Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoConfig]Add multi prune #60954

Merged
merged 8 commits into from
Feb 2, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add prune_by_mp_pp_history
  • Loading branch information
Difers committed Jan 31, 2024
commit 97824558f462631ebf60e475d76acb4f3fb1eade
32 changes: 32 additions & 0 deletions python/paddle/distributed/auto_tuner/prune.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,38 @@ def prune_by_pp(tuner_cfg, cur_cfg, history_cfgs=[]):
return False


@register_prune_history
def prune_by_mp_pp_history(tuner_cfg, cur_cfg, history_cfgs, pruned_cfgs):
mp_degree = cur_cfg.get("mp_degree", None)
pp_degree = cur_cfg.get("pp_degree", None)
use_recompute = cur_cfg.get("recompute", None)

if mp_degree is None or pp_degree is None or use_recompute is None:
return False, None

history_cfgs.extend(pruned_cfgs)
cfgs = same_cfgs_beside(["mp_degree", "pp_degree"], cur_cfg, history_cfgs)
if cur_cfg.get("sharding_degree") == 1:
cfgs = same_cfgs_beside(
["mp_degree", "pp_degree", "sharding_satge"], cur_cfg, history_cfgs
)

if cfgs:
for cfg in cfgs:
if (
not use_recompute
and cfg["mp_degree"] * cfg["pp_degree"] == mp_degree * pp_degree
and cfg["mp_degree"] > mp_degree
and cfg.get("max_mem_usage") == "OOM"
):
pruned_reason = f"mp_degree {mp_degree}, pp_degree {pp_degree} may cause oom because {cfg['mp_degree']}, {cfg['pp_degree']} already oom."
log_pruned_info(cur_cfg, pruned_reason)
cur_cfg["max_mem_usage"] = "OOM"
return True, cur_cfg

return False, None


@register_prune
def prune_by_vpp(tuner_cfg, cur_cfg, history_cfgs=[]):
"""
Expand Down