Skip to content

Commit

Permalink
Add RefineTypes JIT pass for Tuple (pytorch#76919)
Browse files Browse the repository at this point in the history
Consider the following JIT graph, where the type of `%a` and `%b` are out of sync with tuple `%c`.
Before:
```
graph(%a : Float(123), %b : Float(4, 5, 6)):
    c : (Tensor, Tensor) = prim::TupleConstruct(%a, %b)
    return (%c)
```
After:
```
graph(%a : Float(123), %b : Float(4, 5, 6)):
    c : (Float(123), Float(4, 5, 6)) = prim::TupleConstruct(%a, %b)
    return (%c)
```
This PR adds a pass `RefineTypes(...)` to update all such instances with the correct type. This is also available via Python by using `torch._C._jit_pass_refine_types(...)`.

A unit test has been added for unnamed tuples, but no test exists for `NamedTuple` (though it was tested manually) since it isn't supported by the parser:
```
RuntimeError:
unknown type specifier:

        graph(%a : Float(123), %b : Float(4, 5, 6)):
          %c : NamedTuple(Tensor : Tuple, Tensor : Tuple) = prim::TupleConstruct(%a, %b)
               ~~~~~~~~~~ <--- HERE
          return (%c)
```

cc: @ke1337 @antoniojkim @wconstab @eellison
Pull Request resolved: pytorch#76919
Approved by: https://github.com/eellison
  • Loading branch information
henrytwo authored and pytorchmergebot committed May 12, 2022
1 parent 2881e0e commit f6eb811
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 0 deletions.
15 changes: 15 additions & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -11252,6 +11252,21 @@ def func(a):
self.run_pass("erase_number_types", graph)
FileCheck().check_not("int = prim::Constant").run(str(graph))

def test_refine_tuple_types(self):
# TupleConstruct output type is not correct here.
graph_str = """
graph(%a : Float(123), %b : Float(4, 5, 6)):
%c : (Tensor, Tensor) = prim::TupleConstruct(%a, %b)
return (%c)
"""
graph = parse_ir(graph_str)
torch._C._jit_pass_refine_tuple_types(graph)

# After the pass, the output type should've been updated.
self.assertTrue('(Float(123), Float(4, 5, 6))' in str(graph.findNode('prim::TupleConstruct').output()))

# TODO(henrytu): Add test for RefineTypes for NamedTuple when it's supported by IR parser.

def test_remove_dropout(self):
weight_0_shape = (20, 5)
weight_1_shape = (20, 20)
Expand Down
1 change: 1 addition & 0 deletions tools/build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ core_sources_full_mobile_no_backend_interface = [
"torch/csrc/jit/passes/peephole.cpp",
"torch/csrc/jit/passes/peephole_non_tensor.cpp",
"torch/csrc/jit/passes/create_functional_graphs.cpp",
"torch/csrc/jit/passes/refine_tuple_types.cpp",
"torch/csrc/jit/passes/remove_mutation.cpp",
"torch/csrc/jit/passes/prepack_folding.cpp",
"torch/csrc/jit/passes/fold_conv_bn.cpp",
Expand Down
42 changes: 42 additions & 0 deletions torch/csrc/jit/passes/refine_tuple_types.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#include <torch/csrc/jit/passes/refine_tuple_types.h>
#include <torch/csrc/jit/runtime/graph_iterator.h>

#include <ATen/core/type_factory.h>

namespace torch {
namespace jit {

namespace {
static void VisitTupleNode(Node* node) {
TORCH_CHECK(
node->outputs().size() == 1, "Tuple must have exactly one output!");

Value* output = node->outputs()[0];
auto tuple_type = output->type()->expectRef<TupleType>();

TORCH_CHECK(
tuple_type.containedTypes().size() == node->inputs().size(),
"Number of contained types does not match number of inputs!");

// Extract updated types from input values.
std::vector<c10::TypePtr> types;
for (const Value* input : node->inputs()) {
types.push_back(input->type());
}

// Construct new tuple type based on input types.
output->setType(tuple_type.withContained(types));
}
} // anonymous namespace

void RefineTupleTypes(std::shared_ptr<Graph>& graph) {
DepthFirstGraphNodeIterator it(graph);
for (auto* node = it.next(); node != nullptr; node = it.next()) {
if (node->kind() == prim::TupleConstruct) {
VisitTupleNode(node);
}
}
}

} // namespace jit
} // namespace torch
12 changes: 12 additions & 0 deletions torch/csrc/jit/passes/refine_tuple_types.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#pragma once

#include <torch/csrc/jit/ir/ir.h>

namespace torch {
namespace jit {

// updates the types of tuples according to the type of their current inputs.
TORCH_API void RefineTupleTypes(std::shared_ptr<Graph>& graph);

} // namespace jit
} // namespace torch
4 changes: 4 additions & 0 deletions torch/csrc/jit/python/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
#include <torch/csrc/jit/passes/quantization/insert_observers.h>
#include <torch/csrc/jit/passes/quantization/insert_quant_dequant.h>
#include <torch/csrc/jit/passes/quantization/quantization_type.h>
#include <torch/csrc/jit/passes/refine_tuple_types.h>
#include <torch/csrc/jit/passes/remove_dropout.h>
#include <torch/csrc/jit/passes/remove_expands.h>
#include <torch/csrc/jit/passes/remove_inplace_ops.h>
Expand Down Expand Up @@ -899,6 +900,9 @@ void initJITBindings(PyObject* module) {
.def(
"_jit_pass_remove_dropout",
[](script::Module& module) { return removeDropout(module); })
.def(
"_jit_pass_refine_tuple_types",
[](std::shared_ptr<Graph>& graph) { return RefineTupleTypes(graph); })
.def(
"_jit_pass_transform_conv1d_to_conv2d",
[](std::shared_ptr<Graph>& graph) {
Expand Down

0 comments on commit f6eb811

Please sign in to comment.