Skip to content

Commit

Permalink
[dtensor] enable op db tests by using multithreaded test case (pytorc…
Browse files Browse the repository at this point in the history
…h#92198)

Time comparison between using MultithreadedTestCase and MultiProcessTestCase on op db tests is amazing!

using MultiThreadTestCase on a AWS dev node:
```
time pytest test/distributed/_tensor/test_dtensor_ops.py

============= 175 passed, 42 skipped, 397 xfailed in 80.30s (0:01:20) =======

real    1m22.330s
user    1m38.782s
sys     0m18.762s
```
MultiProcessTestCase spends from 40mins to more than 1h, even if using pytest parallel testing tools.

Pull Request resolved: pytorch#92198
Approved by: https://github.com/XilunWu
  • Loading branch information
wanchaol authored and pytorchmergebot committed Jan 17, 2023
1 parent 2ce63ef commit 801d831
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 137 deletions.
256 changes: 124 additions & 132 deletions test/distributed/_tensor/test_dtensor_ops.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]

import sys
import unittest
import warnings

Expand All @@ -23,10 +22,8 @@
TEST_WITH_ASAN,
)
from torch.testing._internal.distributed._tensor.common_dtensor import (
DEVICE_TYPE,
DTensorConverter,
DTensorTestBase,
TEST_SKIPS,
DTensorOpTestBase,
)
from torch.testing._internal.distributed._tensor.dtensor_lagging_op_db import (
dtensor_lagging_op_db,
Expand All @@ -42,32 +39,6 @@
common_ops.XS = 2


def assert_ref_dtensor_equal(test_case, dtensor_rs, rs):
flat_dtensor_rs, _ = tree_flatten(dtensor_rs)
flat_rs, _ = tree_flatten(rs)
test_case.assertEqual(len(flat_dtensor_rs), len(flat_rs))
for dtensor_r, r in zip(flat_dtensor_rs, flat_rs):

if not isinstance(r, torch.Tensor):
continue

test_case.assertIsInstance(dtensor_r, torch.Tensor)
test_case.assertEqual(
dtensor_r.shape,
r.shape,
f"Shape mismatch! original shape:{r.shape}, dtensor shape: {dtensor_r.shape}",
)
test_case.assertEqual(
dtensor_r.requires_grad,
r.requires_grad,
"op result requires_grad mismatch!"
f"original requires_grad: {r.requires_grad}, "
f"dtensor requires_grad: {dtensor_r.requires_grad}",
)

test_case.assertEqual(dtensor_r.to_local(), r)


# Copied from functorch
def xfail(op_name, variant_name="", *, device_type=None, dtypes=None):
return (op_name, variant_name, device_type, dtypes, True)
Expand Down Expand Up @@ -118,7 +89,7 @@ def wrapped(fn):

# Re-generate this failed list, turn on dry_run of the below func
# check_dtensor_func(self, test, op, dry_run=True), then run sth
# like python test/spmd/tensor/test_dtensor_ops.py > failed.expect
# like python test/distributed/_tensor/test_dtensor_ops.py > failed.expect
dtensor_fails = {
# these sometimes pass and sometimes fail
# we need to remove many of them from list once op
Expand Down Expand Up @@ -195,7 +166,6 @@ def wrapped(fn):
xfail("einsum"),
xfail("empty"),
xfail("empty_like"),
xfail("eq"),
xfail("eye"),
xfail("fft.fft2"),
xfail("fft.fft"),
Expand All @@ -212,6 +182,7 @@ def wrapped(fn):
xfail("fft.rfft2"),
xfail("fft.rfft"),
xfail("fft.rfftn"),
xfail("fill"),
xfail("flip"),
xfail("fliplr"),
xfail("flipud"),
Expand All @@ -220,6 +191,7 @@ def wrapped(fn):
xfail("fmin"),
xfail("frexp"),
xfail("full"),
xfail("full_like"),
xfail("gather"),
xfail("geqrf"),
xfail("gradient"),
Expand All @@ -234,10 +206,8 @@ def wrapped(fn):
xfail("index_put"),
xfail("index_reduce"),
xfail("index_select"),
xfail("isfinite"),
xfail("isin"),
xfail("isinf"),
xfail("isnan"),
xfail("isneginf"),
xfail("isposinf"),
xfail("kthvalue"),
Expand Down Expand Up @@ -289,7 +259,6 @@ def wrapped(fn):
xfail("log_softmax", "with_dtype"),
xfail("logcumsumexp"),
xfail("logdet"),
xfail("logical_not"),
xfail("logspace"),
xfail("logsumexp"),
xfail("lt"),
Expand Down Expand Up @@ -456,7 +425,6 @@ def wrapped(fn):
xfail("searchsorted"),
xfail("select"),
xfail("select_scatter"),
xfail("signbit"),
xfail("sort"),
xfail("sparse.sampled_addmm"),
xfail("special.airy_ai"),
Expand Down Expand Up @@ -492,7 +460,6 @@ def wrapped(fn):
xfail("signal.windows.exponential"),
xfail("signal.windows.gaussian"),
xfail("signal.windows.kaiser"),
xfail("squeeze"),
xfail("stack"),
xfail("std"),
xfail("std_mean"),
Expand Down Expand Up @@ -523,6 +490,7 @@ def wrapped(fn):
xfail("vdot"),
xfail("view_as_complex"),
xfail("vstack"),
xfail("where"),
xfail("zeros"),
# ops inside this might even fail without dtensor
# tests, as we rescale op db common test size factor (i.e. L, M, S)
Expand Down Expand Up @@ -559,6 +527,9 @@ def wrapped(fn):
skip("prod"),
skip("segment_reduce", "lengths"),
skip("segment_reduce", "offsets"),

# TODO: fix the following ops
skip("squeeze"),
}


Expand All @@ -573,98 +544,15 @@ def wrapped(fn):
]


def run_dtensor_crossref(test_case, func, args, kwargs):
to_dtensor = DTensorConverter(test_case.mesh, args, kwargs)

# TODO: also handle cases where func raise an exception
rs = func(*args, **kwargs)

def to_replicate(e: object) -> object:
return (
e.redistribute(test_case.mesh, test_case.mesh.ndim * [Replicate()])
if isinstance(e, DTensor)
else e
)

try:
# Suppress warnings, this doesn't matter for test_meta.py
# but it does matter if you want to use this decorator
# for cross-ref testing, as some tests may be looking at
# errors
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# for every comb of sharding choices, we test if it works
for dtensor_args, dtensor_kwargs in to_dtensor:
# Only attempt if we managed to convert all tensors to DTensor
# (if any of them failed, we're in a mixed tensor situation and
# this is not allowed in DTensor)
if to_dtensor.successful():
# Handle special cases first if there's any
# Suppress warnings, this doesn't matter for test_meta.py
# but it does matter if you want to use this decorator
# for cross-ref testing, as some tests may be looking at
# errors
dtensor_rs = func(*dtensor_args, **dtensor_kwargs)

# we need to skip tests containing tensors of zero elmeents for now.
# see issue: https://github.com/pytorch/tau/issues/470
# TODO remove this once issue above fixed.
flat_args, _ = tree_flatten(dtensor_rs)
if any(
isinstance(e, torch.Tensor) and e.numel() == 0
for e in flat_args
):
continue

# redistribute/all_gather the results to compare with normal output
dtensor_rs = tree_map(to_replicate, dtensor_rs)
try:
if resolve_name(func) not in skip_bw:
if isinstance(dtensor_rs, DTensor):
dtensor_rs.to_local().sum().backward()
elif isinstance(dtensor_rs, tuple):
dtensor_rs[0].to_local().sum().backward()

except Exception as e:
# TODO(anj): Remove this guard exception after gaining more confidence.
if torch.distributed.get_rank() == 0:
print(
f"failed to run BW: {resolve_name(func)}, {func}, {str(e)})"
)
assert_ref_dtensor_equal(test_case, dtensor_rs, rs)
else:
raise RuntimeError(
f"failed to convert args to DTensor; "
f"originally (*{args}, **{kwargs})"
)
except Exception as e:
raise RuntimeError(
f"failed to run: {resolve_name(func)}, with (*{args}, **{kwargs})"
) from e

return rs


def check_dtensor_func(test_case, test_func, opinfo, dry_run=False):
try:
test_func()
except Exception:
test_case.destroy_pg()
if not dry_run:
raise
if dist.get_rank() == 0:
if opinfo.variant_test_name:
print(f"xfail('{opinfo.name}', '{opinfo.variant_test_name}'),")
else:
print(f"xfail('{opinfo.name}'),")
else:
test_case.destroy_pg()

OP_DB_WORLD_SIZE = 4
DEVICE_TYPE = "cuda" if torch.cuda.is_available() and torch.cuda.device_count() >= OP_DB_WORLD_SIZE else "cpu"


class TestDTensorOps(DTensorTestBase):
class TestDTensorOps(DTensorOpTestBase):
@property
def world_size(self) -> int:
return 4
return OP_DB_WORLD_SIZE

# only allow float dytpe for now, we can relax this constraint
# when feel necessary later (i.e when adding quantization support).
Expand All @@ -673,11 +561,6 @@ def world_size(self) -> int:
@ops(dtensor_lagging_op_db, allowed_dtypes=(torch.float,))
@skipOps("TestDTensorOps", "test_dtensor_op_db", dtensor_fails)
def test_dtensor_op_db(self, dtype, op):
pg_backend = "nccl" if DEVICE_TYPE == "cuda" else "gloo"
if pg_backend == "nccl" and torch.cuda.device_count() < self.world_size:
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)

self.init_pg(backend=pg_backend)
self.mesh = DeviceMesh(DEVICE_TYPE, torch.arange(self.world_size))

# test each op with dist tensor inputs and normal inputs
Expand All @@ -687,14 +570,123 @@ def test():
args = [sample_input.input] + list(sample_input.args)
kwargs = sample_input.kwargs

run_dtensor_crossref(self, op.op, args, kwargs)
self.run_dtensor_crossref(op.op, args, kwargs)
# we need to figure out a way to test the out variant, out variant testing
# is tricky, as we need to pre allocate the dtensor out, some of them rely
# on sharding placements to be pre-known (i.e. mm.out)
# if isinstance(expected, torch.Tensor) and op.supports_out:
# func(*args, **kwargs, out=expected)

check_dtensor_func(self, test, op)
self.check_dtensor_func(test, op)

def assert_ref_dtensor_equal(self, dtensor_rs, rs):
flat_dtensor_rs, _ = tree_flatten(dtensor_rs)
flat_rs, _ = tree_flatten(rs)
self.assertEqual(len(flat_dtensor_rs), len(flat_rs))
for dtensor_r, r in zip(flat_dtensor_rs, flat_rs):

if not isinstance(r, torch.Tensor):
continue

self.assertIsInstance(dtensor_r, torch.Tensor)
self.assertEqualOnRank(
dtensor_r.shape,
r.shape,
f"Shape mismatch! original shape:{r.shape}, dtensor shape: {dtensor_r.shape}",
)
self.assertEqualOnRank(
dtensor_r.requires_grad,
r.requires_grad,
"op result requires_grad mismatch!"
f"original requires_grad: {r.requires_grad}, "
f"dtensor requires_grad: {dtensor_r.requires_grad}",
)

self.assertEqualOnRank(dtensor_r.to_local(), r)

def run_dtensor_crossref(self, func, args, kwargs):
to_dtensor = DTensorConverter(self.mesh, args, kwargs)

# TODO: also handle cases where func raise an exception
rs = func(*args, **kwargs)

def to_replicate(e: object) -> object:
return (
e.redistribute(self.mesh, self.mesh.ndim * [Replicate()])
if isinstance(e, DTensor)
else e
)

try:
# Suppress warnings, this doesn't matter for test_meta.py
# but it does matter if you want to use this decorator
# for cross-ref testing, as some tests may be looking at
# errors
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# for every comb of sharding choices, we test if it works
for dtensor_args, dtensor_kwargs in to_dtensor:
# Only attempt if we managed to convert all tensors to DTensor
# (if any of them failed, we're in a mixed tensor situation and
# this is not allowed in DTensor)
if to_dtensor.successful():
# Handle special cases first if there's any
# Suppress warnings, this doesn't matter for test_meta.py
# but it does matter if you want to use this decorator
# for cross-ref testing, as some tests may be looking at
# errors
dtensor_rs = func(*dtensor_args, **dtensor_kwargs)

# we need to skip tests containing tensors of zero elmeents for now.
# see issue: https://github.com/pytorch/tau/issues/470
# TODO remove this once issue above fixed.
flat_args, _ = tree_flatten(dtensor_rs)
if any(
isinstance(e, torch.Tensor) and e.numel() == 0
for e in flat_args
):
continue

# redistribute/all_gather the results to compare with normal output
dtensor_rs = tree_map(to_replicate, dtensor_rs)
try:
if resolve_name(func) not in skip_bw:
if isinstance(dtensor_rs, DTensor):
dtensor_rs.to_local().sum().backward()
elif isinstance(dtensor_rs, tuple):
dtensor_rs[0].to_local().sum().backward()

except Exception as e:
# TODO(anj): Remove this guard exception after gaining more confidence.
if torch.distributed.get_rank() == 0:
print(
f"failed to run BW: {resolve_name(func)}, {func}, {str(e)})"
)
self.assert_ref_dtensor_equal(dtensor_rs, rs)
else:
raise RuntimeError(
f"failed to convert args to DTensor; "
f"originally (*{args}, **{kwargs})"
)
except Exception as e:
raise RuntimeError(
f"failed to run: {resolve_name(func)}, with (*{args}, **{kwargs})"
) from e

return rs


def check_dtensor_func(self, test_func, opinfo, dry_run=False):
try:
test_func()
except Exception:
if not dry_run:
raise
if dist.get_rank() == 0:
if opinfo.variant_test_name:
print(f"xfail('{opinfo.name}', '{opinfo.variant_test_name}'),")
else:
print(f"xfail('{opinfo.name}'),")


# only instantiate tests for DEVICE_TYPE alone (i.e. either CPU or GPU)
Expand Down
1 change: 0 additions & 1 deletion test/run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ def skip_test_p(name: str) -> bool:
"distributed/launcher/bin/test_script_is_torchelastic_launched",
"distributed/launcher/bin/test_script_local_rank",
"distributed/test_c10d_spawn",
"distributed/_tensor/test_dtensor_ops",
'distributions/test_transforms',
'distributions/test_utils',
],
Expand Down
Loading

0 comments on commit 801d831

Please sign in to comment.