Skip to content

Commit

Permalink
Refactor push_to_hub (modelscope#1883)
Browse files Browse the repository at this point in the history
  • Loading branch information
tastelikefeet committed Sep 2, 2024
1 parent a748cca commit 543194c
Show file tree
Hide file tree
Showing 7 changed files with 191 additions and 177 deletions.
1 change: 0 additions & 1 deletion swift/llm/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,6 @@ def llm_sft(args: SftArguments) -> Dict[str, Any]:
logger.info(f'images_dir: {images_dir}')
plot_images(images_dir, args.logging_dir, ['train/loss'], 0.9)
if args.push_to_hub:
trainer._add_patterns_to_gitignore(['images/'])
trainer.push_to_hub()
run_info = {
'memory': trainer.perf['memory'],
Expand Down
10 changes: 8 additions & 2 deletions swift/llm/utils/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,11 @@ def handle_compatibility(self: Union['SftArguments', 'InferArguments']) -> None:
if self.lora_target_regex:
self.target_regex = self.lora_target_regex

if getattr(self, 'push_hub_strategy', None):
self.hub_strategy = self.push_hub_strategy
if self.hub_strategy in ('push_last', 'push_best'):
self.hub_strategy = 'every_save'

def handle_custom_dataset_info(self: Union['SftArguments', 'InferArguments']):
if self.custom_dataset_info is None:
return
Expand Down Expand Up @@ -804,7 +809,7 @@ class SftArguments(ArgumentsBase):
hub_token: Optional[str] = field(
default=None, metadata={'help': 'SDK token can be found in https://modelscope.cn/my/myaccesstoken'})
hub_private_repo: bool = False
push_hub_strategy: Literal['end', 'push_best', 'push_last', 'checkpoint', 'all_checkpoints'] = 'push_best'
hub_strategy: Literal['end', 'every_save', 'checkpoint', 'all_checkpoints'] = 'every_save'

# other
test_oom_error: bool = field(
Expand Down Expand Up @@ -888,6 +893,7 @@ class SftArguments(ArgumentsBase):
custom_train_dataset_path: List[str] = field(default_factory=list)
custom_val_dataset_path: List[str] = field(default_factory=list)
device_map_config_path: Optional[str] = None
push_hub_strategy: Literal['end', 'push_best', 'push_last', 'checkpoint', 'all_checkpoints'] = 'push_best'

def _prepare_target_modules(self, target_modules) -> Union[List[str], str]:
if isinstance(target_modules, str):
Expand Down Expand Up @@ -1213,7 +1219,7 @@ def _init_training_args(self) -> None:
adam_epsilon=self.adam_epsilon,
hub_model_id=self.hub_model_id,
hub_private_repo=self.hub_private_repo,
push_hub_strategy=self.push_hub_strategy,
hub_strategy=self.push_hub_strategy,
hub_token=self.hub_token,
push_to_hub=self.push_to_hub,
resume_from_checkpoint=self.resume_from_checkpoint,
Expand Down
2 changes: 0 additions & 2 deletions swift/trainers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ class SwiftArgumentsMixin:
# ckpt only save model
save_only_model: bool = False
train_sampler_random: bool = True
push_hub_strategy: str = field(
default='push_best', metadata={'choices': {'end', 'push_best', 'push_last', 'checkpoint', 'all_checkpoints'}})
acc_strategy: str = field(default='token', metadata={'choices': ['token', 'sentence']})
loss_name: Optional[str] = field(default=None, metadata={'help': f'loss_func choices: {list(LOSS_MAPPING.keys())}'})
additional_saved_files: Optional[List[str]] = None
Expand Down
172 changes: 2 additions & 170 deletions swift/trainers/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,191 +22,23 @@
from transformers.data.data_collator import DataCollator
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import unwrap_model
from transformers.trainer import (ADAPTER_CONFIG_NAME, ADAPTER_SAFE_WEIGHTS_NAME, ADAPTER_WEIGHTS_NAME, CONFIG_NAME,
PREFIX_CHECKPOINT_DIR, SAFE_WEIGHTS_NAME, TRAINER_STATE_NAME, TRAINING_ARGS_NAME,
WEIGHTS_NAME, IntervalStrategy, Trainer, TrainerCallback, is_peft_available)
from transformers.trainer import PREFIX_CHECKPOINT_DIR, TRAINER_STATE_NAME, Trainer, TrainerCallback
from transformers.trainer_utils import EvalPrediction
from transformers.training_args import TrainingArguments
from transformers.utils import is_sagemaker_mp_enabled, is_torch_npu_available

from swift.hub import Repository
from swift.hub.check_model import check_local_model_is_latest
from swift.torchacc_utils import (save_ta_ddp_checkpoint, save_ta_fsdp_checkpoint, ta_load_optimizer_and_scheduler,
ta_save_optimizer_and_scheduler, ta_trim_graph)
from swift.tuners import SwiftModel
from swift.utils import check_json_format, create_ms_repo, get_logger, use_torchacc
from swift.utils import check_json_format, get_logger, use_torchacc
from swift.utils.constants import Invoke
from .optimizers.galore import create_optimizer_and_scheduler
from .utils import can_return_loss, find_labels, get_function, is_instance_of_ms_model

logger = get_logger()


def _push_to_hub(self: Repository, commit_message: str = 'Commit files to Modelscope Hub', **kwargs):
blocking = kwargs.get('blocking', True)
self.push(commit_message)
if not blocking:
# Compatible with transformers
return None, None
else:
return None


class PushToMsHubMixin:
repo: Repository

def _add_patterns_to_file(self, file_name: str, patterns: List[str], commit_message: Optional[str] = None) -> None:
# Make sure we only do this on the main process
if not self.is_world_process_zero():
return
if isinstance(patterns, str):
patterns = [patterns]
if commit_message is None:
commit_message = f'Add `{patterns[0]}` patterns to {file_name}'

# Get current file content
repo_dir = self.repo.model_dir
file_path = os.path.join(repo_dir, file_name)
if os.path.exists(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
current_content = f.read()
else:
current_content = ''
# Add the patterns to file
content = current_content
for pattern in patterns:
if pattern not in content:
if len(content) > 0 and not content.endswith('\n'):
content += '\n'
content += f'{pattern}\n'

# Write the file if it has changed
if content != current_content:
with open(file_path, 'w', encoding='utf-8') as f:
logger.debug(f'Writing {file_name} file. Content: {content}')
f.write(content)
self.repo.push(commit_message)

def _add_patterns_to_gitignore(self, patterns: List[str], commit_message: Optional[str] = None) -> None:
self._add_patterns_to_file('.gitignore', patterns, commit_message)

def _add_patterns_to_gitattributes(self, patterns: List[str], commit_message: Optional[str] = None) -> None:
new_patterns = []
suffix = 'filter=lfs diff=lfs merge=lfs -text'
for pattern in patterns:
if suffix not in pattern:
pattern = f'{pattern} {suffix}'
new_patterns.append(pattern)
file_name = '.gitattributes'
if commit_message is None:
commit_message = f'Add `{patterns[0]}` patterns to {file_name}'
self._add_patterns_to_file(file_name, new_patterns, commit_message)

def init_hf_repo(self) -> None:
"""init ms repo. Compatible with transformers>=4.34"""
self.init_git_repo(at_init=True)

def init_git_repo(self, at_init: bool = False) -> None:
if not self.is_world_process_zero():
return
if (os.path.exists(self.args.output_dir) and os.listdir(self.args.output_dir) and self.args.overwrite_output_dir
and at_init):
# directory not empty.
shutil.rmtree(self.args.output_dir)
self.args.hub_model_id = create_ms_repo(self.args.hub_model_id, self.args.hub_token, self.args.hub_private_repo)
self.repo = Repository(self.args.output_dir, self.args.hub_model_id)
self._add_patterns_to_gitattributes(['*.safetensors', '*.bin', '*.pt'])
self.repo.push_to_hub = MethodType(_push_to_hub, self.repo)
self.repo.local_dir = self.repo.model_dir # hf compatibility

# By default, ignore the checkpoint folders
if self.args.push_hub_strategy != 'all_checkpoints':
self._add_patterns_to_gitignore(['checkpoint-*/', 'tmp-checkpoint-*/'])

# Add 'runs/' to .gitignore, ignore tensorboard files
self._add_patterns_to_gitignore(['runs/'])

# Add '*.sagemaker' to .gitignore if using SageMaker
if os.environ.get('SM_TRAINING_ENV'):
self._add_patterns_to_gitignore(['*.sagemaker-uploading', '*.sagemaker-uploaded'],
'Add `*.sagemaker` patterns to .gitignore')

self.push_in_progress = None

def push_to_hub(self, commit_message: str = 'End of training', **kwargs) -> None:
# user calls manually `push_to_hub` with `self.args.push_to_hub = False`
create_model_card = kwargs.pop('create_model_card', None)
if not hasattr(self, 'repo'):
self.init_git_repo()
self.save_model(_internal_call=True)

if not self.is_world_process_zero():
return

self.repo.push_to_hub(commit_message, **kwargs)
# push separately the model card to be independent from the rest of the model
readme_path = os.path.join(self.args.output_dir, 'README.md')
if create_model_card is None:
create_model_card = not os.path.exists(readme_path)
if create_model_card and self.args.should_save:
model_name = kwargs.pop('model_name', None)
if model_name is None and self.args.should_save:
if self.args.hub_model_id is not None:
model_name = self.args.hub_model_id.split('/')[-1]
else:
model_name = os.path.basename(self.args.output_dir)
self.create_model_card(model_name=model_name, **kwargs)
self.repo.push_to_hub('update model card README.md', **kwargs)

def _push_from_checkpoint(self, checkpoint_folder: str) -> None:
"""Compatible with transformers>=4.32"""
# Only push from one node.
if not self.is_world_process_zero() or self.args.push_hub_strategy == 'end':
return
output_dir = self.args.output_dir
# To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder
modeling_files = [CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME]
if is_peft_available():
modeling_files.extend([ADAPTER_CONFIG_NAME, ADAPTER_WEIGHTS_NAME, ADAPTER_SAFE_WEIGHTS_NAME])
for modeling_file in modeling_files:
if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)):
shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file))
# Saving the tokenizer is fast and we don't know how many files it may have spawned, so we resave it to be sure.
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
# Same for the training arguments
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

try:
if self.args.push_hub_strategy == 'checkpoint':
# Temporarily move the checkpoint just saved for the push
tmp_checkpoint = os.path.join(output_dir, 'last-checkpoint')
# We have to remove the "last-checkpoint" dir if it exists, otherwise the checkpoint is moved as a
# subfolder.
if os.path.isdir(tmp_checkpoint):
shutil.rmtree(tmp_checkpoint)
shutil.move(checkpoint_folder, tmp_checkpoint)

if self.args.save_strategy == IntervalStrategy.STEPS:
commit_message = f'Training in progress, step {self.state.global_step}'
else:
commit_message = f'Training in progress, epoch {int(self.state.epoch)}'
if self.args.push_hub_strategy == 'push_best':
folder, checkpoint_name = os.path.split(checkpoint_folder)
checkpoint_name = checkpoint_name.replace('tmp-checkpoint-', 'checkpoint-')
last_model_checkpoint = os.path.join(folder, checkpoint_name)
if last_model_checkpoint == self.state.best_model_checkpoint:
self.repo.push_to_hub(commit_message=commit_message, blocking=False, auto_lfs_prune=True)
else:
self.repo.push_to_hub(commit_message=commit_message, blocking=False, auto_lfs_prune=True)
except Exception as e:
logger.error(f'Error when pushing to hub: {e}')
finally:
if self.args.push_hub_strategy == 'checkpoint':
# Move back the checkpoint to its place
shutil.move(tmp_checkpoint, checkpoint_folder)


class SwiftMixin:

def __init__(self,
Expand Down
Loading

0 comments on commit 543194c

Please sign in to comment.