Skip to content

Commit

Permalink
fix eval strategy compat (modelscope#1143)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Jun 16, 2024
1 parent db46c9e commit 5710ed3
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 21 deletions.
11 changes: 6 additions & 5 deletions swift/llm/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from transformers.utils import is_torch_npu_available

from swift.trainers.dpo_trainers import DPOTrainer
from swift.utils import (check_json_format, get_dist_setting, get_logger, get_main, get_model_info, is_ddp_plus_mp,
is_dist, is_master, plot_images, seed_everything, show_layers)
from swift.utils import (append_to_jsonl, check_json_format, get_dist_setting, get_logger, get_main, get_model_info,
is_ddp_plus_mp, is_dist, is_local_master, is_master, plot_images, seed_everything, show_layers)
from .tuner import prepare_model
from .utils import (DPOArguments, Template, get_dataset, get_model_tokenizer, get_template, get_time_info,
set_generation_config)
Expand Down Expand Up @@ -156,6 +156,7 @@ def llm_dpo(args: DPOArguments) -> str:

if val_dataset is None:
training_args.evaluation_strategy = IntervalStrategy.NO
training_args.eval_strategy = IntervalStrategy.NO
training_args.do_eval = False
logger.info(f'train_dataset: {train_dataset}')
logger.info(f'val_dataset: {val_dataset}')
Expand Down Expand Up @@ -227,9 +228,9 @@ def llm_dpo(args: DPOArguments) -> str:
'model_info': model_info,
'dataset_info': trainer.dataset_info,
}
jsonl_path = os.path.join(args.output_dir, 'logging.jsonl')
with open(jsonl_path, 'a', encoding='utf-8') as f:
f.write(json.dumps(run_info) + '\n')
if is_local_master():
jsonl_path = os.path.join(args.output_dir, 'logging.jsonl')
append_to_jsonl(jsonl_path, run_info)
return run_info


Expand Down
11 changes: 6 additions & 5 deletions swift/llm/orpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from transformers.utils import is_torch_npu_available

from swift.trainers.orpo_trainers import ORPOTrainer
from swift.utils import (check_json_format, get_dist_setting, get_logger, get_main, get_model_info, is_ddp_plus_mp,
is_dist, is_master, plot_images, seed_everything, show_layers)
from swift.utils import (append_to_jsonl, check_json_format, get_dist_setting, get_logger, get_main, get_model_info,
is_ddp_plus_mp, is_dist, is_local_master, is_master, plot_images, seed_everything, show_layers)
from .tuner import prepare_model
from .utils import (ORPOArguments, Template, get_dataset, get_model_tokenizer, get_template, get_time_info,
set_generation_config)
Expand Down Expand Up @@ -147,6 +147,7 @@ def llm_orpo(args: ORPOArguments) -> str:

if val_dataset is None:
training_args.evaluation_strategy = IntervalStrategy.NO
training_args.eval_strategy = IntervalStrategy.NO
training_args.do_eval = False
logger.info(f'train_dataset: {train_dataset}')
logger.info(f'val_dataset: {val_dataset}')
Expand Down Expand Up @@ -230,9 +231,9 @@ def llm_orpo(args: ORPOArguments) -> str:
'model_info': model_info,
'dataset_info': trainer.dataset_info,
}
jsonl_path = os.path.join(args.output_dir, 'logging.jsonl')
with open(jsonl_path, 'a', encoding='utf-8') as f:
f.write(json.dumps(run_info) + '\n')
if is_local_master():
jsonl_path = os.path.join(args.output_dir, 'logging.jsonl')
append_to_jsonl(jsonl_path, run_info)
return run_info


Expand Down
1 change: 1 addition & 0 deletions swift/llm/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]:
val_dataset = LazyLLMDataset(val_dataset, template)
if val_dataset is None:
training_args.evaluation_strategy = IntervalStrategy.NO
training_args.eval_strategy = IntervalStrategy.NO
training_args.do_eval = False

padding_to = args.max_length if args.sft_type == 'longlora' else None
Expand Down
11 changes: 6 additions & 5 deletions swift/llm/simpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from transformers.utils import is_torch_npu_available

from swift.trainers.simpo_trainers import SimPOTrainer
from swift.utils import (check_json_format, get_dist_setting, get_logger, get_main, get_model_info, is_ddp_plus_mp,
is_dist, is_master, plot_images, seed_everything, show_layers)
from swift.utils import (append_to_jsonl, check_json_format, get_dist_setting, get_logger, get_main, get_model_info,
is_ddp_plus_mp, is_dist, is_local_master, is_master, plot_images, seed_everything, show_layers)
from .tuner import prepare_model
from .utils import (SimPOArguments, Template, get_dataset, get_model_tokenizer, get_template, get_time_info,
set_generation_config)
Expand Down Expand Up @@ -145,6 +145,7 @@ def llm_simpo(args: SimPOArguments) -> str:

if val_dataset is None:
training_args.evaluation_strategy = IntervalStrategy.NO
training_args.eval_strategy = IntervalStrategy.NO
training_args.do_eval = False
logger.info(f'train_dataset: {train_dataset}')
logger.info(f'val_dataset: {val_dataset}')
Expand Down Expand Up @@ -215,9 +216,9 @@ def llm_simpo(args: SimPOArguments) -> str:
'model_info': model_info,
'dataset_info': trainer.dataset_info,
}
jsonl_path = os.path.join(args.output_dir, 'logging.jsonl')
with open(jsonl_path, 'a', encoding='utf-8') as f:
f.write(json.dumps(run_info) + '\n')
if is_local_master():
jsonl_path = os.path.join(args.output_dir, 'logging.jsonl')
append_to_jsonl(jsonl_path, run_info)
return run_info


Expand Down
11 changes: 5 additions & 6 deletions swift/trainers/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
TrainerState)
from transformers.trainer_utils import IntervalStrategy, has_length, speed_metrics

from swift.utils import is_pai_training_job, use_torchacc
from swift.utils import append_to_jsonl, is_pai_training_job, use_torchacc
from .arguments import TrainingArguments


Expand Down Expand Up @@ -55,8 +55,7 @@ def on_log(self, args: TrainingArguments, state: TrainerState, control, logs=Non
logs[k] = round(logs[k], 8)
if not is_pai_training_job() and state.is_local_process_zero:
jsonl_path = os.path.join(args.output_dir, 'logging.jsonl')
with open(jsonl_path, 'a', encoding='utf-8') as f:
f.write(json.dumps(logs) + '\n')
append_to_jsonl(jsonl_path, logs)
super().on_log(args, state, control, logs, **kwargs)
if state.is_local_process_zero and self.training_bar is not None:
self.training_bar.refresh()
Expand All @@ -67,8 +66,9 @@ class DefaultFlowCallbackNew(DefaultFlowCallback):
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
control = super().on_step_end(args, state, control, **kwargs)
# save the last ckpt
evaluation_strategy = args.eval_strategy if hasattr(args, 'eval_strategy') else args.evaluation_strategy
if state.global_step == state.max_steps:
if args.evaluation_strategy != IntervalStrategy.NO:
if evaluation_strategy != IntervalStrategy.NO:
control.should_evaluate = True
if args.save_strategy != IntervalStrategy.NO:
control.should_save = True
Expand All @@ -84,8 +84,7 @@ def on_log(self, args, state, control, logs=None, **kwargs):
logs[k] = round(logs[k], 8)
if not is_pai_training_job() and state.is_local_process_zero:
jsonl_path = os.path.join(args.output_dir, 'logging.jsonl')
with open(jsonl_path, 'a', encoding='utf-8') as f:
f.write(json.dumps(logs) + '\n')
append_to_jsonl(jsonl_path, logs)

_ = logs.pop('total_flos', None)
if state.is_local_process_zero:
Expand Down

0 comments on commit 5710ed3

Please sign in to comment.