Skip to content

Commit

Permalink
Move all torch.nn.modules type annotations inline (pytorch#38211)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#38211

Just because the annotations are inline doesn't mean the files type
check; most of the newly annotated files have type errors and I
added exclusions for them in mypy.ini.  The payoff of moving
all of these modules inline is I can delete the relevant code
generation logic for the pyi files (which was added ignore
annotations that weren't actually relevant anymore.)

For the most part the translation was completely mechanical, but there
were two hairy issues.  First, I needed to work around a Python 3.6 and
earlier bug where Generic has a nontrivial metaclass.  This fix is in
torch/jit/__init__.py.  Second, module.py, we need to apply the same
fix for avoiding contravariance checks that the pyi file used to have;
this is done by declaring forward as a variable (rather than a
function), which appears to be sufficient enough to get mypy to not
contravariantly check input arguments.

Because we aren't actually typechecking these modules in most
cases, it is inevitable that some of these type annotations are wrong.
I slavishly copied the old annotations from the pyi files unless there
was an obvious correction I could make.  These annotations will probably
need fixing up later.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Differential Revision: D21497397

Pulled By: ezyang

fbshipit-source-id: 2b08bacc152c48f074e7edc4ee5dce1b77d83702
  • Loading branch information
ezyang authored and facebook-github-bot committed Jun 11, 2020
1 parent e22dd56 commit eace053
Show file tree
Hide file tree
Showing 58 changed files with 1,048 additions and 2,141 deletions.
25 changes: 20 additions & 5 deletions aten/src/ATen/native/quantized/cpu/qbatch_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,19 @@ void compute_fused_params(
template <bool ReluFused>
Tensor q_batch_norm2d_impl(
Tensor qx,
Tensor weight,
Tensor bias,
c10::optional<Tensor> mb_weight,
c10::optional<Tensor> mb_bias,
Tensor mean,
Tensor var,
double eps,
double output_scale,
int64_t output_zero_point) {

TORCH_CHECK(mb_weight.has_value(), "Weight must be provided");
TORCH_CHECK(mb_bias.has_value(), "Bias must be provided");
const auto& weight = *mb_weight;
const auto& bias = *mb_bias;

if (qx.numel() == 0) {
auto out = qx.clone();
return out;
Expand Down Expand Up @@ -131,14 +136,20 @@ Tensor q_batch_norm2d_impl(
template <bool ReluFused>
Tensor q_batch_norm3d_impl(
Tensor qx,
Tensor weight,
Tensor bias,
c10::optional<Tensor> mb_weight,
c10::optional<Tensor> mb_bias,
Tensor mean,
Tensor var,
double eps,
double output_scale,
int64_t output_zero_point) {

TORCH_CHECK(mb_weight.has_value(), "Weight must be provided")
TORCH_CHECK(mb_bias.has_value(), "Bias must be provided")

const auto& weight = *mb_weight;
const auto& bias = *mb_bias;

if (qx.numel() == 0) {
auto out = qx.clone();
return out;
Expand Down Expand Up @@ -231,8 +242,12 @@ Tensor quantized_batch_norm(
double output_scale,
int64_t output_zero_point) {
Tensor qy;
// TODO: this should arguably support 3d as well
qy = q_batch_norm2d_impl<false>(
qx, weight, bias, mean, var, eps, output_scale, output_zero_point);
qx,
weight.defined() ? c10::make_optional(weight) : c10::nullopt,
bias.defined() ? c10::make_optional(bias) : c10::nullopt,
mean, var, eps, output_scale, output_zero_point);
return qy;
}

Expand Down
28 changes: 19 additions & 9 deletions aten/src/ATen/native/quantized/cpu/qnormalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,33 +123,43 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
m.impl("layer_norm", [](
Tensor input,
std::vector<int64_t> normalized_shape, // because IntArrayRef doesn't work
Tensor weight /* optional */,
Tensor bias /* optional */,
c10::optional<Tensor> weight,
c10::optional<Tensor> bias,
double eps,
double output_scale,
int64_t output_zero_point) {
return quantized_layer_norm_impl(input, normalized_shape, weight, bias, eps, output_scale, output_zero_point);
return quantized_layer_norm_impl(
input, normalized_shape,
weight.has_value() ? *weight : Tensor(),
bias.has_value() ? *bias : Tensor(),
eps, output_scale, output_zero_point);
});
m.impl("group_norm", [](
Tensor qx,
int64_t num_groups,
Tensor weight,
Tensor bias,
c10::optional<Tensor> weight,
c10::optional<Tensor> bias,
double eps,
double output_scale,
int64_t output_zero_point) {
return quantized_group_norm_impl(
qx, num_groups, weight, bias, eps, output_scale, output_zero_point);
qx, num_groups,
weight.has_value() ? *weight : Tensor(),
bias.has_value() ? *bias : Tensor(),
eps, output_scale, output_zero_point);
});
m.impl("instance_norm", [](
Tensor qx,
Tensor weight,
Tensor bias,
c10::optional<Tensor> weight,
c10::optional<Tensor> bias,
double eps,
double output_scale,
int64_t output_zero_point) {
return quantized_instance_norm_impl(
qx, weight, bias, eps, output_scale, output_zero_point);
qx,
weight.has_value() ? *weight : Tensor(),
bias.has_value() ? *bias : Tensor(),
eps, output_scale, output_zero_point);
});
}

Expand Down
14 changes: 7 additions & 7 deletions aten/src/ATen/native/quantized/library.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ TORCH_LIBRARY(quantized, m) {
m.def("add_scalar_relu(Tensor qa, Scalar b) -> Tensor qc");
m.def("add_scalar_out(Tensor qa, Scalar b, Tensor(a!) out) -> Tensor(a!) out");
m.def("add_scalar_relu_out(Tensor qa, Scalar b, Tensor(a!) out) -> Tensor(a!) out");
m.def("batch_norm2d(Tensor qx, Tensor weight, Tensor bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor");
m.def("batch_norm2d_relu(Tensor qx, Tensor weight, Tensor bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor");
m.def("batch_norm3d(Tensor qx, Tensor weight, Tensor bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor");
m.def("batch_norm3d_relu(Tensor qx, Tensor weight, Tensor bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor");
m.def("batch_norm2d(Tensor qx, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor");
m.def("batch_norm2d_relu(Tensor qx, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor");
m.def("batch_norm3d(Tensor qx, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor");
m.def("batch_norm3d_relu(Tensor qx, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor");
m.def("clamp(Tensor qx, Scalar? min, Scalar? max) -> Tensor qy");
m.def("threshold(Tensor qx, Scalar threshold, Scalar value) -> Tensor qy");
m.def("cat(Tensor[] qx, int dim, float? scale, int? zero_point) -> Tensor");
Expand Down Expand Up @@ -64,9 +64,9 @@ TORCH_LIBRARY(quantized, m) {
m.def("conv3d_dilation(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> int[]");
m.def("conv3d_groups(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> int");
m.def("hardswish(Tensor input, float output_scale, int output_zero_point) -> Tensor");
m.def("group_norm(Tensor input, int num_groups, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point) -> Tensor");
m.def("instance_norm(Tensor input, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point) -> Tensor");
m.def("layer_norm(Tensor input, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point) -> Tensor");
m.def("group_norm(Tensor input, int num_groups, Tensor? weight, Tensor? bias, float eps, float output_scale, int output_zero_point) -> Tensor");
m.def("instance_norm(Tensor input, Tensor? weight, Tensor? bias, float eps, float output_scale, int output_zero_point) -> Tensor");
m.def("layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps, float output_scale, int output_zero_point) -> Tensor");
m.def(
"linear(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, float Y_scale_i, int Y_zero_point_i) -> Tensor Y");
m.def(
Expand Down
45 changes: 45 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,51 @@ ignore_errors = True
[mypy-torch._tensor_str]
ignore_errors = True

[mypy-torch.nn.modules.activation]
ignore_errors = True

[mypy-torch.nn.modules.batchnorm]
ignore_errors = True

[mypy-torch.nn.modules.container]
ignore_errors = True

[mypy-torch.nn.modules.conv]
ignore_errors = True

[mypy-torch.nn.modules.fold]
ignore_errors = True

[mypy-torch.nn.modules.instancenorm]
ignore_errors = True

[mypy-torch.nn.modules.linear]
ignore_errors = True

[mypy-torch.nn.modules.loss]
ignore_errors = True

[mypy-torch.nn.modules.module]
ignore_errors = True

[mypy-torch.nn.modules.normalization]
ignore_errors = True

[mypy-torch.nn.modules.padding]
ignore_errors = True

[mypy-torch.nn.modules.pooling]
ignore_errors = True

[mypy-torch.nn.modules.rnn]
ignore_errors = True

[mypy-torch.nn.modules.sparse]
ignore_errors = True

[mypy-torch.nn.modules.upsampling]
ignore_errors = True

[mypy-torch.nn.parallel._functions]
ignore_errors = True

Expand Down
20 changes: 2 additions & 18 deletions test/type_hint_tests/module_list.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,9 @@
from typing import Iterable

import torch

# ModuleList with elements of a specific type
# ModuleList with elements of type Module
class FooModule(torch.nn.Module):
def ten(self) -> int:
return 10

class FooCollector(torch.nn.Module):
def __init__(self, ml: Iterable[FooModule]) -> None:
super(FooCollector, self).__init__()
self.ml: torch.nn.ModuleList[FooModule] = torch.nn.ModuleList(ml)

def foo_sum(self) -> int:
return sum(foo.ten() for foo in self.ml)

collector = FooCollector([FooModule(), FooModule()])
twenty = collector.foo_sum()
twenty == 20
pass

# ModuleList with elements of type Module
class BarModule(torch.nn.Module):
pass

Expand Down
22 changes: 0 additions & 22 deletions tools/pyi/gen_pyi.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import print_function
import os
import collections
import glob
import yaml
import re
import argparse
Expand Down Expand Up @@ -342,26 +341,6 @@ def generate_type_hints(fname, decls, namedtuples, is_tensor=False):

return type_hints

def gen_nn_modules(out):
def replace_forward(m):
# We instruct mypy to not emit errors for the `forward` and `__call__` declarations since mypy
# would otherwise correctly point out that Module's descendants' `forward` declarations
# conflict with `Module`s. Specifically, `Module` defines `forward(self, *args)` while the
# descandantes define more specific forms, such as `forward(self, input: Tensor)`, which
# violates Liskov substitutability. The 'mypy' team recommended this solution for now.
forward_def = m.group(0) + " # type: ignore"
call_def = re.sub(r'def forward', 'def __call__', forward_def)
new_def = "{}\n{}".format(forward_def, call_def)
return new_def
pattern = re.compile(r'^\s*def forward\(self.*$', re.MULTILINE)
for fname in glob.glob("torch/nn/modules/*.pyi.in"):
with open(fname, 'r') as f:
src = f.read()
res = pattern.sub(replace_forward, src)
fname_out = fname[:-3]
with open(os.path.join(out, fname_out), 'w') as f:
f.write(res)

def gen_nn_functional(out):
# Functions imported into `torch.nn.functional` from `torch`, perhaps being filtered
# through an `_add_docstr` call
Expand Down Expand Up @@ -421,7 +400,6 @@ def gen_nn_functional(out):

def gen_nn_pyi(out):
gen_nn_functional(out)
gen_nn_modules(out)

def gen_pyi(declarations_path, out):
"""gen_pyi()
Expand Down
52 changes: 0 additions & 52 deletions torch/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -222,60 +222,10 @@ endif()
# upsampling
# )
# list(TRANSFORM Modules PREPEND "${TORCH_SRC_DIR}/nn/modules/")
# set(ModuleStubIn ${Modules})
# set(ModuleStubOut ${Modules})
# list(TRANSFORM ModuleStubIn APPEND ".pyi.in")
# list(TRANSFORM ModuleStubOut APPEND ".pyi")
set(ModulesStubIn
${TORCH_SRC_DIR}/nn/modules/__init__.pyi.in
${TORCH_SRC_DIR}/nn/modules/activation.pyi.in
${TORCH_SRC_DIR}/nn/modules/adaptive.pyi.in
${TORCH_SRC_DIR}/nn/modules/batchnorm.pyi.in
${TORCH_SRC_DIR}/nn/modules/container.pyi.in
${TORCH_SRC_DIR}/nn/modules/conv.pyi.in
${TORCH_SRC_DIR}/nn/modules/distance.pyi.in
${TORCH_SRC_DIR}/nn/modules/dropout.pyi.in
${TORCH_SRC_DIR}/nn/modules/fold.pyi.in
${TORCH_SRC_DIR}/nn/modules/flatten.pyi.in
${TORCH_SRC_DIR}/nn/modules/instancenorm.pyi.in
${TORCH_SRC_DIR}/nn/modules/linear.pyi.in
${TORCH_SRC_DIR}/nn/modules/loss.pyi.in
${TORCH_SRC_DIR}/nn/modules/module.pyi.in
${TORCH_SRC_DIR}/nn/modules/normalization.pyi.in
${TORCH_SRC_DIR}/nn/modules/padding.pyi.in
${TORCH_SRC_DIR}/nn/modules/pixelshuffle.pyi.in
${TORCH_SRC_DIR}/nn/modules/pooling.pyi.in
${TORCH_SRC_DIR}/nn/modules/rnn.pyi.in
${TORCH_SRC_DIR}/nn/modules/sparse.pyi.in
${TORCH_SRC_DIR}/nn/modules/upsampling.pyi.in
)
set(ModulesStubOut
${TORCH_SRC_DIR}/nn/modules/__init__.pyi
${TORCH_SRC_DIR}/nn/modules/activation.pyi
${TORCH_SRC_DIR}/nn/modules/adaptive.pyi
${TORCH_SRC_DIR}/nn/modules/batchnorm.pyi
${TORCH_SRC_DIR}/nn/modules/container.pyi
${TORCH_SRC_DIR}/nn/modules/conv.pyi
${TORCH_SRC_DIR}/nn/modules/distance.pyi
${TORCH_SRC_DIR}/nn/modules/dropout.pyi
${TORCH_SRC_DIR}/nn/modules/fold.pyi
${TORCH_SRC_DIR}/nn/modules/instancenorm.pyi
${TORCH_SRC_DIR}/nn/modules/linear.pyi
${TORCH_SRC_DIR}/nn/modules/loss.pyi
${TORCH_SRC_DIR}/nn/modules/module.pyi
${TORCH_SRC_DIR}/nn/modules/normalization.pyi
${TORCH_SRC_DIR}/nn/modules/padding.pyi
${TORCH_SRC_DIR}/nn/modules/pixelshuffle.pyi
${TORCH_SRC_DIR}/nn/modules/pooling.pyi
${TORCH_SRC_DIR}/nn/modules/rnn.pyi
${TORCH_SRC_DIR}/nn/modules/sparse.pyi
${TORCH_SRC_DIR}/nn/modules/upsampling.pyi
)
add_custom_target(torch_python_stubs DEPENDS
"${TORCH_SRC_DIR}/_C/__init__.pyi"
"${TORCH_SRC_DIR}/_C/_VariableFunctions.pyi"
"${TORCH_SRC_DIR}/nn/functional.pyi"
${ModuleStubOut}
)
# For Declarations.yaml dependency
add_dependencies(torch_python_stubs ATEN_CPU_FILES_GEN_TARGET)
Expand All @@ -284,7 +234,6 @@ add_custom_command(
"${TORCH_SRC_DIR}/_C/__init__.pyi"
"${TORCH_SRC_DIR}/_C/_VariableFunctions.pyi"
"${TORCH_SRC_DIR}/nn/functional.pyi"
${ModuleStubOut}
COMMAND
"${PYTHON_EXECUTABLE}" -mtools.pyi.gen_pyi
--declarations-path "${CMAKE_BINARY_DIR}/aten/src/ATen/Declarations.yaml"
Expand All @@ -293,7 +242,6 @@ add_custom_command(
"${TORCH_SRC_DIR}/_C/__init__.pyi.in"
"${TORCH_SRC_DIR}/_C/_VariableFunctions.pyi.in"
"${TORCH_SRC_DIR}/nn/functional.pyi.in"
${ModuleStubIn}
"${TOOLS_PATH}/pyi/gen_pyi.py"
WORKING_DIRECTORY
"${TORCH_ROOT}"
Expand Down
15 changes: 11 additions & 4 deletions torch/jit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import warnings
import weakref


# These are imported so users can access them from the `torch.jit` module
from torch._jit_internal import Final, _overload, _overload_method
from torch._jit_internal import ignore, export, unused
Expand Down Expand Up @@ -1394,6 +1393,10 @@ def interface(obj):
if not _is_new_style_class(obj):
raise RuntimeError("TorchScript interfaces must inherit from 'object'")

# Expected MRO is:
# User module
# torch.nn.modules.module.Module
# object
is_module_interface = issubclass(obj, torch.nn.Module) and len(obj.mro()) == 3

if not is_module_interface and len(obj.mro()) > 2:
Expand Down Expand Up @@ -1555,7 +1558,7 @@ def __getitem__(self, k):
# parameters are initialized _before_ the script compiler resolve references to
# `self.param` or `self.module`.
class ScriptMeta(type):
def __init__(cls, name, bases, attrs):
def __init__(cls, name, bases, attrs): # noqa: B902
# Aggregate all the ScriptMethods and constants from superclasses
cls._methods = {}
cls._constants_set = set(getattr(cls, '__constants__', ()))
Expand Down Expand Up @@ -1641,8 +1644,12 @@ def __setattr__(self, attr, value):
# This ensures that if we use the attr again in `__init__`, it
# will look like the actual value, not an instance of Attribute.
if isinstance(value, Attribute):
if not hasattr(self, "__annotations__"):
self.__annotations__ = {}
# NB: Ensure that we set __annotations__ on the specific
# class in question, and not on a superclass (which would
# be wrong wrong wrong!).
# See also https://github.com/pytorch/pytorch/issues/39463
if "__annotations__" not in self.__class__.__dict__:
self.__class__.__annotations__ = {}
self.__annotations__[attr] = value.type
value = value.value
return super(ScriptModule, self).__setattr__(attr, value)
Expand Down
Loading

0 comments on commit eace053

Please sign in to comment.