Skip to content

Commit

Permalink
Support parsing Ellipsis in JIT frontend (pytorch#53576)
Browse files Browse the repository at this point in the history
Summary:
De-sugars `Ellipsis` into dots (`...`)

Fixes pytorch#53517

Pull Request resolved: pytorch#53576

Reviewed By: pbelevich

Differential Revision: D26904361

Pulled By: gmagogsfm

fbshipit-source-id: 5b23e049a075a9a99e37dcb47a9410b6f82a6fb7
  • Loading branch information
gmagogsfm authored and facebook-github-bot committed Mar 9, 2021
1 parent c2ccb35 commit bf88a4d
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
30 changes: 30 additions & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -10756,6 +10756,36 @@ def forward(self, x):

self.checkModule(C(), (torch.tensor(1),))

def test_ellipsis_const_mid(self):
def ellipsize(x):
# type: (Tensor) -> List[int]
return x[2, Ellipsis, 0:4, 4:8].size() # noqa T484

dummy = torch.zeros(8, 8, 8, 8, 8)
self.checkScript(ellipsize, (dummy,), optimize=True)

def test_ellipsis_const_mid_select(self):
def ellipsize(x):
# type: (Tensor) -> List[int]
return x[2, Ellipsis, 4, 4, 4:8, 2].size() # noqa T484

dummy = torch.zeros(8, 8, 8, 8, 8, 8, 8)
self.checkScript(ellipsize, (dummy,), optimize=True)

def test_ellipsis_const_start(self):
def ellipsize(x):
# type: (Tensor) -> List[int]
return x[Ellipsis, 0:4, 4:8].size() # noqa T484
dummy = torch.zeros(8, 8, 8, 8, 8)
self.checkScript(ellipsize, (dummy,), optimize=True)

def test_ellipsis_const_end(self):
def ellipsize(x):
# type: (Tensor) -> List[int]
return x[0:4, 2, Ellipsis].size() # noqa T484
dummy = torch.zeros(8, 8, 8, 8, 8)
self.checkScript(ellipsize, (dummy,), optimize=True)

def test_ellipsis_mid(self):
def ellipsize(x):
# type: (Tensor) -> List[int]
Expand Down
4 changes: 4 additions & 0 deletions torch/jit/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,8 @@ def build_Name(ctx, expr):
return FalseLiteral(r)
elif expr.id == "None":
return NoneLiteral(r)
elif expr.id == "Ellipsis":
return Dots(r)
return Var(Ident(r, expr.id))

@staticmethod
Expand All @@ -628,6 +630,8 @@ def build_NameConstant(ctx, expr):
return FalseLiteral(r)
elif expr.value is None:
return NoneLiteral(r)
elif expr.value == Ellipsis:
return Dots(r)
else:
raise ValueError("Name constant value unsupported: " + str(expr.value))

Expand Down

0 comments on commit bf88a4d

Please sign in to comment.