forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add RefineTypes JIT pass for Tuple (pytorch#76919)
Consider the following JIT graph, where the type of `%a` and `%b` are out of sync with tuple `%c`. Before: ``` graph(%a : Float(123), %b : Float(4, 5, 6)): c : (Tensor, Tensor) = prim::TupleConstruct(%a, %b) return (%c) ``` After: ``` graph(%a : Float(123), %b : Float(4, 5, 6)): c : (Float(123), Float(4, 5, 6)) = prim::TupleConstruct(%a, %b) return (%c) ``` This PR adds a pass `RefineTypes(...)` to update all such instances with the correct type. This is also available via Python by using `torch._C._jit_pass_refine_types(...)`. A unit test has been added for unnamed tuples, but no test exists for `NamedTuple` (though it was tested manually) since it isn't supported by the parser: ``` RuntimeError: unknown type specifier: graph(%a : Float(123), %b : Float(4, 5, 6)): %c : NamedTuple(Tensor : Tuple, Tensor : Tuple) = prim::TupleConstruct(%a, %b) ~~~~~~~~~~ <--- HERE return (%c) ``` cc: @ke1337 @antoniojkim @wconstab @eellison Pull Request resolved: pytorch#76919 Approved by: https://github.com/eellison
- Loading branch information
1 parent
2881e0e
commit f6eb811
Showing
5 changed files
with
74 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
#include <torch/csrc/jit/passes/refine_tuple_types.h> | ||
#include <torch/csrc/jit/runtime/graph_iterator.h> | ||
|
||
#include <ATen/core/type_factory.h> | ||
|
||
namespace torch { | ||
namespace jit { | ||
|
||
namespace { | ||
static void VisitTupleNode(Node* node) { | ||
TORCH_CHECK( | ||
node->outputs().size() == 1, "Tuple must have exactly one output!"); | ||
|
||
Value* output = node->outputs()[0]; | ||
auto tuple_type = output->type()->expectRef<TupleType>(); | ||
|
||
TORCH_CHECK( | ||
tuple_type.containedTypes().size() == node->inputs().size(), | ||
"Number of contained types does not match number of inputs!"); | ||
|
||
// Extract updated types from input values. | ||
std::vector<c10::TypePtr> types; | ||
for (const Value* input : node->inputs()) { | ||
types.push_back(input->type()); | ||
} | ||
|
||
// Construct new tuple type based on input types. | ||
output->setType(tuple_type.withContained(types)); | ||
} | ||
} // anonymous namespace | ||
|
||
void RefineTupleTypes(std::shared_ptr<Graph>& graph) { | ||
DepthFirstGraphNodeIterator it(graph); | ||
for (auto* node = it.next(); node != nullptr; node = it.next()) { | ||
if (node->kind() == prim::TupleConstruct) { | ||
VisitTupleNode(node); | ||
} | ||
} | ||
} | ||
|
||
} // namespace jit | ||
} // namespace torch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
#pragma once | ||
|
||
#include <torch/csrc/jit/ir/ir.h> | ||
|
||
namespace torch { | ||
namespace jit { | ||
|
||
// updates the types of tuples according to the type of their current inputs. | ||
TORCH_API void RefineTupleTypes(std::shared_ptr<Graph>& graph); | ||
|
||
} // namespace jit | ||
} // namespace torch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters