Skip to content

Commit

Permalink
Add a basic test for "nvprims_nvfuser" Dynamo backend (pytorch#88186)
Browse files Browse the repository at this point in the history
Ref. pytorch#87797 (comment)

Pull Request resolved: pytorch#88186
Approved by: https://github.com/ezyang
  • Loading branch information
IvanYashchuk authored and pytorchmergebot committed Nov 2, 2022
1 parent 9ebb8d5 commit 6b5d7fc
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 1 deletion.
50 changes: 50 additions & 0 deletions test/test_nvfuser_dynamo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Owner(s): ["module: nvfuser"]

import unittest
import warnings

import torch
import torch._dynamo as torchdynamo
from torch.testing import make_tensor
from torch.testing._internal.common_utils import (
run_tests,
skipIfTorchDynamo,
TEST_WITH_ROCM,
TestCase,
IS_WINDOWS,
)
from torch.testing._internal.jit_utils import RUN_CUDA

RUN_NVFUSER = RUN_CUDA and not TEST_WITH_ROCM


def is_pre_volta():
if not RUN_NVFUSER:
return False
prop = torch.cuda.get_device_properties(torch.cuda.current_device())
return prop.major < 7


@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
@unittest.skipIf(IS_WINDOWS, "TorchDynamo is not supported on Windows")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(is_pre_volta(), "Only supported on Volta and newer devices.")
class TestNvFuserDynamo(TestCase):
def test_basic(self):
input1 = make_tensor((2, 4, 8), device="cuda", dtype=torch.float32)
input2 = make_tensor((2, 4, 8), device="cuda", dtype=torch.float32)

@torchdynamo.optimize("nvprims_nvfuser")
def func(a, b):
return a.sin() + b.cos()

# No warnings and no errors
with warnings.catch_warnings(record=True) as w:
nvfuser_result = func(input1, input2)
self.assertEqual(len(w), 0)
eager_result = func.__wrapped__(input1, input2)
self.assertEqual(eager_result, nvfuser_result)


if __name__ == "__main__":
run_tests()
4 changes: 3 additions & 1 deletion torch/_dynamo/optimizations/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,8 @@ def __call__(self, gm: torch.fx.GraphModule, example_inputs):


def prims_executor(gm, inputs, *, executor):
from functorch.compile import make_boxed_func

# This function is called once per forward/backward pass of a graph in AOT
# Autograd. We use it to set up the nvFuser-specific FX graph and return
# execute function.
Expand All @@ -274,7 +276,7 @@ def prims_executor(gm, inputs, *, executor):
prim_gm = make_fx(gm)(*inputs)

# Then we return a callable that executes the "prim_gm" graph
return partial(execute, prim_gm, executor=executor)
return make_boxed_func(partial(execute, prim_gm, executor=executor))


def create_nvprims_backend(*, executor):
Expand Down

0 comments on commit 6b5d7fc

Please sign in to comment.