Skip to content

Commit

Permalink
Fix total_num_steps (axolotl-ai-cloud#1566)
Browse files Browse the repository at this point in the history
* Fix `total_num_steps`

* Fix total_num_steps

* lint
  • Loading branch information
bofenghuang authored May 15, 2024
1 parent 1e1921b commit 81da7d2
Showing 1 changed file with 5 additions and 15 deletions.
20 changes: 5 additions & 15 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,6 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
/ cfg.sample_packing_eff_est
/ cfg.sequence_len
// cfg.batch_size
// int(os.environ.get("WORLD_SIZE", 1))
)
- 1
)
Expand Down Expand Up @@ -359,18 +358,14 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
train_dataset.remove_columns(["length"]),
batch_sampler=sampler,
)
data_loader_len = len(data_loader) // cfg.batch_size
data_loader_len = len(data_loader) // (
cfg.world_size * cfg.gradient_accumulation_steps
)
actual_eff = sampler.efficiency()
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
# FIXME: is there a bug here somewhere? the total num steps depends
# on the agreed on value for sample_packing_eff_est
total_num_steps = int(
math.floor(
data_loader_len
* cfg.num_epochs
/ int(os.environ.get("WORLD_SIZE", 1))
)
)
total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs))

def calc_sample_packing_eff_est(estimates: List[float]):
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
Expand All @@ -391,12 +386,7 @@ def calc_sample_packing_eff_est(estimates: List[float]):
)
else:
total_num_steps = int(
math.ceil(
len(train_dataset)
* cfg.num_epochs
/ int(os.environ.get("WORLD_SIZE", 1))
/ cfg.batch_size
)
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)
LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True)
return total_num_steps
Expand Down

0 comments on commit 81da7d2

Please sign in to comment.