Skip to content

Commit

Permalink
Delete torch/__init__.pyi, deferring to direct extension stubs (pytor…
Browse files Browse the repository at this point in the history
…ch#38157)

Summary:
Pull Request resolved: pytorch#38157

This removes the error prone process of assembling `torch/__init__.pyi`
(and frequently forgetting to expose things), since now we can simply
rely on the true source file to get things done.  Most of the old
codegen in gen_pyi.py is now rerouted to various files:

- `torch/_C/__init__.pyi` (the dumping pile of all misc bindings)
- `torch/_C/_nn.pyi` (NN function bindings)
- `torch/_C/_VariableFunctions.pyi` (torch function bindings)

`torch.types` grew a bunch more definitions that previously where
defined in `torch/__init__.pyi`

Some miscellaneous changes

- Fixed a bug where we treat single TensorList argument as implying
  varargs are accepted. This is actually only supported on IntList.
  This means we can correctly generate a stub for dequantize.
- Add missing manual stub for nonzero
- Switched torch/onnx/operators.py to directly refer to _C module,
  since apparently mypy doesn't think that methods prefixed with
  underscores get reexported.  This may be a recurring theme; maybe
  we need to find a better way to solve it.

Because I was really lazy, I dumped namedtuple definitions in both
`torch._C` and `torch._C._VariableFunctions`.  This is definitely wrong.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Differential Revision: D21497400

Pulled By: ezyang

fbshipit-source-id: 07b126141c82efaca37be27c07255cb2b9b3f064
  • Loading branch information
ezyang authored and facebook-github-bot committed May 11, 2020
1 parent 6f396e1 commit 6edf340
Show file tree
Hide file tree
Showing 19 changed files with 263 additions and 296 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ third_party/build/
tools/shared/_utils_internal.py
torch.egg-info/
torch/__init__.pyi
torch/_C/__init__.pyi
torch/_C/_nn.pyi
torch/_C/_VariableFunctions.pyi
torch/nn/functional.pyi
torch/nn/modules/*.pyi
torch/csrc/autograd/generated/*
Expand Down
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,9 @@ ignore_errors = True
[mypy-torch.nn.intrinsic.qat.modules.conv_fused]
ignore_errors = True

[mypy-torch.onnx.operators]
ignore_errors = True

[mypy-torch.onnx.symbolic_opset8]
ignore_errors = True

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,7 @@ def print_box(msg):
'py.typed',
'bin/*',
'test/*',
'__init__.pyi',
'_C/*.pyi',
'cuda/*.pyi',
'optim/*.pyi',
'autograd/*.pyi',
Expand Down
36 changes: 22 additions & 14 deletions tools/pyi/gen_pyi.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@
'div_out',
'true_divide', 'true_divide_', 'true_divide_out',
'floor_divide', 'floor_divide_', 'floor_divide_out',
'dequantize',
]


Expand Down Expand Up @@ -320,22 +319,19 @@ def generate_type_hints(fname, decls, namedtuples, is_tensor=False):
numargs = len(decl['arguments'])
vararg_pos = int(is_tensor)
have_vararg_version = (numargs > vararg_pos and
decl['arguments'][vararg_pos]['dynamic_type'] in {'IntArrayRef', 'TensorList'} and
decl['arguments'][vararg_pos]['dynamic_type'] in {'IntArrayRef'} and
(numargs == vararg_pos + 1 or python_args[vararg_pos + 1] == '*') and
(not is_tensor or decl['arguments'][0]['name'] == 'self'))

type_hints.append(type_hint)

if have_vararg_version:
# Two things come into play here: PyTorch has the "magic" that if the first and only positional argument
# is an IntArrayRef or TensorList, it will be used as a vararg variant.
# is an IntArrayRef, it will be used as a vararg variant.
# The following outputs the vararg variant, the "pass a list variant" is output above.
# The other thing is that in Python, the varargs are annotated with the element type, not the list type.
typelist = decl['arguments'][vararg_pos]['dynamic_type']
if typelist == 'IntArrayRef':
vararg_type = '_int'
else:
vararg_type = 'Tensor'
vararg_type = '_int'
# replace first argument and eliminate '*' if present
python_args = ((['self'] if is_tensor else []) + ['*' + decl['arguments'][vararg_pos]['name'] +
': ' + vararg_type] + python_args[vararg_pos + 2:])
Expand Down Expand Up @@ -419,6 +415,9 @@ def gen_nn_functional(out):
}
write(out, 'torch/nn/functional.pyi', stubs, env)

stubs = CodeTemplate.from_file(os.path.join('torch', '_C', '_nn.pyi.in'))
write(out, 'torch/_C/_nn.pyi', stubs, env)

def gen_nn_pyi(out):
gen_nn_functional(out)
gen_nn_modules(out)
Expand Down Expand Up @@ -485,7 +484,9 @@ def gen_pyi(declarations_path, out):
'def full(size: _size, fill_value: Number, *,'
' names: List[Union[str, None]], {}) -> Tensor: ...'
.format(FACTORY_PARAMS)],
'is_grad_enabled': ['def is_grad_enabled() -> _bool: ...']
'is_grad_enabled': ['def is_grad_enabled() -> _bool: ...'],
'nonzero': ['def nonzero(input: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...',
'def nonzero(input: Tensor, *, as_tuple: bool=...) -> Tensor: ...'],
})
for binop in ['mul', 'div', 'true_divide', 'floor_divide']:
unsorted_function_hints[binop].append(
Expand Down Expand Up @@ -622,11 +623,14 @@ def gen_pyi(declarations_path, out):
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

# TODO: These are deprecated, maybe we shouldn't type hint them
legacy_class_hints = []
for c in ('DoubleStorage', 'FloatStorage', 'LongStorage', 'IntStorage',
'ShortStorage', 'CharStorage', 'ByteStorage', 'BoolStorage'):
legacy_class_hints.append('class {}(Storage): ...'.format(c))
legacy_storage_base_hints = []
for c in ('Double', 'Float', 'Long', 'Int',
'Short', 'Char', 'Byte', 'Bool',
'Half', 'BFloat16', 'ComplexDouble',
'ComplexFloat', 'QUInt8', 'QInt8', 'QInt32'):
legacy_storage_base_hints.append('class {}StorageBase(object): ...'.format(c))

legacy_class_hints = []
for c in ('DoubleTensor', 'FloatTensor', 'LongTensor', 'IntTensor',
'ShortTensor', 'CharTensor', 'ByteTensor', 'BoolTensor'):
legacy_class_hints.append('class {}(Tensor): ...'.format(c))
Expand All @@ -650,11 +654,15 @@ def gen_pyi(declarations_path, out):
'function_hints': function_hints,
'tensor_method_hints': tensor_method_hints,
'legacy_class_hints': legacy_class_hints,
'legacy_storage_base_hints': legacy_storage_base_hints,
'dtype_class_hints': dtype_class_hints,
}
TORCH_TYPE_STUBS = CodeTemplate.from_file(os.path.join('torch', '__init__.pyi.in'))
TORCH_C_TYPE_STUBS = CodeTemplate.from_file(os.path.join('torch', '_C', '__init__.pyi.in'))
TORCH_C_VARIABLE_FUNCTIONS_TYPE_STUBS = \
CodeTemplate.from_file(os.path.join('torch', '_C', '_VariableFunctions.pyi.in'))

write(out, 'torch/__init__.pyi', TORCH_TYPE_STUBS, env)
write(out, 'torch/_C/__init__.pyi', TORCH_C_TYPE_STUBS, env)
write(out, 'torch/_C/_VariableFunctions.pyi', TORCH_C_VARIABLE_FUNCTIONS_TYPE_STUBS, env)
gen_nn_pyi(out)


Expand Down
9 changes: 6 additions & 3 deletions torch/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -268,23 +268,26 @@ set(ModulesStubOut
${TORCH_SRC_DIR}/nn/modules/upsampling.pyi
)
add_custom_target(torch_python_stubs DEPENDS
"${TORCH_SRC_DIR}/__init__.pyi"
"${TORCH_SRC_DIR}/_C/__init__.pyi"
"${TORCH_SRC_DIR}/_C/_VariableFunctions.pyi"
"${TORCH_SRC_DIR}/nn/functional.pyi"
${ModuleStubOut}
)
# For Declarations.yaml dependency
add_dependencies(torch_python_stubs ATEN_CPU_FILES_GEN_TARGET)
add_custom_command(
OUTPUT
"${TORCH_SRC_DIR}/__init__.pyi"
"${TORCH_SRC_DIR}/_C/__init__.pyi"
"${TORCH_SRC_DIR}/_C/_VariableFunctions.pyi"
"${TORCH_SRC_DIR}/nn/functional.pyi"
${ModuleStubOut}
COMMAND
"${PYTHON_EXECUTABLE}" -mtools.pyi.gen_pyi
--declarations-path "${CMAKE_BINARY_DIR}/aten/src/ATen/Declarations.yaml"
DEPENDS
"${CMAKE_BINARY_DIR}/aten/src/ATen/Declarations.yaml"
"${TORCH_SRC_DIR}/__init__.pyi.in"
"${TORCH_SRC_DIR}/_C/__init__.pyi.in"
"${TORCH_SRC_DIR}/_C/_VariableFunctions.pyi.in"
"${TORCH_SRC_DIR}/nn/functional.pyi.in"
${ModuleStubIn}
"${TOOLS_PATH}/pyi/gen_pyi.py"
Expand Down
14 changes: 14 additions & 0 deletions torch/_C/_VariableFunctions.pyi.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# ${generated_comment}

from torch import Tensor, Generator, strided, memory_format, contiguous_format
from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload, Iterator, NamedTuple, Sequence, TypeVar
from torch._six import inf

from torch.types import _int, _float, _bool, Number, _dtype, _device, _qscheme, _size, _layout

import builtins

# REDUNDANT!
${namedtuple_defs}

${function_hints}
86 changes: 0 additions & 86 deletions torch/_C/__init__.pyi

This file was deleted.

Loading

0 comments on commit 6edf340

Please sign in to comment.