Skip to content

Commit

Permalink
fix bug in hiyouga#50
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed May 5, 2023
1 parent 268ef28 commit ff4d089
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/train_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
preprocess_data,
PPODataCollatorForChatGLM,
PPOTrainerForChatGLM,
InvalidScoreLogitsProcessor,
compute_rewards,
get_logits_processor,
plot_loss
)

Expand Down Expand Up @@ -70,7 +70,7 @@ def main():
"do_sample": True,
"pad_token_id": tokenizer.pad_token_id,
"eos_token_id": tokenizer.eos_token_id,
"logits_processor": InvalidScoreLogitsProcessor()
"logits_processor": get_logits_processor()
}
output_length_sampler = LengthSampler(data_args.max_target_length // 2, data_args.max_target_length)

Expand Down
2 changes: 1 addition & 1 deletion src/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@

from .config import ModelArguments

from .other import InvalidScoreLogitsProcessor, plot_loss
from .other import get_logits_processor, plot_loss
3 changes: 3 additions & 0 deletions src/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ def load_pretrained(
logger.info("Load fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
finetuning_args = torch.load(os.path.join(model_args.checkpoint_dir[0], FINETUNING_ARGS_NAME))

if stage != "sft" and finetuning_args.finetuning_type != "lora":
raise ValueError("RM and PPO training can only be performed with LoRA method.")

quantization = None
if model_args.quantization_bit is not None:
if is_trainable:
Expand Down
10 changes: 10 additions & 0 deletions src/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
from transformers import Seq2SeqTrainingArguments
from transformers.trainer import TRAINER_STATE_NAME
from transformers.modeling_utils import PreTrainedModel
from transformers.generation.utils import LogitsProcessorList
from transformers.generation.logits_process import LogitsProcessor


from peft.utils.other import WEIGHTS_NAME


Expand Down Expand Up @@ -48,6 +50,8 @@ def update(self, val, n=1):
self.avg = self.sum / self.count


# Avoid runtime error in model.generate(do_sample=True).
# Borrowed from: https://huggingface.co/THUDM/chatglm-6b/blob/658202d88ac4bb782b99e99ac3adff58b4d0b813/modeling_chatglm.py#L54
class InvalidScoreLogitsProcessor(LogitsProcessor):

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
Expand All @@ -57,6 +61,12 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
return scores


def get_logits_processor():
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
return logits_processor


# Includes: (1) cast the layernorm in fp32 (2) make output embedding layer require grads (3) upcast the lm_head to fp32
# Inspired by: https://github.com/huggingface/peft/blob/c0209c35abbf88c63aa267800d98a8e212ed0a42/src/peft/utils/other.py#L35
def prepare_model_for_training(
Expand Down

0 comments on commit ff4d089

Please sign in to comment.