Skip to content

Commit

Permalink
[ONNX] Fix graph sequence output from loop node (pytorch#51305) (pyto…
Browse files Browse the repository at this point in the history
…rch#51521)

Summary:
Pull Request resolved: pytorch#51521

* Add loop & if node to the list of nodes that could produce sequence type output.
* Switch from `[]` to `at()` to avoid segfault of out of range access.

Test Plan: Imported from OSS

Reviewed By: pbelevich

Differential Revision: D26203112

Pulled By: SplitInfinity

fbshipit-source-id: e990eeed933124b195be0be159271e33fb485063
  • Loading branch information
BowenBao authored and facebook-github-bot committed Feb 4, 2021
1 parent 3cc4600 commit 586c2e8
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 27 deletions.
1 change: 0 additions & 1 deletion scripts/onnx/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ if [[ "$BUILD_ENVIRONMENT" == *ort_test2* ]]; then
"$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset$i"
done
pytest "${args[@]}" \
"$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset12_onnx_shape_inference" \
"$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset12_IRv4_old_jit_API"
fi

Expand Down
13 changes: 3 additions & 10 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from test_pytorch_common import RNN_BATCH_SIZE, RNN_SEQUENCE_LENGTH, RNN_INPUT_SIZE, RNN_HIDDEN_SIZE
from typing import List, Tuple, Optional
import model_defs.word_language_model as word_language_model

import onnx

import torchvision
from torchvision import ops
from torchvision.models.detection.image_list import ImageList
Expand All @@ -26,7 +29,6 @@
from torchvision.models.detection.roi_heads import RoIHeads
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead
from collections import OrderedDict
import onnx

def to_numpy(tensor):
if tensor.requires_grad:
Expand Down Expand Up @@ -3876,7 +3878,6 @@ def forward(self, x):
inputs = torch.zeros(1, 2, 3, dtype=torch.long)
self.run_test(model, inputs)

@skipIfUnsupportedOpsetVersion([13])
@skipIfUnsupportedMinOpsetVersion(11)
def test_loop_with_list(self):
class ListLoopModel(torch.jit.ScriptModule):
Expand Down Expand Up @@ -6063,7 +6064,6 @@ def forward(self, box_regression: torch.Tensor, proposals: List[torch.Tensor]):
convert_to_onnx(model, input=(box_regression, proposal),
example_outputs=outputs, use_new_jit_passes=True)

@skipIfUnsupportedOpsetVersion([13])
def test_initializer_sequence(self):
class MyModule(torch.nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
Expand Down Expand Up @@ -6681,12 +6681,5 @@ def setup_rnn_tests():
keep_initializers_as_inputs=False,
use_new_jit_passes=False))


# opset 12 tests, with _onnx_shape_inference=True.
TestONNXRuntime_opset12_onnx_shape_inference = type(str("TestONNXRuntime_opset12_onnx_shape_inference"),
(unittest.TestCase,),
dict(TestONNXRuntime.__dict__, opset_version=12,
onnx_shape_inference=True))

if __name__ == '__main__':
unittest.main()
42 changes: 26 additions & 16 deletions torch/csrc/jit/passes/onnx/shape_type_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,8 @@ bool HasSequenceTypeOutput(Node* node) {
node->kind() == ::c10::onnx::SequenceInsert ||
node->kind() == ::c10::onnx::SequenceEmpty ||
node->kind() == ::c10::onnx::SequenceErase ||
node->kind() == ::c10::onnx::SequenceConstruct)
node->kind() == ::c10::onnx::SequenceConstruct ||
node->kind() == ::c10::onnx::Loop || node->kind() == ::c10::onnx::If)
return true;
return false;
}
Expand Down Expand Up @@ -618,7 +619,7 @@ void ONNXAssignOutputShape(

if (PyList_Check(elem)) {
size_t list_len = PyList_GET_SIZE(elem);
if (HasSequenceTypeOutput(graph->outputs()[outputs_index]->node())) {
if (HasSequenceTypeOutput(graph->outputs().at(outputs_index)->node())) {
if (list_len > 0) {
auto& var =
reinterpret_cast<THPVariable*>(PyList_GET_ITEM(elem, 0))->cdata;
Expand All @@ -630,15 +631,18 @@ void ONNXAssignOutputShape(
var.scalar_type() == new_var.scalar_type(),
"Unsupported sequence type in model outputs. ONNX supports sequences of elements of the same data type.");
}
auto elem_type = graph->outputs()[outputs_index]
auto elem_type = graph->outputs()
.at(outputs_index)
->type()
->castRaw<ListType>()
->getElementType()
->cast<TensorType>();
elem_type = elem_type->withScalarType(var.scalar_type());
graph->outputs()[outputs_index]->setType(MergeInferredType(
graph->outputs()[outputs_index]->type(),
ListType::create(elem_type)));
graph->outputs()
.at(outputs_index)
->setType(MergeInferredType(
graph->outputs().at(outputs_index)->type(),
ListType::create(elem_type)));
outputs_index++;
TORCH_INTERNAL_ASSERT(
outputs_index <= graph->outputs().size(),
Expand All @@ -652,9 +656,11 @@ void ONNXAssignOutputShape(
PyObject* list_elem = PyList_GET_ITEM(elem, j);
TORCH_INTERNAL_ASSERT(THPVariable_Check(list_elem));
auto& var = reinterpret_cast<THPVariable*>(list_elem)->cdata;
graph->outputs()[outputs_index + j]->setType(MergeInferredType(
graph->outputs()[outputs_index + j]->type(),
TensorType::create(var)));
graph->outputs()
.at(outputs_index + j)
->setType(MergeInferredType(
graph->outputs().at(outputs_index + j)->type(),
TensorType::create(var)));
}
outputs_index += list_len;
TORCH_INTERNAL_ASSERT(
Expand All @@ -669,9 +675,11 @@ void ONNXAssignOutputShape(
PyObject* tuple_elem = PyTuple_GET_ITEM(elem, j);
TORCH_INTERNAL_ASSERT(THPVariable_Check(tuple_elem));
auto& var = reinterpret_cast<THPVariable*>(tuple_elem)->cdata;
graph->outputs()[outputs_index + j]->setType(MergeInferredType(
graph->outputs()[outputs_index + j]->type(),
TensorType::create(var)));
graph->outputs()
.at(outputs_index + j)
->setType(MergeInferredType(
graph->outputs().at(outputs_index + j)->type(),
TensorType::create(var)));
}
outputs_index += tuple_len;
TORCH_INTERNAL_ASSERT(
Expand All @@ -681,7 +689,7 @@ void ONNXAssignOutputShape(
} else if (THPVariable_Check(elem)) {
at::Tensor var = reinterpret_cast<THPVariable*>(elem)->cdata;
ONNXUpdateTypeFromTensor(
graph->outputs()[outputs_index], var, onnx_shape_inference);
graph->outputs().at(outputs_index), var, onnx_shape_inference);
outputs_index++;
TORCH_INTERNAL_ASSERT(
outputs_index <= graph->outputs().size(),
Expand All @@ -700,9 +708,11 @@ void ONNXAssignOutputShape(
auto& var =
reinterpret_cast<THPVariable*>(PyTuple_GET_ITEM(tuple_elem, 1))
->cdata;
graph->outputs()[outputs_index + j]->setType(MergeInferredType(
graph->outputs()[outputs_index + j]->type(),
TensorType::create(var)));
graph->outputs()
.at(outputs_index + j)
->setType(MergeInferredType(
graph->outputs().at(outputs_index + j)->type(),
TensorType::create(var)));
}
outputs_index += unrolled_dict.size();
TORCH_INTERNAL_ASSERT(
Expand Down

0 comments on commit 586c2e8

Please sign in to comment.