Skip to content

Commit

Permalink
Revert "[ONNX] Refactor to remove inline imports (pytorch#77142)"
Browse files Browse the repository at this point in the history
This reverts commit c08b8f0.

Reverted pytorch#77142 on behalf of https://github.com/malfet
  • Loading branch information
pytorchmergebot committed May 13, 2022
1 parent 6066e59 commit 6b366dd
Show file tree
Hide file tree
Showing 22 changed files with 510 additions and 547 deletions.
7 changes: 3 additions & 4 deletions test/onnx/test_onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@
import itertools
import os
import sys
import unittest.mock
import unittest
from typing import Callable, Iterable, Optional, Tuple, Union

import onnx
from test_pytorch_common import TestCase

import torch
from torch.onnx import OperatorExportTypes, symbolic_registry
from torch.onnx._globals import GLOBALS
from torch.onnx.symbolic_helper import _onnx_unsupported
from torch.testing._internal.common_utils import custom_op, skipIfCaffe2

Expand All @@ -30,9 +29,9 @@ def export_to_onnx(
Union[contextlib.AbstractContextManager, contextlib.ContextDecorator],
]
] = None,
mocks: Optional[Iterable] = None,
mocks: Optional[Iterable[unittest.mock.patch]] = None,
operator_export_type: OperatorExportTypes = OperatorExportTypes.ONNX,
opset_version: int = GLOBALS.export_onnx_opset_version,
opset_version: int = torch.onnx.symbolic_helper._export_onnx_opset_version,
) -> onnx.ModelProto:
"""Exports `model(input)` to ONNX and returns it.
Expand Down
6 changes: 2 additions & 4 deletions test/onnx/test_onnx_opset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,10 @@
import torch.onnx
from torch.nn import Module
from torch.onnx import producer_name, producer_version
from torch.onnx._globals import GLOBALS
from torch.onnx.symbolic_helper import _export_onnx_opset_version


def check_onnx_opset_operator(
model, ops, opset_version=GLOBALS.export_onnx_opset_version
):
def check_onnx_opset_operator(model, ops, opset_version=_export_onnx_opset_version):
# check_onnx_components
assert (
model.producer_name == producer_name
Expand Down
6 changes: 2 additions & 4 deletions test/onnx/test_pytorch_onnx_onnxruntime_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,16 @@
skipIfUnsupportedMinOpsetVersion,
skipScriptTest,
)

# TODO(justinchuby): Remove reference to other unit tests.
from test_pytorch_onnx_onnxruntime import TestONNXRuntime

import torch
from torch.cuda.amp import autocast
from torch.onnx._globals import GLOBALS


class TestONNXRuntime_cuda(unittest.TestCase):
from torch.onnx.symbolic_helper import _export_onnx_opset_version

opset_version = GLOBALS.export_onnx_opset_version
opset_version = _export_onnx_opset_version
keep_initializers_as_inputs = True
onnx_shape_inference = True

Expand Down
9 changes: 6 additions & 3 deletions test/onnx/test_pytorch_onnx_shape_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@
from test_pytorch_common import skipIfUnsupportedMinOpsetVersion

import torch
from torch.onnx import _constants
from torch.onnx.symbolic_helper import _set_onnx_shape_inference, _set_opset_version
from torch.onnx.symbolic_helper import (
_onnx_main_opset,
_set_onnx_shape_inference,
_set_opset_version,
)


def expect_tensor(scalar_type, shape=None):
Expand All @@ -24,7 +27,7 @@ def verify(actual_type):
class TestONNXShapeInference(unittest.TestCase):
def __init__(self, *args, **kwargs):
unittest.TestCase.__init__(self, *args, **kwargs)
self.opset_version = _constants.onnx_main_opset
self.opset_version = _onnx_main_opset
_set_onnx_shape_inference(True)
_set_opset_version(self.opset_version)

Expand Down
3 changes: 3 additions & 0 deletions test/onnx/test_utility_funs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import copy
import io
import unittest

import onnx
import torchvision
Expand Down Expand Up @@ -33,6 +34,8 @@
parse_args,
)

skip = unittest.skip


class _BaseTestCase(TestCase):
def setUp(self):
Expand Down
4 changes: 2 additions & 2 deletions tools/onnx/update_default_opset_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ def read_sub_write(path: str, prefix_pat: str) -> None:


read_sub_write(
os.path.join("torch", "onnx", "_constants.py"),
r"(onnx_default_opset = )\d+",
os.path.join("torch", "onnx", "symbolic_helper.py"),
r"(_default_onnx_opset_version = )\d+",
)
read_sub_write(
os.path.join("torch", "onnx", "__init__.py"), r"(opset_version \(int, default )\d+"
Expand Down
9 changes: 4 additions & 5 deletions torch/csrc/jit/passes/onnx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ void NodeToONNX(
::torch::onnx::OperatorExportTypes operator_export_type,
std::unordered_map<Value*, Value*>& env) {
py::object onnx = py::module::import("torch.onnx");
py::object onnx_globals = py::module::import("torch.onnx._globals");
py::object onnx_symbolic = py::module::import("torch.onnx.symbolic_helper");
py::object onnx_registry = py::module::import("torch.onnx.symbolic_registry");

// Setup all the lambda helper functions.
Expand Down Expand Up @@ -273,8 +273,8 @@ void NodeToONNX(
}
// For const node, it does not need params_dict info, so set it to {}.
const ParamMap empty_params_dict = {};
auto opset_version = py::cast<int>(
onnx_globals.attr("GLOBALS").attr("export_onnx_opset_version"));
auto opset_version =
py::cast<int>(onnx_symbolic.attr("_export_onnx_opset_version"));
for (const auto i : c10::irange(num_old_outputs)) {
auto old = old_outputs[i];
if (outputs[i]) {
Expand Down Expand Up @@ -435,8 +435,7 @@ void NodeToONNX(
pyobj = func->get();
}

py::object opset_version =
onnx_globals.attr("GLOBALS").attr("export_onnx_opset_version");
py::object opset_version = onnx_symbolic.attr("_export_onnx_opset_version");
py::object is_registered_op = onnx_registry.attr("is_registered_op")(
"PythonOp", "prim", opset_version);
if (!py::hasattr(pyobj, "symbolic") &&
Expand Down
2 changes: 1 addition & 1 deletion torch/onnx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict
from typing import Dict, Optional

import torch._C as _C

Expand Down
6 changes: 0 additions & 6 deletions torch/onnx/_constants.py

This file was deleted.

48 changes: 0 additions & 48 deletions torch/onnx/_globals.py

This file was deleted.

Loading

0 comments on commit 6b366dd

Please sign in to comment.