Skip to content

Commit

Permalink
add eliminate_transpose arg
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqiu committed Apr 28, 2024
1 parent ea2926c commit 3318bb5
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions llm/llama/auto_parallel/run_pretrain_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ class PreTrainingArguments(TrainingArguments):
"help": "Enable fused_linear_param_grad pass, which should replace add_n_op with add_op for gradients accumulation."
},
)
eleminate_transpose: bool = field(
default=False,
metadata={
"help": "Enable eleminate_transpose pass, which should replace transpose with reshape when sequence parallel is enabled."
},
)
job_schedule_profiler_start: int = field(
default=-1,
metadata={"help": "The step to start job_schedule_profiler."},
Expand Down Expand Up @@ -132,6 +138,11 @@ def __post_init__(self):
fused_passes.enable = True
fused_passes.fused_passes_list.append("fused_linear_param_grad_add_pass")

if self.eliminate_transpose:
fused_passes = self.strategy.fused_passes
fused_passes.enable = True
fused_passes.fused_passes_list.append("eliminate_transpose")

logger.info(self.strategy)


Expand Down

0 comments on commit 3318bb5

Please sign in to comment.