From bf88a4dad5b3c06f43a9ce07a51ec0ab0dbfe2d3 Mon Sep 17 00:00:00 2001 From: Yanan Cao Date: Tue, 9 Mar 2021 00:00:14 -0800 Subject: [PATCH] Support parsing Ellipsis in JIT frontend (#53576) Summary: De-sugars `Ellipsis` into dots (`...`) Fixes https://github.com/pytorch/pytorch/issues/53517 Pull Request resolved: https://github.com/pytorch/pytorch/pull/53576 Reviewed By: pbelevich Differential Revision: D26904361 Pulled By: gmagogsfm fbshipit-source-id: 5b23e049a075a9a99e37dcb47a9410b6f82a6fb7 --- test/test_jit.py | 30 ++++++++++++++++++++++++++++++ torch/jit/frontend.py | 4 ++++ 2 files changed, 34 insertions(+) diff --git a/test/test_jit.py b/test/test_jit.py index ff8b35b1e734a..6e9f4e84d7e10 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -10756,6 +10756,36 @@ def forward(self, x): self.checkModule(C(), (torch.tensor(1),)) + def test_ellipsis_const_mid(self): + def ellipsize(x): + # type: (Tensor) -> List[int] + return x[2, Ellipsis, 0:4, 4:8].size() # noqa T484 + + dummy = torch.zeros(8, 8, 8, 8, 8) + self.checkScript(ellipsize, (dummy,), optimize=True) + + def test_ellipsis_const_mid_select(self): + def ellipsize(x): + # type: (Tensor) -> List[int] + return x[2, Ellipsis, 4, 4, 4:8, 2].size() # noqa T484 + + dummy = torch.zeros(8, 8, 8, 8, 8, 8, 8) + self.checkScript(ellipsize, (dummy,), optimize=True) + + def test_ellipsis_const_start(self): + def ellipsize(x): + # type: (Tensor) -> List[int] + return x[Ellipsis, 0:4, 4:8].size() # noqa T484 + dummy = torch.zeros(8, 8, 8, 8, 8) + self.checkScript(ellipsize, (dummy,), optimize=True) + + def test_ellipsis_const_end(self): + def ellipsize(x): + # type: (Tensor) -> List[int] + return x[0:4, 2, Ellipsis].size() # noqa T484 + dummy = torch.zeros(8, 8, 8, 8, 8) + self.checkScript(ellipsize, (dummy,), optimize=True) + def test_ellipsis_mid(self): def ellipsize(x): # type: (Tensor) -> List[int] diff --git a/torch/jit/frontend.py b/torch/jit/frontend.py index f7b63e6ae71a0..cfcc91e57e77f 100644 --- a/torch/jit/frontend.py +++ b/torch/jit/frontend.py @@ -617,6 +617,8 @@ def build_Name(ctx, expr): return FalseLiteral(r) elif expr.id == "None": return NoneLiteral(r) + elif expr.id == "Ellipsis": + return Dots(r) return Var(Ident(r, expr.id)) @staticmethod @@ -628,6 +630,8 @@ def build_NameConstant(ctx, expr): return FalseLiteral(r) elif expr.value is None: return NoneLiteral(r) + elif expr.value == Ellipsis: + return Dots(r) else: raise ValueError("Name constant value unsupported: " + str(expr.value))