forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor get numerical jacobian to calculate wrt all outputs at once (p…
…ytorch#54378) Summary: Pull Request resolved: pytorch#54378 ### For release notes `torch.autograd.gradcheck.get_numerical_jacobian` (not part of the public api) is being deprecated. In the future, user code relying on this function will break because, among other changes, `get_numerical_jacobian` now returns `List[Tuple[torch.Tensor]]` instead of `List[torch.Tensor]`. (more details if necessary) For a `fn` that takes in M inputs and N outputs we now return a list of M N-tuples of jacobians where `output[i][j]` would represent the numerical jacobian w.r.t. to the ith input and the jth output. Previously `get_numerical_jacobian` returned a list of tensors where each tensor represents the jacobian w.r.t. to each of the M inputs and a specific output. Finally, the function passed in as the parameter `fn` should expect to handle individual parameters, where previously `fn` is required to expect its parameters wrapped in a tuple. --- end -- This PR addresses the comment here pytorch#53857 (comment), to reduce the run-time of old gradcheck's get numerical jacobian by a factor of num_outputs. However, because very few ops actually return multiple outputs, there is not too much real speed up here. The main benefit of doing this change as part of the refactor is that it helps us isolate the possible bugs that are specific to switching `get numerical jacobian` to run in a per output way vs all outputs at once. Much of the logic implemented here will be the same for the fast gradcheck case, so knowing for certain that everything should pass after this stage will make the next step much simpler. The get_numerical_jacobian api is also being used in common_nn. So we update the callsite there as well. Test Plan: Imported from OSS Reviewed By: jbschlosser Differential Revision: D27728720 Pulled By: soulitzer fbshipit-source-id: ee0f90b4f26ddc5fdbe949c4965eaa91c9ed0bb8
- Loading branch information
1 parent
fc6985e
commit 381b3d8
Showing
4 changed files
with
341 additions
and
93 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.