Skip to content

Commit

Permalink
[ONNX] Rename constants for clarity (pytorch#84645)
Browse files Browse the repository at this point in the history
Rename constants to make them more clear. Fix styles to upper case.

Removed `onnx_stable_opsets` because it can be computed from `ONNX_MIN_OPSET` and `ONNX_MAX_OPSET`.

Fixes pytorch#84643

Pull Request resolved: pytorch#84645
Approved by: https://github.com/BowenBao
  • Loading branch information
justinchuby authored and pytorchmergebot committed Sep 9, 2022
1 parent bc3683d commit 2fa8142
Show file tree
Hide file tree
Showing 10 changed files with 26 additions and 29 deletions.
2 changes: 1 addition & 1 deletion test/onnx/onnx_test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def set_rng_seed(seed):


class _TestONNXRuntime(common_utils.TestCase):
opset_version = _constants.onnx_default_opset
opset_version = _constants.ONNX_DEFAULT_OPSET
keep_initializers_as_inputs = True # For IR version 3 type export.
is_script = False
check_shape = True
Expand Down
2 changes: 1 addition & 1 deletion test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
# The min onnx opset version to test for
MIN_ONNX_OPSET_VERSION = 9
# The max onnx opset version to test for
MAX_ONNX_OPSET_VERSION = _constants.onnx_main_opset
MAX_ONNX_OPSET_VERSION = _constants.ONNX_MAX_OPSET


def _init_test_generalized_rcnn_transform():
Expand Down
2 changes: 1 addition & 1 deletion test/onnx/test_pytorch_onnx_shape_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def verify(actual_type):

class TestONNXShapeInference(common_utils.TestCase):
def setUp(self):
self.opset_version = _constants.onnx_main_opset
self.opset_version = _constants.ONNX_MAX_OPSET
symbolic_helper._set_onnx_shape_inference(True)
symbolic_helper._set_opset_version(self.opset_version)

Expand Down
2 changes: 1 addition & 1 deletion tools/onnx/update_default_opset_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def main(args: Any) -> None:

read_sub_write(
os.path.join("torch", "onnx", "_constants.py"),
r"(onnx_default_opset = )\d+",
r"(ONNX_DEFAULT_OPSET = )\d+",
new_default,
)
read_sub_write(
Expand Down
10 changes: 6 additions & 4 deletions torch/onnx/_constants.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Constant values used in ONNX."""

ONNX_ARCHIVE_MODEL_PROTO_NAME = "__MODEL_PROTO"
onnx_default_opset = 14
onnx_main_opset = 17
onnx_stable_opsets = tuple(range(7, onnx_main_opset))
onnx_constant_folding_opsets = tuple(range(9, onnx_main_opset + 1))

ONNX_MIN_OPSET = 7
ONNX_MAX_OPSET = 17
# ONNX_DEFAULT_OPSET generated by tools/onnx/update_default_opset_version.py
ONNX_DEFAULT_OPSET = 14
ONNX_CONSTANT_FOLDING_MIN_OPSET = 9

PYTORCH_GITHUB_ISSUES_URL = "https://github.com/pytorch/pytorch/issues"
7 changes: 4 additions & 3 deletions torch/onnx/_globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class _InternalGlobals:
"""

def __init__(self):
self._export_onnx_opset_version = _constants.onnx_default_opset
self._export_onnx_opset_version = _constants.ONNX_DEFAULT_OPSET
self._training_mode: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL
self._in_onnx_export: bool = False
# Whether the user's model is training during export
Expand Down Expand Up @@ -66,8 +66,9 @@ def export_onnx_opset_version(self) -> int:

@export_onnx_opset_version.setter
def export_onnx_opset_version(self, value: int):
supported_versions = [_constants.onnx_main_opset]
supported_versions.extend(_constants.onnx_stable_opsets)
supported_versions = range(
_constants.ONNX_MIN_OPSET, _constants.ONNX_MAX_OPSET + 1
)
if value not in supported_versions:
raise ValueError(f"Unsupported ONNX opset version: {value}")
self._export_onnx_opset_version = value
Expand Down
3 changes: 1 addition & 2 deletions torch/onnx/_onnx_supported_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
from torch import _C
from torch.onnx import _constants, symbolic_registry

for v in _constants.onnx_stable_opsets:
for v in range(_constants.ONNX_MIN_OPSET, _constants.ONNX_MAX_OPSET + 1):
symbolic_registry.register_version("", v)
symbolic_registry.register_version("", _constants.onnx_main_opset)


class _TorchSchema:
Expand Down
7 changes: 3 additions & 4 deletions torch/onnx/symbolic_registry.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import importlib
import inspect
import itertools
import warnings
from typing import Any, Callable, Dict, Tuple, Union

Expand Down Expand Up @@ -39,8 +38,8 @@


def _import_symbolic_opsets():
for opset_version in itertools.chain(
_constants.onnx_stable_opsets, [_constants.onnx_main_opset]
for opset_version in range(
_constants.ONNX_MIN_OPSET, _constants.ONNX_MAX_OPSET + 1
):
module = importlib.import_module(f"torch.onnx.symbolic_opset{opset_version}")
global _symbolic_versions
Expand Down Expand Up @@ -149,7 +148,7 @@ def unregister_op(opname: str, domain: str, version: int):

def get_op_supported_version(opname: str, domain: str, version: int):
iter_version = version
while iter_version <= _constants.onnx_main_opset:
while iter_version <= _constants.ONNX_MAX_OPSET:
ops = [(op[0], op[1]) for op in get_ops_in_version(iter_version)]
if (domain, opname) in ops:
return iter_version
Expand Down
18 changes: 7 additions & 11 deletions torch/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import copy
import inspect
import io
import itertools
import os
import re
import textwrap
Expand Down Expand Up @@ -1139,7 +1138,8 @@ def _model_to_graph(

if (
do_constant_folding
and GLOBALS.export_onnx_opset_version in _constants.onnx_constant_folding_opsets
and GLOBALS.export_onnx_opset_version
>= _constants.ONNX_CONSTANT_FOLDING_MIN_OPSET
):
params_dict = _C._jit_pass_onnx_constant_fold(
graph, params_dict, GLOBALS.export_onnx_opset_version
Expand Down Expand Up @@ -1204,7 +1204,7 @@ def export_to_pretty_string(
A UTF-8 str containing a human-readable representation of the ONNX model.
"""
if opset_version is None:
opset_version = _constants.onnx_default_opset
opset_version = _constants.ONNX_DEFAULT_OPSET
if custom_opsets is None:
custom_opsets = {}
symbolic_helper._set_opset_version(opset_version)
Expand Down Expand Up @@ -1265,7 +1265,7 @@ def unconvertible_ops(
of the unconvertible ops.
"""

opset_version = opset_version or _constants.onnx_default_opset
opset_version = opset_version or _constants.ONNX_DEFAULT_OPSET
symbolic_helper._set_opset_version(opset_version)
# operator_export_type is set to ONNX_FALLTHROUGH by default so that if an op is not supported
# in ONNX, fall through will occur and export the operator as is, as a custom ONNX op.
Expand Down Expand Up @@ -1418,7 +1418,7 @@ def _export(
symbolic_helper._set_onnx_shape_inference(onnx_shape_inference)

if opset_version is None:
opset_version = _constants.onnx_default_opset
opset_version = _constants.ONNX_DEFAULT_OPSET

if export_modules_as_functions and opset_version < 15:
raise ValueError(
Expand Down Expand Up @@ -1913,9 +1913,7 @@ def register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version):
"""
ns, op_name = get_ns_op_name_from_custom_op(symbolic_name)

for version in itertools.chain(
_constants.onnx_stable_opsets, [_constants.onnx_main_opset]
):
for version in range(_constants.ONNX_MIN_OPSET, _constants.ONNX_MAX_OPSET + 1):
if version >= opset_version:
symbolic_registry.register_op(op_name, symbolic_fn, ns, version)

Expand All @@ -1933,9 +1931,7 @@ def unregister_custom_op_symbolic(symbolic_name: str, opset_version: int):
"""
ns, op_name = get_ns_op_name_from_custom_op(symbolic_name)

for version in itertools.chain(
_constants.onnx_stable_opsets, [_constants.onnx_main_opset]
):
for version in range(_constants.ONNX_MIN_OPSET, _constants.ONNX_MAX_OPSET + 1):
if version >= opset_version:
symbolic_registry.unregister_op(op_name, ns, version)

Expand Down
2 changes: 1 addition & 1 deletion torch/onnx/verification.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ def _onnx_graph_from_model(
output_names = export_options.output_names

if opset_version is None:
opset_version = _constants.onnx_default_opset
opset_version = _constants.ONNX_DEFAULT_OPSET

utils._setup_trace_module_map(model, export_modules_as_functions)

Expand Down

0 comments on commit 2fa8142

Please sign in to comment.