Skip to content

Commit

Permalink
[dynamo] Support dynamic slicing (pytorch#91341)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#91341
Approved by: https://github.com/voznesenskym
  • Loading branch information
tugsbayasgalan authored and pytorchmergebot committed Jan 10, 2023
1 parent 3139e68 commit 0c3ed2e
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 14 deletions.
68 changes: 68 additions & 0 deletions test/dynamo/test_export.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Owner(s): ["module: dynamo"]
import operator
from typing import Dict, List
from unittest.mock import patch

Expand Down Expand Up @@ -1494,6 +1495,73 @@ def f(x: torch.Tensor) -> torch.Tensor:

self.assertTrue(has_sym_size)

@patch.object(torch._dynamo.config, "dynamic_shapes", True)
def test_dynamic_slicing(self):
def f(x):
return x[: x.shape[0] - 2, x.shape[1] - 1 :: 2]

gm_aten_mode, _ = torch._dynamo.export(
f, torch.randn(4, 5), aten_graph=True, tracing_mode="symbolic"
)

inp = torch.randn(6, 7)
self.assertEqual(gm_aten_mode(inp).shape, f(inp).shape)

count = 0
# aten graph should flatten getitem calls to actual
# slice kernel call.
for node in gm_aten_mode.graph.nodes:
if (
node.op == "call_function"
and node.target == torch.ops.aten.slice.Tensor
):
count += 1

self.assertEqual(count, 2)

gm_torch_mode, _ = torch._dynamo.export(f, torch.randn(4, 5), aten_graph=False)

# In torch mode, the graph should contain 3 getitem methods
# one for x.shape[0]-2 and one for x.shape[1]-1 and one for slice
# this is because Tensor class has its' own getitem method
# which gets translated to aten.Slice later.
count = 0
for node in gm_torch_mode.graph.nodes:
if node.op == "call_function" and node.target == operator.getitem:
count += 1

self.assertEqual(count, 3)
self.assertEqual(gm_torch_mode(inp).shape, f(inp).shape)

@patch.object(torch._dynamo.config, "dynamic_shapes", True)
def test_dynamic_slicing_invalid(self):
def g(x, y):
return x[y : x.shape[0]]

with self.assertRaisesRegex(
torch._dynamo.exc.Unsupported,
"Dynamic slicing on data-dependent value is not supported",
):
torch._dynamo.export(
g,
torch.randn(4, 5),
torch.tensor(2),
aten_graph=True,
tracing_mode="symbolic",
)

@patch.object(torch._dynamo.config, "dynamic_shapes", True)
def test_dynamic_slicing_simple(self):
def f(x):
return x[slice(None, None, None)]

gm, _ = torch._dynamo.export(
f, torch.randn(4, 5), aten_graph=True, tracing_mode="symbolic"
)

inp = torch.randn(6, 7)
self.assertEqual(gm(inp), f(inp))


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
18 changes: 4 additions & 14 deletions torch/_dynamo/variables/lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,11 +434,6 @@ def call_hasattr(self, tx, name: str) -> "VariableTracker":

class SliceVariable(BaseListVariable):
def __init__(self, items, **kwargs):
from .tensor import DynamicShapeVariable

if any([isinstance(x, DynamicShapeVariable) for x in items]):
unimplemented("Dynamic slicing not supported")

items_to_map = items
start, stop, step = [variables.ConstantVariable(None)] * 3

Expand All @@ -451,15 +446,10 @@ def __init__(self, items, **kwargs):
else:
raise AssertionError()

# Avoids a .item() call in the tensor slice that would attempt to get a
# value out fake tensors, and which would determine the output shape of
# the slice. It is a workaround until
# https://github.com/pytorch/pytorch/pull/83567 is landed and there is
# more complete support for breaking on data dependent operators.
if not config.capture_scalar_outputs:
for limit in (start, stop, step):
if isinstance(limit, (variables.TensorVariable, DynamicShapeVariable)):
unimplemented("Dynamic slicing not supported")
if isinstance(start, variables.TensorVariable) or isinstance(
stop, variables.TensorVariable
):
unimplemented("Dynamic slicing on data-dependent value is not supported")

super().__init__([start, stop, step], **kwargs)

Expand Down

0 comments on commit 0c3ed2e

Please sign in to comment.