Skip to content

Commit

Permalink
[ONNX] Fix node attributes when namespace is aten (pytorch#84211)
Browse files Browse the repository at this point in the history
When `g.at` is used, the previous clean up in pytorch#83136 mistakenly removed the behavior that sets `aten=True` in `_add_attribute`. This PR brings the behavior back.
Pull Request resolved: pytorch#84211
Approved by: https://github.com/thiagocrepaldi, https://github.com/BowenBao
  • Loading branch information
justinchuby authored and pytorchmergebot committed Sep 9, 2022
1 parent 2fa8142 commit dbdc1cd
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions torch/onnx/_patch_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,17 +101,18 @@ def _aten_op(g: _C.Graph, operator: str, *args, overload_name: str = "", **kwarg


@_beartype.beartype
def _block_op(b: _C.Block, opname: str, *args: _C.Value, **kwargs):
def _block_op(block: _C.Block, opname: str, *args: _C.Value, **kwargs):
if "::" in opname:
aten = False
ns_opname = opname
namespace, op = opname.split("::")
else:
aten = kwargs.pop("aten", False)
ns = "aten" if aten else "onnx"
ns_opname = ns + "::" + opname
n = b.addNode(ns_opname, args)
namespace = "onnx"
op = opname

n = block.addNode(f"{namespace}::{op}", args)
aten = namespace == "aten"
skip_attrs = {"inplace", "aten"}
for k, v in sorted(kwargs.items()):
if k == "inplace":
if k in skip_attrs:
continue
_add_attribute(n, k, v, aten=aten)
outputs = tuple(n.outputs())
Expand All @@ -135,10 +136,11 @@ def _new_node(
Returns:
The new node.
"""
aten = kwargs.pop("aten", False)
aten = namespace == "aten"
node = g.create(f"{namespace}::{op}", args, outputs)
skip_attrs = {"inplace", "aten"}
for k, v in sorted(kwargs.items()):
if k == "inplace":
if k in skip_attrs:
continue
_add_attribute(node, k, v, aten=aten)
return node
Expand Down Expand Up @@ -175,7 +177,7 @@ def _add_attribute(node: _C.Node, key: str, value: Any, aten: bool):
if m is None:
raise ValueError(
f"Invalid attribute specifier '{key}' names "
" must be suffixed with type, e.g. 'dim_i' or 'dims_i'"
"must be suffixed with type, e.g. 'dim_i' or 'dims_i'"
)
name, kind = m.group(1), m.group(2)
if _is_onnx_list(value):
Expand Down

0 comments on commit dbdc1cd

Please sign in to comment.