Skip to content

Commit

Permalink
Don't DCE PythonOp
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#14773

Reviewed By: eellison

Differential Revision: D13327673

Pulled By: suo

fbshipit-source-id: 236db3407c7eacac470530836e3d4d0dc323110c
  • Loading branch information
apaszke authored and facebook-github-bot committed Dec 5, 2018
1 parent 8dfebc1 commit c79e305
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 12 deletions.
9 changes: 0 additions & 9 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -8512,15 +8512,6 @@ def foo(cond):
with self.assertRaisesRegex(torch.jit.Error, "Exception"):
foo(torch.tensor(1))

@torch.jit.script
def foo():
a = Exception()
raise a

# a gets DCEd because the expression following raise is ignored
with self.assertRaisesRegex(torch.jit.Error, "failed in interpreter"):
foo()

@torch.jit.script
def foo_except_used():
a = Exception()
Expand Down
4 changes: 1 addition & 3 deletions torch/csrc/jit/passes/dead_code_elimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,14 +177,12 @@ class DeadCodeEliminator {
}

bool hasSideEffects(Node* node) {
// FIXME: PythonOp should be treated as having side effects as well!
// Unfortunately ONNX depends on it getting removed in this pass, so
// it's not a simple change.
auto it = memo_.find(node);
if (it != memo_.end())
return it->second;
bool has_side_effects = node->kind() == prim::Print ||
node->kind() == prim::RaiseException ||
node->kind() == prim::PythonOp ||
std::any_of(node->blocks().begin(),
node->blocks().end(),
[&](Block* b) {
Expand Down
3 changes: 3 additions & 0 deletions torch/nn/modules/batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ class BatchNorm1d(_BatchNorm):
https://arxiv.org/abs/1502.03167
"""

@weak_script_method
def _check_input_dim(self, input):
if input.dim() != 2 and input.dim() != 3:
raise ValueError('expected 2D or 3D input (got {}D input)'
Expand Down Expand Up @@ -235,6 +236,7 @@ class BatchNorm2d(_BatchNorm):
https://arxiv.org/abs/1502.03167
"""

@weak_script_method
def _check_input_dim(self, input):
if input.dim() != 4:
raise ValueError('expected 4D input (got {}D input)'
Expand Down Expand Up @@ -309,6 +311,7 @@ class BatchNorm3d(_BatchNorm):
https://arxiv.org/abs/1502.03167
"""

@weak_script_method
def _check_input_dim(self, input):
if input.dim() != 5:
raise ValueError('expected 5D input (got {}D input)'
Expand Down

0 comments on commit c79e305

Please sign in to comment.