Skip to content

Commit

Permalink
fix cast output layer
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Jun 28, 2023
1 parent 7b67c38 commit a136eb4
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
5 changes: 4 additions & 1 deletion src/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,13 +233,16 @@ def get_input_embeddings(self):
model.get_input_embeddings = MethodType(get_input_embeddings, model)
model.lm_head = model.transformer.output_layer
tokenizer.eos_token = "</s>"
output_embedding_layer_name = "transformer.output_layer"
output_embedding_base_layer = model.transformer
output_embedding_layer_name = "output_layer"
else:
output_embedding_base_layer = model
output_embedding_layer_name = "lm_head"

model = prepare_model_for_training(
model,
finetuning_args.finetuning_type,
output_embedding_base_layer,
output_embedding_layer_name
) if is_trainable else model
model = init_adapter(model, model_args, finetuning_args, is_trainable)
Expand Down
7 changes: 4 additions & 3 deletions src/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def get_logits_processor() -> LogitsProcessorList:
def prepare_model_for_training(
model: PreTrainedModel,
finetuning_type: str,
output_embedding_base_layer: torch.nn.Module,
output_embedding_layer_name: Optional[str] = "lm_head",
use_gradient_checkpointing: Optional[bool] = True,
layer_norm_names: Optional[List[str]] = ["layernorm"] # for chatglm setting
Expand All @@ -87,16 +88,16 @@ def prepare_model_for_training(
model.gradient_checkpointing_enable()
model.config.use_cache = False # turn off when gradient checkpointing is enabled

if finetuning_type != "full" and hasattr(model, output_embedding_layer_name):
output_embedding_layer: torch.nn.Linear = getattr(model, output_embedding_layer_name)
if finetuning_type != "full" and hasattr(output_embedding_base_layer, output_embedding_layer_name):
output_embedding_layer = getattr(output_embedding_base_layer, output_embedding_layer_name)
input_dtype = output_embedding_layer.weight.dtype

class CastOutputToFloat(torch.nn.Sequential):

def forward(self, x: torch.Tensor) -> torch.Tensor:
return super().forward(x.to(input_dtype)).to(torch.float32)

setattr(model, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer))
setattr(output_embedding_base_layer, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer))

return model

Expand Down

0 comments on commit a136eb4

Please sign in to comment.