Skip to content

Commit

Permalink
DeepSpeed/FSDP ckpt saving utils fixes and FSDP training args fixes (#…
Browse files Browse the repository at this point in the history
…24591)

* update ds and fsdp ckpt logic

* refactoring

* fix 🐛

* resolve comment

* fix issue with overriding of the fsdp config set by accelerate
  • Loading branch information
pacman100 authored Jul 6, 2023
1 parent 3927404 commit 66a3784
Showing 1 changed file with 18 additions and 34 deletions.
52 changes: 18 additions & 34 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from .configuration_utils import PretrainedConfig
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from .debug_utils import DebugOption, DebugUnderflowOverflow
from .deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_zero3_enabled
from .deepspeed import deepspeed_init, deepspeed_load_checkpoint
from .dependency_versions_check import dep_version_check
from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend
from .modelcard import TrainingSummary
Expand Down Expand Up @@ -2737,44 +2737,26 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa
or self.fsdp is not None
or self.is_fsdp_enabled
):
state_dict = self.model.state_dict()
if self.args.should_save:
self._save(output_dir, state_dict=state_dict)
if self.is_fsdp_enabled:
os.makedirs(output_dir, exist_ok=True)
save_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir)
else:
state_dict = self.model.state_dict()

if self.args.should_save:
self._save(output_dir, state_dict=state_dict)
elif self.is_deepspeed_enabled:
# this takes care of everything as long as we aren't under zero3
if self.args.should_save and not is_deepspeed_zero3_enabled():
if version.parse(accelerate_version) <= version.parse("0.20.3"):
raise ValueError("Install Accelerate from main branch")
if version.parse(accelerate_version) <= version.parse("0.20.3"):
raise ValueError("Install Accelerate from main branch")
try:
state_dict = self.accelerator.get_state_dict(self.deepspeed)

self._save(output_dir, state_dict=state_dict)

if is_deepspeed_zero3_enabled():
# It's too complicated to try to override different places where the weights dump gets
# saved, so since under zero3 the file is bogus, simply delete it. The user should
# either use deepspeed checkpoint to resume or to recover full weights use
# zero_to_fp32.py stored in the checkpoint.
if self.args.should_save:
file = os.path.join(output_dir, WEIGHTS_NAME)
if os.path.isfile(file):
# logger.info(f"deepspeed zero3: removing {file}, see zero_to_fp32.py to recover weights")
os.remove(file)

# now save the real model if stage3_gather_16bit_weights_on_model_save=True
# if false it will not be saved.
# This must be called on all ranks
if not self.model_wrapped.save_16bit_model(output_dir, WEIGHTS_NAME):
logger.warning(
"deepspeed.save_16bit_model didn't save the model, since"
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use"
" zero_to_fp32.py to recover weights"
)
self.model_wrapped.save_checkpoint(output_dir)
self._save(output_dir, state_dict=state_dict)
except ValueError:
logger.warning(
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use"
" zero_to_fp32.py to recover weights"
)
self.model_wrapped.save_checkpoint(output_dir)

elif self.args.should_save:
self._save(output_dir)
Expand Down Expand Up @@ -3854,8 +3836,10 @@ def create_accelerator_and_postprocess(self):
# post accelerator creation setup
if self.is_fsdp_enabled:
fsdp_plugin = self.accelerator.state.fsdp_plugin
fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get("limit_all_gathers", False)
fsdp_plugin.use_orig_params = self.args.fsdp_config.get("use_orig_params", False)
fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get(
"limit_all_gathers", fsdp_plugin.limit_all_gathers
)
fsdp_plugin.use_orig_params = self.args.fsdp_config.get("use_orig_params", fsdp_plugin.use_orig_params)

if self.is_deepspeed_enabled:
if getattr(self.args, "hf_deepspeed_config", None) is None:
Expand Down

0 comments on commit 66a3784

Please sign in to comment.