Skip to content

Commit

Permalink
Refactor get numerical jacobian to calculate wrt all outputs at once (p…
Browse files Browse the repository at this point in the history
…ytorch#54378)

Summary:
Pull Request resolved: pytorch#54378

### For release notes
`torch.autograd.gradcheck.get_numerical_jacobian` (not part of the public api) is being deprecated.

In the future, user code relying on this function will break because, among other changes, `get_numerical_jacobian` now returns `List[Tuple[torch.Tensor]]` instead of `List[torch.Tensor]`.

(more details if necessary)
For a `fn` that takes in M inputs and N outputs we now return a list of M N-tuples of jacobians where `output[i][j]` would represent the numerical jacobian w.r.t. to the ith input and the jth output. Previously `get_numerical_jacobian` returned a list of tensors where each tensor represents the jacobian w.r.t. to each of the M inputs and a specific output. Finally, the function passed in as the parameter `fn` should expect to handle individual parameters, where previously `fn` is required to expect its parameters wrapped in a tuple.

 --- end --

This PR addresses the comment here pytorch#53857 (comment), to reduce the run-time of old gradcheck's get numerical jacobian by a factor of num_outputs. However, because very few ops actually return multiple outputs, there is not too much real speed up here.

The main benefit of doing this change as part of the refactor is that it helps us isolate the possible bugs that are specific to switching `get numerical jacobian` to run in a per output way vs all outputs at once. Much of the logic implemented here will be the same for the fast gradcheck case, so knowing for certain that everything should pass after this stage will make the next step much simpler.

The get_numerical_jacobian api is also being used in common_nn. So we update the callsite there as well.

Test Plan: Imported from OSS

Reviewed By: jbschlosser

Differential Revision: D27728720

Pulled By: soulitzer

fbshipit-source-id: ee0f90b4f26ddc5fdbe949c4965eaa91c9ed0bb8
  • Loading branch information
soulitzer authored and facebook-github-bot committed Apr 13, 2021
1 parent fc6985e commit 381b3d8
Show file tree
Hide file tree
Showing 4 changed files with 341 additions and 93 deletions.
115 changes: 115 additions & 0 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4021,6 +4021,9 @@ def test_gradcheck_check_no_differentiable_outputs(self):
gradcheck(lambda x: torch.tensor([x]), x)
self.assertFalse(gradcheck(lambda x: torch.tensor([x]), x, raise_exception=False))

# succeed when no outputs at all
self.assertTrue(gradcheck(lambda x: (), (x,)))

def test_gradcheck_check_batched_grad(self):
x = torch.rand(10, requires_grad=True).to_sparse()
# runtime error while compute batched grad (print big error)
Expand Down Expand Up @@ -4127,6 +4130,110 @@ def fn3(x): # C -> R
gradcheck(fn3, (x_c,))
self.assertFalse(gradcheck(fn3, (x_c,), raise_exception=False))

def test_gradcheck_dense_and_sparse_inputs(self):
def fn(x, y):
return x * y.coalesce().to_dense()
a = torch.rand(2, 2, requires_grad=True)
b = torch.rand(2, 2).to_sparse().requires_grad_(True)
self.assertTrue(gradcheck(fn, (a, b), check_sparse_nnz=True, check_batched_grad=False))

@unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled")
def test_gradcheck_multiple_mkldnn_inputs(self):
def fn(x, y):
return x + y.to_dense()
a = torch.rand(10, requires_grad=True)
b = torch.rand(10, dtype=torch.float32).to_mkldnn().requires_grad_(True)
self.assertTrue(gradcheck(fn, (a, b), atol=1e-1, check_batched_grad=False))

def fn2(x, y):
return x.to_dense() + y.to_dense()
c = torch.rand(10, dtype=torch.float32).to_mkldnn().requires_grad_(True)
self.assertTrue(gradcheck(fn, (a, c), atol=1e-1, check_batched_grad=False))

def test_gradcheck_output_shape_or_dtype_depend_on_values(self):
def fn(x):
if torch.all(x >= 1):
return torch.cat([x, x])
else:
return x
a = torch.ones(1, requires_grad=True)
with self.assertRaisesRegex(AssertionError, 'return outputs with the same shape when inputs are perturbed'):
self.assertTrue(gradcheck(fn, (a,)))

def fn2(x):
if torch.all(x >= 1):
return x.to(torch.float32)
else:
return x
with self.assertRaisesRegex(AssertionError, 'return outputs with the same dtype when inputs are perturbed'):
self.assertTrue(gradcheck(fn2, (a,)))

def test_gradcheck_complex_non_complex_outputs(self):
def fn(x, y):
z = torch.complex(x, y)
return z, x + 1
a = torch.ones(2, 2, requires_grad=True, dtype=torch.float64)
b = torch.ones(2, 2, requires_grad=True, dtype=torch.float64)
self.assertTrue(gradcheck(fn, (a, b)))

def fn2(z):
return z, torch.real(z)
c = torch.ones(2, 2, requires_grad=True, dtype=torch.complex128)
self.assertTrue(gradcheck(fn2, (c)))

def test_gradcheck_get_numerical_jacobian(self):
# get_numerical_jacobian is deprecated and no longer used internally by gradcheck
from torch.autograd.gradcheck import get_numerical_jacobian

def fn(inputs):
# get_numerical_jacobian requires fn to take inputs as a tuple
# and returns the jacobian wrt the first output
x = inputs[0]
y = inputs[1]
return 2 * x + y, x + 2 * y
a = torch.rand(2, 2, requires_grad=True, dtype=torch.float64)
b = torch.rand(2, 2, requires_grad=True, dtype=torch.float64)

with self.assertWarnsRegex(UserWarning, "get_numerical_jacobian was part of PyTorch's private API"):
jacobian = get_numerical_jacobian(fn, (a, b), target=a, eps=1e-6)
self.assertEqual(jacobian[0], 2 * torch.eye(4))

with self.assertWarnsRegex(UserWarning, "get_numerical_jacobian was part of PyTorch's private API"):
jacobian = get_numerical_jacobian(fn, (a, b), eps=1e-6)
self.assertEqual(jacobian[0], 2 * torch.eye(4))
self.assertEqual(jacobian[1], 1 * torch.eye(4))

def test_gradcheck_get_analytical_jacobian(self):
from torch.autograd.gradcheck import get_analytical_jacobian

def fn(x, y):
return 2 * x + y, x + 2 * y

a = torch.rand(2, 2, requires_grad=True, dtype=torch.float64)
b = torch.rand(2, 2, requires_grad=True, dtype=torch.float64)

outputs = fn(a, b)
with self.assertWarnsRegex(UserWarning, "get_analytical_jacobian was part of PyTorch's private API"):
jacobians, reentrant, correct_grad_sizes, correct_grad_types = get_analytical_jacobian((a, b), outputs[0])
self.assertEqual(jacobians[0], 2 * torch.eye(4))
self.assertEqual(jacobians[1], 1 * torch.eye(4))
self.assertTrue(reentrant)

class NonDetFunc(Function):
@staticmethod
def forward(ctx, x, jitter=0.0):
ctx._jitter = jitter
return x

@staticmethod
def backward(ctx, grad_out):
return NonDetFunc.apply(grad_out, ctx._jitter) * (1 + torch.rand_like(grad_out) * ctx._jitter), None

outputs = NonDetFunc.apply(a, 1e-6)
with self.assertWarnsRegex(UserWarning, "get_analytical_jacobian was part of PyTorch's private API"):
jacobians, reentrant, correct_grad_sizes, correct_grad_types = get_analytical_jacobian((a,), outputs)
self.assertFalse(reentrant)

def test_version_counter(self):
x = torch.randn(1, 2)

Expand Down Expand Up @@ -7843,6 +7950,14 @@ def fn(vec):
gradcheck(fn, (vec))
gradgradcheck(fn, (vec))

@onlyCUDA
def test_gradcheck_input_output_different_device(self, device):
x = torch.ones((1,), device="cuda", requires_grad=True)
gradcheck(lambda x: x.to("cpu"), (x,))

x = torch.ones((1,), device="cpu", requires_grad=True)
gradcheck(lambda x: x.to("cuda"), (x,))

def test_logcumsumexp_large_value(self, device):
a = torch.rand(4, 4, 4, dtype=torch.double, requires_grad=True)
with torch.no_grad():
Expand Down
13 changes: 7 additions & 6 deletions test/test_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,32 +803,33 @@ def test_gradcheck(self):
# if the gradcheck implementation changes. It's best to
# aim for attributes that may be commonly present on other
# Tensor-likes.
self.assertEqual(total_used_attrs, {
self.assertEqual({
'data',
'device',
'dtype',
'is_complex',
'is_floating_point',
'is_sparse',
'layout',
'nelement',
'new_zeros',
'numel',
'requires_grad',
'retain_grad',
'size',
'stride',
})
}, total_used_attrs)

self.assertEqual(total_used_calls, {
self.assertEqual({
torch.Tensor.new_zeros,
torch.Tensor.size,
torch.Tensor.is_complex,
torch.Tensor.is_floating_point,
torch.Tensor.nelement,
torch.Tensor.numel,
torch.Tensor.retain_grad,
torch.Tensor.stride,
torch.autograd.grad,
torch.add,
})
}, total_used_calls)

class TestNamedTuple(TestCase):
""" Regression test for gh-47090 """
Expand Down
Loading

0 comments on commit 381b3d8

Please sign in to comment.