Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update no_trainer.py scripts to include accelerate gradient accumulation wrapper #18473

Merged
merged 11 commits into from
Aug 8, 2022
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,13 @@ def main():
# If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
# in the environment
accelerator = (
Accelerator(log_with=args.report_to, logging_dir=args.output_dir) if args.with_tracking else Accelerator()
Accelerator(
log_with=args.report_to,
logging_dir=args.output_dir,
gradient_accumulation_steps=args.gradient_accumulation_steps,
)
if args.with_tracking
else Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
Rasmusafj marked this conversation as resolved.
Show resolved Hide resolved
)
logger.info(accelerator.state)
# Make one log on every process with the configuration for debugging.
Expand Down Expand Up @@ -385,7 +391,7 @@ def collate_fn(examples):
name=args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=args.num_warmup_steps,
num_training_steps=args.max_train_steps,
num_training_steps=math.ceil(len(train_dataloader)) * args.num_train_epochs,
Rasmusafj marked this conversation as resolved.
Show resolved Hide resolved
)

# Prepare everything with our `accelerator`.
Expand Down Expand Up @@ -467,17 +473,20 @@ def collate_fn(examples):
if resume_step is not None and step < resume_step:
completed_steps += 1
continue
outputs = model(**batch)
loss = outputs.loss
# We keep track of the loss at each epoch
if args.with_tracking:
total_loss += loss.detach().float()
loss = loss / args.gradient_accumulation_steps
accelerator.backward(loss)
if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:

with accelerator.accumulate(model):
outputs = model(**batch)
loss = outputs.loss
# We keep track of the loss at each epoch
if args.with_tracking:
total_loss += loss.detach().float()

accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()

if accelerator.sync_gradients:
sgugger marked this conversation as resolved.
Show resolved Hide resolved
progress_bar.update(1)
completed_steps += 1

Expand Down