forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add autograd automatic anomaly detection (pytorch#7677)
* add autograd automatic anomaly detection * python 3 string support * Fix non python build * fix typo in doc * better test and naming fix * fix no python build and python object handling * fix missing checks * clean NO_PYTHON build * Remove unwanted changes
- Loading branch information
Showing
20 changed files
with
344 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import torch | ||
|
||
|
||
class detect_anomaly(object): | ||
r"""Context-manager that enable anomaly detection for the autograd engine. | ||
This does two things: | ||
- Running the forward pass with detection enabled will allow the backward | ||
pass to print the traceback of the forward operation that created the failing | ||
backward function. | ||
- Any backward computation that generate "nan" value will raise an error. | ||
Example: | ||
>>> import torch | ||
>>> from torch import autograd | ||
>>> class MyFunc(autograd.Function): | ||
... @staticmethod | ||
... def forward(ctx, inp): | ||
... return inp.clone() | ||
... @staticmethod | ||
... def backward(ctx, gO): | ||
... # Error during the backward pass | ||
... raise RuntimeError("Some error in backward") | ||
... return gO.clone() | ||
>>> def run_fn(a): | ||
... out = MyFunc.apply(a) | ||
... return out.sum() | ||
>>> inp = torch.rand(10, 10, requires_grad=True) | ||
>>> out = run_fn(inp) | ||
>>> out.backward() | ||
Traceback (most recent call last): | ||
File "<stdin>", line 1, in <module> | ||
File "/your/pytorch/install/torch/tensor.py", line 93, in backward | ||
torch.autograd.backward(self, gradient, retain_graph, create_graph) | ||
File "/your/pytorch/install/torch/autograd/__init__.py", line 90, in backward | ||
allow_unreachable=True) # allow_unreachable flag | ||
File "/your/pytorch/install/torch/autograd/function.py", line 76, in apply | ||
return self._forward_cls.backward(self, *args) | ||
File "<stdin>", line 8, in backward | ||
RuntimeError: Some error in backward | ||
>>> with autograd.detect_anomaly(): | ||
... inp = torch.rand(10, 10, requires_grad=True) | ||
... out = run_fn(inp) | ||
... out.backward() | ||
Traceback of forward call that caused the error: | ||
File "tmp.py", line 53, in <module> | ||
out = run_fn(inp) | ||
File "tmp.py", line 44, in run_fn | ||
out = MyFunc.apply(a) | ||
Traceback (most recent call last): | ||
File "<stdin>", line 4, in <module> | ||
File "/your/pytorch/install/torch/tensor.py", line 93, in backward | ||
torch.autograd.backward(self, gradient, retain_graph, create_graph) | ||
File "/your/pytorch/install/torch/autograd/__init__.py", line 90, in backward | ||
allow_unreachable=True) # allow_unreachable flag | ||
File "/your/pytorch/install/torch/autograd/function.py", line 76, in apply | ||
return self._forward_cls.backward(self, *args) | ||
File "<stdin>", line 8, in backward | ||
RuntimeError: Some error in backward | ||
""" | ||
|
||
def __init__(self): | ||
self.prev = torch.is_anomaly_enabled() | ||
|
||
def __enter__(self): | ||
torch.set_anomaly_enabled(True) | ||
|
||
def __exit__(self, *args): | ||
torch.set_anomaly_enabled(self.prev) | ||
return False | ||
|
||
|
||
class set_detect_anomaly(object): | ||
r"""Context-manager that sets the anomaly detection for the autograd engine on or off. | ||
``set_detect_anomaly`` will enable or disable the autograd anomaly detection | ||
based on its argument :attr:`mode`. | ||
It can be used as a context-manager or as a function. | ||
See ``detect_anomaly`` above for details of the anomaly detection behaviour. | ||
Arguments: | ||
mode (bool): Flag whether to enable anomaly detection (``True``), | ||
or disable (``False``). | ||
""" | ||
|
||
def __init__(self, mode): | ||
self.prev = torch.is_anomaly_enabled() | ||
torch.set_anomaly_enabled(mode) | ||
|
||
def __enter__(self): | ||
pass | ||
|
||
def __exit__(self, *args): | ||
torch.set_anomaly_enabled(self.prev) | ||
return False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
#include "torch/csrc/autograd/anomaly_mode.h" | ||
|
||
namespace torch { namespace autograd { | ||
|
||
bool AnomalyMode::_enabled = 0; | ||
|
||
}} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
#pragma once | ||
|
||
namespace torch { namespace autograd { | ||
|
||
struct AnomalyMode { | ||
static bool is_enabled() { | ||
return _enabled; | ||
} | ||
static void set_enabled(bool enabled) { | ||
_enabled = enabled; | ||
} | ||
|
||
private: | ||
static bool _enabled; | ||
}; | ||
|
||
|
||
struct AnomalyMetadata { | ||
virtual void store_stack() = 0; | ||
virtual void print_stack() = 0; | ||
}; | ||
|
||
}} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.