Skip to content

Commit

Permalink
Throw a proper error when parsing local variable annotations without …
Browse files Browse the repository at this point in the history
…assignments (pytorch#34133)

Summary:
Currently, putting `outputs: List[Tensor]` instead of `outputs: List[Tensor] = []` in your JITed code results in:
```
Traceback (most recent call last):
  File "custom_lstms.py", line 453, in <module>
    test_script_stacked_bidir_rnn(5, 2, 3, 7, 4)
  File "custom_lstms.py", line 404, in test_script_stacked_bidir_rnn
    rnn = script_lstm(input_size, hidden_size, num_layers, bidirectional=True)
  File "custom_lstms.py", line 62, in script_lstm
    other_layer_args=[LSTMCell, hidden_size * dirs, hidden_size]))
  File "/home/apaszke/pytorch/torch/jit/__init__.py", line 1267, in script
    return torch.jit._recursive.create_script_module(obj, torch.jit._recursive.infer_methods_to_compile)
  File "/home/apaszke/pytorch/torch/jit/_recursive.py", line 305, in create_script_module
    return create_script_module_impl(nn_module, concrete_type, stubs_fn)
  File "/home/apaszke/pytorch/torch/jit/_recursive.py", line 348, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/home/apaszke/pytorch/torch/jit/__init__.py", line 1612, in _construct
    init_fn(script_module)
  File "/home/apaszke/pytorch/torch/jit/_recursive.py", line 340, in init_fn
    scripted = create_script_module_impl(orig_value, sub_concrete_type, infer_methods_to_compile)
  File "/home/apaszke/pytorch/torch/jit/_recursive.py", line 348, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/home/apaszke/pytorch/torch/jit/__init__.py", line 1612, in _construct
    init_fn(script_module)
  File "/home/apaszke/pytorch/torch/jit/_recursive.py", line 340, in init_fn
    scripted = create_script_module_impl(orig_value, sub_concrete_type, infer_methods_to_compile)
  File "/home/apaszke/pytorch/torch/jit/_recursive.py", line 348, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/home/apaszke/pytorch/torch/jit/__init__.py", line 1612, in _construct
    init_fn(script_module)
  File "/home/apaszke/pytorch/torch/jit/_recursive.py", line 340, in init_fn
    scripted = create_script_module_impl(orig_value, sub_concrete_type, infer_methods_to_compile)
  File "/home/apaszke/pytorch/torch/jit/_recursive.py", line 348, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/home/apaszke/pytorch/torch/jit/__init__.py", line 1612, in _construct
    init_fn(script_module)
  File "/home/apaszke/pytorch/torch/jit/_recursive.py", line 340, in init_fn
    scripted = create_script_module_impl(orig_value, sub_concrete_type, infer_methods_to_compile)
  File "/home/apaszke/pytorch/torch/jit/_recursive.py", line 317, in create_script_module_impl
    stubs = stubs_fn(nn_module)
  File "/home/apaszke/pytorch/torch/jit/_recursive.py", line 511, in infer_methods_to_compile
    stubs.append(make_stub_from_method(nn_module, method))
  File "/home/apaszke/pytorch/torch/jit/_recursive.py", line 41, in make_stub_from_method
    return make_stub(func)
  File "/home/apaszke/pytorch/torch/jit/_recursive.py", line 34, in make_stub
    ast = torch.jit.get_jit_def(func, self_name="RecursiveScriptModule")
  File "/home/apaszke/pytorch/torch/jit/frontend.py", line 173, in get_jit_def
    return build_def(ctx, py_ast.body[0], type_line, self_name)
  File "/home/apaszke/pytorch/torch/jit/frontend.py", line 206, in build_def
    build_stmts(ctx, body))
  File "/home/apaszke/pytorch/torch/jit/frontend.py", line 129, in build_stmts
    stmts = [build_stmt(ctx, s) for s in stmts]
  File "/home/apaszke/pytorch/torch/jit/frontend.py", line 129, in <listcomp>
    stmts = [build_stmt(ctx, s) for s in stmts]
  File "/home/apaszke/pytorch/torch/jit/frontend.py", line 181, in __call__
    return method(ctx, node)
  File "/home/apaszke/pytorch/torch/jit/frontend.py", line 294, in build_AnnAssign
    rhs = build_expr(ctx, stmt.value)
  File "/home/apaszke/pytorch/torch/jit/frontend.py", line 180, in __call__
    raise UnsupportedNodeError(ctx, node)
  File "/home/apaszke/pytorch/torch/jit/frontend.py", line 116, in __init__
    source_range = ctx.make_range(offending_node.lineno,
AttributeError: 'NoneType' object has no attribute 'lineno'
```

This patch makes the error message more reasonable:
```
torch.jit.frontend.UnsupportedNodeError: annotated assignments without assigned value aren't supported:
  File "custom_lstms.py", line 221
        # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
        inputs = reverse(input.unbind(0))
        outputs: List[Tensor]
        ~ <--- HERE
        for i in range(len(inputs)):
            out, state = self.cell(inputs[i], state)
```
Pull Request resolved: pytorch#34133

Differential Revision: D20249076

Pulled By: ezyang

fbshipit-source-id: 40ec34ad38859f9fe56f379d3f8d08644b00fab9
  • Loading branch information
apaszke authored and facebook-github-bot committed Mar 5, 2020
1 parent ed11e25 commit 3a4bac5
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions torch/jit/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,15 @@ class NotSupportedError(FrontendError):


class UnsupportedNodeError(NotSupportedError):
def __init__(self, ctx, offending_node):
def __init__(self, ctx, offending_node, reason=''):
# If we don't have a specific token, we default to length of 1
node_type = type(offending_node)
range_len = len(node_start_tokens.get(node_type, ' '))
source_range = ctx.make_range(offending_node.lineno,
offending_node.col_offset,
offending_node.col_offset + range_len)
feature_name = pretty_node_names.get(node_type, node_type.__name__)
msg = "{} aren't supported".format(feature_name)
msg = "{} {}aren't supported".format(feature_name, reason + ' ' if reason else '')
super(UnsupportedNodeError, self).__init__(source_range, msg)


Expand Down Expand Up @@ -291,6 +291,8 @@ def build_Assign(ctx, stmt):

@staticmethod
def build_AnnAssign(ctx, stmt):
if stmt.value is None:
raise UnsupportedNodeError(ctx, stmt, reason='without assigned value')
rhs = build_expr(ctx, stmt.value)
lhs = build_expr(ctx, stmt.target)
the_type = build_expr(ctx, stmt.annotation)
Expand Down

0 comments on commit 3a4bac5

Please sign in to comment.