Skip to content

Commit

Permalink
[14/N] Refactor _new_process_group_helper() to remove repeated code (p…
Browse files Browse the repository at this point in the history
…ytorch#88351)

Changes:
- refactor parts of `_new_process_group_helper()` to remove repeated code

Differential Revision: [D41188274](https://our.internmc.facebook.com/intern/diff/D41188274)
Pull Request resolved: pytorch#88351
Approved by: https://github.com/kwen2501
  • Loading branch information
H-Huang authored and pytorchmergebot committed Nov 10, 2022
1 parent 4bcf2c5 commit 1d54ce9
Showing 1 changed file with 25 additions and 67 deletions.
92 changes: 25 additions & 67 deletions torch/distributed/distributed_c10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,8 +925,6 @@ def _new_process_group_helper(
pg = ProcessGroupMPI.create(global_ranks_in_group)
if not pg:
return GroupMember.NON_GROUP_MEMBER
_world.pg_map[pg] = (Backend.MPI, None)
_world.pg_names[pg] = group_name
else:
# If this is a subgroup (which means group_ranks is specified),
# we check if the current process is a member of the new group.
Expand All @@ -943,27 +941,6 @@ def _new_process_group_helper(
if pg_options is not None:
raise RuntimeError("GLOO options not supported")
pg = ProcessGroupGloo(prefix_store, group_rank, group_size, timeout=timeout)
# In debug mode and if GLOO is available, wrap in a wrapper PG that
# enables enhanced collective checking for debugability.
if get_debug_level() == DebugLevel.DETAIL:
if not _GLOO_AVAILABLE:
logger.info(
"""TORCH_DISTRIBUTED_DEBUG was set to DETAIL, but
GLOO is not available. Build with Gloo to
create a wrapper process group in debug mode
to aid collective desynchronization debugging."""
)
else:
pg = _create_process_group_wrapper(
wrapped_pg=pg,
store_prefix=group_name,
store=store,
rank=group_rank,
world_size=group_size,
timeout=timeout,
)
_world.pg_map[pg] = (Backend.GLOO, store)
_world.pg_names[pg] = group_name
elif backend == Backend.NCCL:
if not is_nccl_available():
raise RuntimeError("Distributed package doesn't have NCCL " "built in")
Expand All @@ -978,54 +955,12 @@ def _new_process_group_helper(
pg_options._timeout = timeout

pg = ProcessGroupNCCL(prefix_store, group_rank, group_size, pg_options)
# In debug mode and if GLOO is available, wrap in a wrapper PG that
# enables enhanced collective checking for debugability.
if get_debug_level() == DebugLevel.DETAIL:
if not _GLOO_AVAILABLE:
logger.info(
"""TORCH_DISTRIBUTED_DEBUG was set to DETAIL, but
GLOO is not available. Build with Gloo to
create a wrapper process group in debug mode
to aid collective desynchronization debugging."""
)
else:
pg = _create_process_group_wrapper(
wrapped_pg=pg,
store_prefix=group_name,
store=store,
rank=group_rank,
world_size=group_size,
timeout=timeout,
)
_world.pg_map[pg] = (Backend.NCCL, store)
_world.pg_names[pg] = group_name
elif backend == Backend.UCC and is_ucc_available():
# TODO: once UCC plugin is fully deprecated, remove
# is_ucc_available() from above elif-condition and raise
# RuntimeError if is_ucc_available() returns false.

pg = ProcessGroupUCC(prefix_store, group_rank, group_size, timeout=timeout)
# In debug mode and if GLOO is available, wrap in a wrapper PG that
# enables enhanced collective checking for debugability.
if get_debug_level() == DebugLevel.DETAIL:
if not _GLOO_AVAILABLE:
logger.info(
"""TORCH_DISTRIBUTED_DEBUG was set to DETAIL, but
GLOO is not available. Build with Gloo to
create a wrapper process group in debug mode
to aid collective desynchronization debugging."""
)
else:
pg = _create_process_group_wrapper(
wrapped_pg=pg,
store_prefix=group_name,
store=store,
rank=group_rank,
world_size=group_size,
timeout=timeout,
)
_world.pg_map[pg] = (Backend.UCC, store)
_world.pg_names[pg] = group_name
else:
assert backend.upper() in Backend._plugins, (
f"Unknown c10d backend type {backend.upper()}"
Expand All @@ -1047,9 +982,32 @@ def _new_process_group_helper(
dist_backend_opts.global_ranks_in_group = global_ranks_in_group

pg = creator_fn(dist_backend_opts, pg_options)
_world.pg_map[pg] = (backend, store)
_world.pg_names[pg] = group_name

# Process group wrapper initialization for supported PGs when TORCH_DISTRIBUTED_DEBUG is set
if backend in [Backend.GLOO, Backend.NCCL, Backend.UCC]:
# In debug mode and if GLOO is available, wrap in a wrapper PG that
# enables enhanced collective checking for debuggability.
if get_debug_level() == DebugLevel.DETAIL:
if not _GLOO_AVAILABLE:
logger.info(
"""TORCH_DISTRIBUTED_DEBUG was set to DETAIL, but
GLOO is not available. Build with Gloo to
create a wrapper process group in debug mode
to aid collective desynchronization debugging."""
)
else:
pg = _create_process_group_wrapper(
wrapped_pg=pg,
store_prefix=group_name,
store=store,
rank=group_rank,
world_size=group_size,
timeout=timeout,
)

# update global state
_world.pg_map[pg] = (backend, store)
_world.pg_names[pg] = group_name
return pg


Expand Down

0 comments on commit 1d54ce9

Please sign in to comment.