Skip to content

Commit

Permalink
enable loraplus setting for dpo trainer (axolotl-ai-cloud#1646)
Browse files Browse the repository at this point in the history
  • Loading branch information
thepowerfuldeez authored May 22, 2024
1 parent 6299eb5 commit a27d5e1
Showing 1 changed file with 37 additions and 1 deletion.
38 changes: 37 additions & 1 deletion src/axolotl/core/trainer_builder.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,40 @@ class AxolotlDPOTrainer(DPOTrainer):

tag_names = ["axolotl", "dpo"]

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.optimizer = None

def create_optimizer(self):
if self.args.loraplus_lr_ratio is None:
return super().create_optimizer()

opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args,
opt_model,
)

loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
if loraplus_lr_ratio:
print("Using lora+")
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
optimizer_kwargs,
loraplus_lr_ratio,
loraplus_lr_embedding,
)

if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
self.optimizer
)

return self.optimizer

@wraps(DPOTrainer.push_to_hub)
def push_to_hub(self, *args, **kwargs) -> str:
"""
Expand Down Expand Up @@ -1483,6 +1517,8 @@ def build_training_arguments(self, total_num_steps):
if self.cfg.bf16 or self.cfg.bfloat16:
training_args_kwargs["bf16"] = True

training_args_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
training_args_kwargs["loraplus_lr_embedding"] = self.cfg.loraplus_lr_embedding
training_args_kwargs["lr_scheduler_type"] = (
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
)
Expand Down Expand Up @@ -1535,7 +1571,7 @@ def build_training_arguments(self, total_num_steps):
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
training_args_kwargs["beta"] = self.cfg.orpo_alpha

training_args_cls = TrainingArguments
training_args_cls = AxolotlTrainingArguments
if self.cfg.rl == "orpo":
training_args_cls = ORPOConfig
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
Expand Down

0 comments on commit a27d5e1

Please sign in to comment.