Skip to content

Commit

Permalink
use flake8-mypy (pytorch#17721)
Browse files Browse the repository at this point in the history
Summary:
Use flake8 installed with mypy checks so that our linter matches fbcode. Mypy type errors also provide valuable signal
Pull Request resolved: pytorch#17721

Differential Revision: D14357778

Pulled By: eellison

fbshipit-source-id: d8c9ea3fe3b5f550c3b70fe259e0eabf95e4c92d
  • Loading branch information
Elias Ellison authored and facebook-github-bot committed Mar 7, 2019
1 parent 1d52259 commit 561037a
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 55 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[flake8]
max-line-length = 120
ignore = E203,E305,E402,E721,E741,F401,F403,F405,F821,F841,F999,W503,W504
exclude = docs/src,venv,third_party,caffe2,scripts,docs/caffe2,tools/amd_build/pyHIPIFY,torch/lib/include,torch/lib/tmp_install
exclude = docs/src,venv,third_party,caffe2,scripts,docs/caffe2,tools/amd_build/pyHIPIFY,torch/lib/include,torch/lib/tmp_install,build,torch/include
2 changes: 1 addition & 1 deletion .travis.aten.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,5 @@ matrix:
include:
env: LINT_CHECK
python: "2.7"
install: pip install flake8
install: pip install flake8-mypy
script: flake8
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ matrix:
python: "3.7"
dist: xenial # required for Python 3.7 (travis-ci/travis-ci#9069)
sudo: required # required for Python 3.7 (travis-ci/travis-ci#9069)
install: pip install flake8
install: pip install flake8-mypy
script: flake8
- name: "MyPy typecheck"
python: "3.6"
Expand Down
87 changes: 45 additions & 42 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
import torch.nn.functional as F
import torch.nn.parallel as dp
import torch.optim as optim
import torch.cuda
import torch.jit.quantized
from contextlib import contextmanager
from itertools import product, chain
import torch.jit.frontend
from torch.autograd import Variable, Function
from torch.nn import Module
from torch.autograd.function import traceable
from torch.testing import assert_allclose
from torch.onnx import OperatorExportTypes
Expand Down Expand Up @@ -44,9 +46,11 @@
ListType, StringType, DictType
from copy import deepcopy
import random
from typing import List, Dict, Optional
from typing import List, Dict, Optional, Tuple
from torch.jit.frontend import NotSupportedError
from torch.jit import BatchTensor
from torch import Tensor
from torch.jit.annotations import BroadcastingList2, BroadcastingList3

# For testing truediv in python 2
from test_module.future_div import div_int_future, div_float_future
Expand Down Expand Up @@ -96,7 +100,7 @@ def TemporaryFileName():
finally:
os.unlink(f.name)
else:
@contextmanager
@contextmanager # noqa: T484
def TemporaryFileName():
with tempfile.NamedTemporaryFile() as f:
yield f.name
Expand Down Expand Up @@ -2262,7 +2266,7 @@ def hints(x, a=0.5, b=10):
with self.assertRaisesRegex(RuntimeError, "Expected a default value"):

@torch.jit.script
def hints_bad_types(x, a=10, b=0.5):
def hints_bad_types(x, a=10, b=0.5): # noqa: T484
# type: (Tensor, float, int) -> Tensor
return x + a + b

Expand Down Expand Up @@ -3113,7 +3117,7 @@ def test_sum_list_wrong_type(self):
def sum_list(a):
# type: (int) -> int
sum = 0
for i in a:
for i in a: # noqa: T484
sum += i

return sum
Expand Down Expand Up @@ -4727,39 +4731,39 @@ def test_manual_unwrap_opt(x):
x = 1
else:
x = torch.jit._unwrap_optional(x)
return x
return x # noqa: T484

with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
@torch.jit.script
def or_error(x, y):
# type: (Optional[int], Optional[int]) -> int
# type: (Optional[int], Optional[int]) -> None
if x is None or y is None:
print(x + y)
print(x + y) # noqa: T484

with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
@torch.jit.script
def and_error(x, y):
# type: (Optional[int], Optional[int]) -> int
# type: (Optional[int], Optional[int]) -> None
if x is None and y is None:
pass
else:
print(x + y)
print(x + y) # noqa: T484

with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
@torch.jit.script
def named_var(x):
# type: (Optional[int]) -> None
x_none = x is not None
if x_none:
print(x + 1)
print(x + 1) # noqa: T484

with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
@torch.jit.script
def named_var_and(x, y):
# type: (Optional[int], Optional[int]) -> None
x_none = x is not None
if y is not None and x_none:
print(x + y)
print(x + y) # noqa: T484

def test_while_write_outer_then_read(self):
def func(a, b):
Expand Down Expand Up @@ -5057,10 +5061,11 @@ def multiple_returns(a):
self.checkScript(multiple_returns, [a], optimize=True)

with self.assertRaisesRegex(RuntimeError, "but is actually of type None"):
@torch.jit.script
torch.jit.CompilationUnit('''
def no_return_bad_annotation(a):
# type: (Tensor) -> Tensor
a + 1
''')

def test_error(self):
@torch.jit.script
Expand Down Expand Up @@ -5654,8 +5659,6 @@ def test_rnn_cell_quantized(self):
hiddens = hx

if isinstance(cell, torch.jit.quantized.QuantizedLSTMCell):
from typing import Tuple

class ScriptWrapper(torch.jit.ScriptModule):
def __init__(self, cell):
super(ScriptWrapper, self).__init__()
Expand Down Expand Up @@ -6650,7 +6653,7 @@ def test_python_call_non_tensor_wrong(self):
with self.assertRaisesRegex(RuntimeError, r"but instead got value of type tuple"):
def foo():
# type: () -> Tensor
return ((3, 4),)
return ((3, 4),) # noqa: T484

@torch.jit.script
def bar():
Expand Down Expand Up @@ -6769,7 +6772,7 @@ def list_optional_fails(x):
if x:
y = [1]
else:
y = [None]
y = [None] # noqa: T484
return y[0]

@torch.jit.script
Expand Down Expand Up @@ -6815,18 +6818,18 @@ def int_fn_call():
print(int_fn((1, 1, 1)))

with self.assertRaisesRegex(RuntimeError, "must be a positive integer:"):
@torch.jit.script
@torch.jit.script # noqa: T484
def fn(x):
# type: (BroadcastingListx[int]) -> List[int]
# type: (BroadcastingListx[int]) -> List[int] # noqa: T484
return x

# TODO: the type comment in this seems to trip up flake8 for some reason
# even though we have a noqa comment. Figure out why
# using CU so that flake8 error on int[2] is not raised (noqa not working)
with self.assertRaisesRegex(RuntimeError, "Unknown type constructor"):
@torch.jit.script
def nested(x, y):
# type: (int, Tuple[int, int[2]]) -> List[int] # noqa: T484
return x
cu = torch.jit.CompilationUnit('''
def nested(x, y):
# type: (int, Tuple[int, int[2]]) -> List[int]
return x # noqa: T484
''')

def test_ntuple_builtins(self):
from torch.nn.modules.utils import _single, _pair, _triple, _quadruple
Expand Down Expand Up @@ -8349,7 +8352,7 @@ def test_wrong_return_type(self):
with self.assertRaisesRegex(RuntimeError, 'but instead got value of type tuple'):
def somefunc():
# type: () -> Tuple[Tuple[Tensor, Tensor]]
return torch.zeros(3, 4), torch.zeros(4, 5)
return torch.zeros(3, 4), torch.zeros(4, 5) # noqa: T484

@torch.jit.script
def wrong_return_type():
Expand Down Expand Up @@ -9029,7 +9032,7 @@ def test_unwrap_optional_builtin(self):
def test(x):
# type: (Optional[int]) -> int
x = torch.jit._unwrap_optional(x)
x = x + x
x = x + x # noqa: T484
return x

self.checkScript(test, (3,))
Expand Down Expand Up @@ -9082,14 +9085,14 @@ def test_annotated_script_fn_return_mismatch(self):
@torch.jit.script
def return_tup(x):
# type: (Tensor) -> Tuple[Tuple[Tensor, Tensor], Tensor]
return x, x
return x, x # noqa: T484

def test_annotated_script_fn_arg_mismatch(self):
with self.assertRaisesRegex(RuntimeError, r"arguments for call are not valid"):
@torch.jit.script
def tuple_arg(x):
# type: (Tuple[Tensor, Tensor]) -> Tensor
return x + 1
return x + 1 # noqa: T484

def test_script_non_tensor_args_outputs(self):
@torch.jit.script
Expand Down Expand Up @@ -13122,11 +13125,11 @@ def wait_script(x):
self.assertEqual(y, y_hat)

def test_async_script_capture(self):
class Module(torch.jit.ScriptModule):
class Mod(torch.jit.ScriptModule):
__constants__ = ['const']

def __init__(self):
super(Module, self).__init__(False)
super(Mod, self).__init__(False)
self.const = 42
self.param = nn.Parameter(torch.randn(2, 2))

Expand All @@ -13144,7 +13147,7 @@ def wait_script(self, x1, x2):
x1 = torch.rand(3, 4)
x2 = torch.rand(5, 6)

m = Module()
m = Mod()
y, y_hat = m.wait_script(x1, x2)

self.assertEqual(y, y_hat)
Expand Down Expand Up @@ -13244,9 +13247,9 @@ def __init__(self):
def forward(self, x):
return (torch.neg(x), x)

class Module(torch.jit.ScriptModule):
class Mod(torch.jit.ScriptModule):
def __init__(self):
super(Module, self).__init__(False)
super(Mod, self).__init__(False)
x = torch.rand(3, 3)
self.traced = torch.jit.trace(Traced(), (x), _force_outplace=True)

Expand All @@ -13266,10 +13269,10 @@ def forward(self, x):
# return a nested structure of tensors
return (tensor_list, tensor_tuple, tensor_tuple[1])

class Tuple(nn.Module):
class TupleCl(nn.Module):
def __init__(self):
super(Tuple, self).__init__()
self.module = Module()
super(TupleCl, self).__init__()
self.module = Mod()

def forward(self, x):
z = torch.neg(x)
Expand All @@ -13278,7 +13281,7 @@ def forward(self, x):
return tuple(list)

x = torch.rand(3, 3)
module = torch.jit.trace(Tuple(), (x), _force_outplace=True)
module = torch.jit.trace(TupleCl(), (x), _force_outplace=True)

# Make sure we have forks
self.assertGraphContainsExactly(module.graph, kind='prim::fork', num_kind_nodes=2)
Expand Down Expand Up @@ -13632,16 +13635,16 @@ def test_set_attr_in_method(self):
@torch.jit.script
class FooTest:
def __init__(self, x):
# type: (int)
# type: (int) -> None
self.foo = x

def incFooTest(self, y):
# type: (int)
# type: (int) -> None
self.foo = self.foo + y

@torch.jit.script
def fn(x):
# type: (int)
# type: (int) -> int
foo = FooTest(x)
foo.incFooTest(2)
return foo.foo
Expand Down Expand Up @@ -13689,7 +13692,7 @@ def test_type_annotations(self):
@torch.jit.script
class FooTest:
def __init__(self, x):
# type: (bool)
# type: (bool) -> None
self.foo = x

@torch.jit.script
Expand Down Expand Up @@ -13718,7 +13721,7 @@ def __init__(self, x):

@torch.jit.script
def fn(foo):
# type: (FooTest)
# type: (FooTest) -> Tensor
return foo.attr

@torch.jit.script
Expand Down
18 changes: 9 additions & 9 deletions torch/_jit_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,24 @@
from torch._six import builtins

# Tracks standalone weak script functions
compiled_weak_fns = weakref.WeakKeyDictionary()
compiled_weak_fns = weakref.WeakKeyDictionary() # noqa: T484

# Tracks which methods should be converted to strong methods
weak_script_methods = weakref.WeakKeyDictionary()
weak_script_methods = weakref.WeakKeyDictionary() # noqa: T484

# Converted modules and their corresponding WeakScriptModuleProxy objects
weak_modules = weakref.WeakKeyDictionary()
weak_modules = weakref.WeakKeyDictionary() # noqa: T484

# Types that have been declared as weak modules
weak_types = weakref.WeakKeyDictionary()
weak_types = weakref.WeakKeyDictionary() # noqa: T484

# Wrapper functions that can call either of 2 functions depending on a boolean
# argument
boolean_dispatched = weakref.WeakKeyDictionary()
boolean_dispatched = weakref.WeakKeyDictionary() # noqa: T484

# Python Op functions that should be ignored by the compiler. These will be replaced
# with an operator that always throws an error
ignored_fns = weakref.WeakSet()
ignored_fns = weakref.WeakSet() # noqa: T484

COMPILATION_PENDING = object()
COMPILED = object()
Expand Down Expand Up @@ -223,9 +223,9 @@ class DictCls(object):
def __getitem__(self, types):
return DictInstance(types)

Tuple = TupleCls()
List = ListCls()
Dict = DictCls()
Tuple = TupleCls() # noqa: T484
List = ListCls() # noqa: T484
Dict = DictCls() # noqa: T484

def is_tuple(ann):
return isinstance(ann, TupleInstance)
Expand Down
4 changes: 3 additions & 1 deletion torch/jit/quantized.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import torch
import copy
import numbers
from typing import Tuple
from typing import Tuple, Optional
from torch import Tensor
from torch.jit import ScriptModule

from torch.nn.utils.rnn import PackedSequence
from torch.nn import _VF
Expand Down

0 comments on commit 561037a

Please sign in to comment.