From 3cc46002a3b8490a89a0cf46b9150b676ba9a962 Mon Sep 17 00:00:00 2001 From: BowenBao Date: Thu, 4 Feb 2021 12:35:27 -0800 Subject: [PATCH] [ONNX] Fix graph position to insert clone node for inplace op removal (#50123) (#51520) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/51520 Previous insertBefore approach might end-up inserting clone node in inner sub-blocks, while then the node being used later at other outside call sites. Test Plan: Imported from OSS Reviewed By: pbelevich Differential Revision: D26203124 Pulled By: SplitInfinity fbshipit-source-id: 999511e901ad1087f360bb689fcdfc3743c78aa4 --- test/onnx/test_pytorch_onnx_onnxruntime.py | 14 ++++++++++++++ .../passes/onnx/remove_inplace_ops_for_onnx.cpp | 1 + 2 files changed, 15 insertions(+) diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 8c3bf0cbd204a..8a11df49de2ed 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -3513,6 +3513,20 @@ def forward(self, input, other): y = torch.randn(6, 4) self.run_test(ViewModel(), (x, y)) + def test_linear(self): + class LinearModel(torch.nn.Module): + def __init__(self): + super(LinearModel, self).__init__() + self.fc = torch.nn.Linear(16, 16) + + def forward(self, x): + out = self.fc(x) + out = self.fc(out) + return out + + x = torch.randn(3, 16) + self.run_test(LinearModel(), (x,)) + @disableScriptTest() def test_weight_norm(self): # addmm for 3-d inputs converts to onnx::MatMul diff --git a/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp b/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp index 6ec6dd3406623..ebcfa3188ab11 100644 --- a/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp +++ b/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp @@ -771,6 +771,7 @@ static void PrepareForRemoveMutations(MutationRemover& mr, Block* b) { b->prependNode(newNode); noneNode->insertBefore(newNode); } + TORCH_INTERNAL_ASSERT(nullptr != newNode); node->replaceInput(index, newNode->output()); input->replaceAllUsesAfterNodeWith(node, newNode->output()); }