Skip to content

Commit

Permalink
add unpack_outputs to inlineCallTo
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#15382

Differential Revision: D13518844

Pulled By: zdevito

fbshipit-source-id: 981936988080af80629b70bf5f6dfa52ceb09c2f
  • Loading branch information
zdevito authored and facebook-github-bot committed Dec 19, 2018
1 parent 07d20b1 commit 0b21953
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 23 deletions.
17 changes: 2 additions & 15 deletions torch/csrc/jit/autodiff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,19 +150,6 @@ bool isDifferentiable(Graph & g) {
static_cast<bool(*)(Node*)>(isDifferentiable));
}

// TODO: Remove this after #15355.
namespace {
std::vector<Value*> inlineUnpackedCallTo(Graph& g, Graph& callee, ArrayRef<Value*> inputs) {
auto outputs = inlineCallTo(g, callee, inputs);
if (callee.outputs().size() == 1 && callee.outputs().at(0)->type()->kind() == TupleType::Kind) {
auto tc = createTupleUnpack(outputs.at(0));
outputs = std::vector<Value*>(tc.begin(), tc.end());
}
return outputs;
}
} //anonymous namespace


// NB: Write gradient using torchscript
// For example, node aten::mul() should be defined as follows
// def forward(x, y):
Expand Down Expand Up @@ -200,7 +187,7 @@ static c10::optional<std::vector<Value*>> build_script_grad(
{
WithInsertPoint guard(node->next());
auto fw_graph = compiled_graphs->forward;
new_outputs = inlineUnpackedCallTo(*graph, *fw_graph, node->inputs());
new_outputs = inlineCallTo(*graph, *fw_graph, node->inputs(), /*unpack_outputs=*/true);
for (size_t i = 0; i < node->outputs().size(); ++i) {
new_outputs.at(i)->setType(node->outputs()[i]->type());
new_outputs.at(i)->replaceAllUsesWith(node->outputs()[i]);
Expand All @@ -213,7 +200,7 @@ static c10::optional<std::vector<Value*>> build_script_grad(
auto it = grad_vec.begin();
grad_vec.insert(it, new_outputs.back());
ArrayRef<Value*> grad(grad_vec);
auto grad_inputs = inlineUnpackedCallTo(*graph, *bw_graph, grad);
auto grad_inputs = inlineCallTo(*graph, *bw_graph, grad, /*unpack_outputs=*/true);
return grad_inputs;
};

Expand Down
17 changes: 16 additions & 1 deletion torch/csrc/jit/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1495,7 +1495,7 @@ at::ArrayRef<Value*> createTupleUnpack(Value* v) {
return g.insertNode(g.createTupleUnpack(v))->outputs();
}

std::vector<Value*> inlineCallTo(Graph& g, Graph& callee, ArrayRef<Value*> inputs) {
std::vector<Value*> inlineCallTo(Graph& g, Graph& callee, ArrayRef<Value*> inputs, bool unpack_outputs) {
std::unordered_map<Value*, Value*> value_map;
auto value_map_func = [&](Value* v) { return value_map.at(v); };
JIT_ASSERT(callee.inputs().size() == inputs.size());
Expand All @@ -1514,6 +1514,21 @@ std::vector<Value*> inlineCallTo(Graph& g, Graph& callee, ArrayRef<Value*> input
for (auto* output : callee.outputs()) {
outputs.push_back(value_map_func(output));
}

if (unpack_outputs && outputs.size() == 1 &&
callee.outputs().at(0)->type()->kind() == TupleType::Kind) {
auto tup = outputs[0];
outputs.clear();
for(Value* v : createTupleUnpack(tup)) {
outputs.emplace_back(v);
}
// if this was a peephole tuple unpack we can just get rid of
// the tuple construct here and prevent needing DCE
if (tup->node()->kind() == prim::TupleConstruct && !tup->node()->hasUses()) {
tup->node()->destroy();
}
}

return outputs;
}

Expand Down
4 changes: 3 additions & 1 deletion torch/csrc/jit/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -1105,7 +1105,9 @@ inline Node* Graph::createPythonOp(
TORCH_API void LintGraph(std::shared_ptr<Graph>& graph);

TORCH_API at::ArrayRef<Value*> createTupleUnpack(Value* v);
TORCH_API std::vector<Value*> inlineCallTo(Graph& g, Graph& callee, ArrayRef<Value*> inputs);
// unpack_outputs - if true, and the callee returns a single tuple value, then insert a tuple unpack node
// and return the resulting values
TORCH_API std::vector<Value*> inlineCallTo(Graph& g, Graph& callee, ArrayRef<Value*> inputs, bool unpack_outputs=false);


}} // namespace torch::jit
7 changes: 1 addition & 6 deletions torch/csrc/jit/passes/to_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,7 @@ std::shared_ptr<Graph> ToBatch::getBatchOperator(const std::string& name, int64_
}

std::vector<Value*> inlineUnpackedCallTo(Graph& g, Graph& callee, ArrayRef<Value*> inputs) {
auto outputs = inlineCallTo(g, callee, inputs);
if (callee.outputs().size() == 1 && callee.outputs().at(0)->type()->kind() == TupleType::Kind) {
auto tc = createTupleUnpack(outputs.at(0));
outputs = std::vector<Value*>(tc.begin(), tc.end());
}
return outputs;
return inlineCallTo(g, callee, inputs, /*unpack_outputs=*/true);
}

// replace aten operator node with BatchTensor operator graph
Expand Down

0 comments on commit 0b21953

Please sign in to comment.