Skip to content

Commit

Permalink
Reenable all forward-pass fusions that worked before the AD fix (pyto…
Browse files Browse the repository at this point in the history
…rch#14558)

Summary:
Dealing with so many `aten::size` calls (in particular calls on elements computed inside fusion groups) requires us to do some extra graph processing in the fuser (to compute the sizes by explicit broadcasts, instead of writing the intermediate tensors only to check their size). This restores the forward expects of LSTM and MiLSTM to a single big kernel. Unfortunately the backward is much harder, because as long as we can't prove that the reductions are unnecessary (or if we can't distribute them over the op), we will not be able to fuse them.
Pull Request resolved: pytorch#14558

Differential Revision: D13321748

Pulled By: zou3519

fbshipit-source-id: c04fc2f70d106d2bfb56206b5aec517a93b79d1f
  • Loading branch information
apaszke authored and facebook-github-bot committed Dec 4, 2018
1 parent c3bfa0e commit d76fd43
Show file tree
Hide file tree
Showing 9 changed files with 382 additions and 127 deletions.
2 changes: 2 additions & 0 deletions aten/src/ATen/core/interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ namespace c10 {
_(namespaces, namespaces) \
_(prim, Assign) \
_(prim, BroadcastingChunk) \
_(prim, BroadcastSizes) \
_(prim, Constant) \
_(prim, ChunkSizes) \
_(prim, None) \
_(prim, Drop) \
_(prim, Eval) \
Expand Down
113 changes: 61 additions & 52 deletions test/expect/TestScript.test_lstm_fusion_cuda-forward.expect
Original file line number Diff line number Diff line change
Expand Up @@ -19,59 +19,68 @@ with prim::DifferentiableGraph_0 = graph(%0 : Float(*, *)
%8 : Float(*, *) = aten::mm(%5, %7)
%9 : Float(*, *) = aten::t(%4)
%10 : Float(*, *) = aten::mm(%3, %9)
%11 : int = prim::Constant[value=1]()
%12 : int[] = aten::size(%8)
%13 : int[] = aten::size(%10)
%14 : int[] = aten::size(%2)
%15 : Float(*, *), %16 : Float(*, *) = prim::FusionGroup_0(%2, %8, %10)
%17 : int[] = aten::size(%16)
%18 : int[] = aten::size(%15)
%19 : int[] = aten::size(%1)
%20 : Tensor[] = prim::ListConstruct(%15, %1)
%21 : Tensor[] = aten::broadcast_tensors(%20)
%22 : Tensor, %23 : Tensor = prim::ListUnpack(%21)
%24 : int[] = aten::size(%0)
%hy : Float(*, *), %26 : Float(*, *), %cy : Float(*, *), %28 : Float(*, *), %29 : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_1(%0, %23, %22)
%34 : int[] = aten::size(%ingate.1)
%35 : int[] = aten::size(%forgetgate.1)
%36 : int[] = aten::size(%cellgate.1)
%37 : int[] = aten::size(%outgate.1)
%38 : int[] = aten::size(%29)
%39 : int[] = aten::size(%28)
%40 : int[] = aten::size(%26)
return (%hy, %cy, %7, %9, %12, %13, %17, %14, %18, %19, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %35, %24, %34, %36, %38, %39, %26, %37, %40);
%11 : int[] = aten::size(%8)
%12 : int[] = aten::size(%10)
%13 : int[] = aten::size(%2)
%14 : int[] = aten::size(%1)
%15 : Tensor[] = prim::ListConstruct(%1, %2, %8, %10)
%16 : Tensor[] = aten::broadcast_tensors(%15)
%17 : Tensor, %18 : Tensor, %19 : Tensor, %20 : Tensor = prim::ListUnpack(%16)
%21 : int[] = prim::BroadcastSizes(%11, %12)
%22 : int[] = prim::BroadcastSizes(%21, %13)
%23 : int[] = aten::size(%0)
%hy : Float(*, *), %25 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %20, %19, %18, %17)
%31 : int[] = aten::size(%25)
%32 : int[] = aten::size(%outgate.1)
%33 : int[] = aten::size(%cellgate.1)
%34 : int[] = aten::size(%forgetgate.1)
%35 : int[] = aten::size(%ingate.1)
%36 : int[] = prim::BroadcastSizes(%34, %23)
%37 : int[] = prim::BroadcastSizes(%35, %33)
return (%hy, %cy, %7, %9, %11, %12, %21, %13, %22, %14, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %34, %23, %35, %33, %36, %37, %25, %32, %31);
}
with prim::FusionGroup_0 = graph(%0 : Float(*)
%1 : Float(*, *)
%2 : Float(*, *)) {
%3 : int = prim::Constant[value=1]()
%4 : Float(*, *) = aten::add(%1, %2, %3)
%5 : int = prim::Constant[value=1]()
%6 : Float(*, *) = aten::add(%4, %0, %5)
return (%6, %4);
}
with prim::FusionGroup_1 = graph(%0 : Float(*, *)
with prim::FusionGroup_0 = graph(%0 : Float(*, *)
%1 : Tensor
%2 : Tensor) {
%3 : Float(*, *), %4 : Float(*, *), %5 : Float(*, *), %6 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2)
%7 : Float(*, *), %8 : Float(*, *), %9 : Float(*, *), %10 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1)
%11 : int = prim::Constant[value=1]()
%12 : Float(*, *) = aten::add(%3, %7, %11)
%13 : int = prim::Constant[value=1]()
%14 : Float(*, *) = aten::add(%4, %8, %13)
%15 : int = prim::Constant[value=1]()
%16 : Float(*, *) = aten::add(%5, %9, %15)
%17 : int = prim::Constant[value=1]()
%18 : Float(*, *) = aten::add(%6, %10, %17)
%ingate.1 : Float(*, *) = aten::sigmoid(%12)
%forgetgate.1 : Float(*, *) = aten::sigmoid(%14)
%cellgate.1 : Float(*, *) = aten::tanh(%16)
%outgate.1 : Float(*, *) = aten::sigmoid(%18)
%23 : Float(*, *) = aten::mul(%forgetgate.1, %0)
%24 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1)
%2 : Tensor
%3 : Tensor
%4 : Tensor) {
%5 : Float(*, *), %6 : Float(*, *), %7 : Float(*, *), %8 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%4)
%9 : Float(*, *), %10 : Float(*, *), %11 : Float(*, *), %12 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%3)
%13 : Float(*, *), %14 : Float(*, *), %15 : Float(*, *), %16 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2)
%17 : Float(*, *), %18 : Float(*, *), %19 : Float(*, *), %20 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1)
%21 : int = prim::Constant[value=1]()
%22 : Float(*, *) = aten::add(%13, %17, %21)
%23 : int = prim::Constant[value=1]()
%24 : Float(*, *) = aten::add(%14, %18, %23)
%25 : int = prim::Constant[value=1]()
%cy : Float(*, *) = aten::add(%23, %24, %25)
%27 : Float(*, *) = aten::tanh(%cy)
%hy : Float(*, *) = aten::mul(%outgate.1, %27)
return (%hy, %27, %cy, %24, %23, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1);
%26 : Float(*, *) = aten::add(%15, %19, %25)
%27 : int = prim::Constant[value=1]()
%28 : Float(*, *) = aten::add(%16, %20, %27)
%29 : int = prim::Constant[value=1]()
%30 : Float(*, *) = aten::add(%22, %9, %29)
%31 : int = prim::Constant[value=1]()
%32 : Float(*, *) = aten::add(%24, %10, %31)
%33 : int = prim::Constant[value=1]()
%34 : Float(*, *) = aten::add(%26, %11, %33)
%35 : int = prim::Constant[value=1]()
%36 : Float(*, *) = aten::add(%28, %12, %35)
%37 : int = prim::Constant[value=1]()
%38 : Float(*, *) = aten::add(%30, %5, %37)
%39 : int = prim::Constant[value=1]()
%40 : Float(*, *) = aten::add(%32, %6, %39)
%41 : int = prim::Constant[value=1]()
%42 : Float(*, *) = aten::add(%34, %7, %41)
%43 : int = prim::Constant[value=1]()
%44 : Float(*, *) = aten::add(%36, %8, %43)
%ingate.1 : Float(*, *) = aten::sigmoid(%38)
%forgetgate.1 : Float(*, *) = aten::sigmoid(%40)
%cellgate.1 : Float(*, *) = aten::tanh(%42)
%outgate.1 : Float(*, *) = aten::sigmoid(%44)
%49 : Float(*, *) = aten::mul(%forgetgate.1, %0)
%50 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1)
%51 : int = prim::Constant[value=1]()
%cy : Float(*, *) = aten::add(%49, %50, %51)
%53 : Float(*, *) = aten::tanh(%cy)
%hy : Float(*, *) = aten::mul(%outgate.1, %53)
return (%hy, %53, %cy, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1);
}
144 changes: 82 additions & 62 deletions test/expect/TestScript.test_milstm_fusion_cuda-forward.expect
Original file line number Diff line number Diff line change
Expand Up @@ -25,69 +25,89 @@ with prim::DifferentiableGraph_0 = graph(%0 : Float(*, *)
%Uz.1 : Float(*, *) = aten::mm(%5, %11)
%13 : int[] = aten::size(%4)
%14 : int[] = aten::size(%Wx.1)
%15 : int[] = aten::size(%Uz.1)
%16 : int[] = aten::size(%3)
%17 : int = prim::Constant[value=1]()
%18 : int[] = aten::size(%2)
%19 : Float(*, *), %20 : Float(*, *), %21 : Float(*, *), %22 : Float(*, *), %23 : Float(*, *), %24 : Float(*, *) = prim::FusionGroup_0(%2, %Uz.1, %3, %Wx.1, %4)
%25 : int[] = aten::size(%24)
%26 : int[] = aten::size(%23)
%27 : int[] = aten::size(%22)
%28 : int[] = aten::size(%21)
%29 : int[] = aten::size(%20)
%30 : int[] = aten::size(%19)
%31 : int[] = aten::size(%1)
%32 : Tensor[] = prim::ListConstruct(%19, %1)
%33 : Tensor[] = aten::broadcast_tensors(%32)
%34 : Tensor, %35 : Tensor = prim::ListUnpack(%33)
%36 : int[] = aten::size(%0)
%hy : Float(*, *), %38 : Float(*, *), %cy : Float(*, *), %40 : Float(*, *), %41 : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_1(%0, %35, %34)
%15 : Float(*, *) = aten::mul(%4, %Wx.1)
%16 : int[] = aten::size(%15)
%17 : int[] = aten::size(%Uz.1)
%18 : int[] = aten::size(%3)
%19 : int[] = aten::size(%2)
%20 : int[] = aten::size(%1)
%21 : Tensor[] = prim::ListConstruct(%1, %2, %Uz.1, %15, %3, %Wx.1)
%22 : Tensor[] = aten::broadcast_tensors(%21)
%23 : Tensor, %24 : Tensor, %25 : Tensor, %26 : Tensor, %27 : Tensor, %28 : Tensor = prim::ListUnpack(%22)
%29 : int[] = prim::BroadcastSizes(%18, %14)
%30 : int[] = prim::BroadcastSizes(%16, %17)
%31 : int[] = prim::BroadcastSizes(%19, %17)
%32 : int[] = prim::BroadcastSizes(%30, %29)
%33 : int[] = prim::BroadcastSizes(%32, %31)
%34 : int[] = aten::size(%0)
%hy : Float(*, *), %36 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %28, %27, %26, %25, %24, %23)
%42 : int[] = aten::size(%36)
%43 : int[] = aten::size(%outgate.1)
%44 : int[] = aten::size(%cellgate.1)
%45 : int[] = aten::size(%forgetgate.1)
%46 : int[] = aten::size(%ingate.1)
%47 : int[] = aten::size(%forgetgate.1)
%48 : int[] = aten::size(%cellgate.1)
%49 : int[] = aten::size(%outgate.1)
%50 : int[] = aten::size(%41)
%51 : int[] = aten::size(%40)
%52 : int[] = aten::size(%38)
return (%hy, %cy, %9, %Wx.1, %11, %Uz.1, %13, %14, %24, %25, %15, %16, %14, %26, %27, %18, %15, %28, %29, %30, %31, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %47, %36, %46, %48, %50, %51, %38, %49, %52);
%47 : int[] = prim::BroadcastSizes(%45, %34)
%48 : int[] = prim::BroadcastSizes(%46, %44)
return (%hy, %cy, %9, %Wx.1, %11, %Uz.1, %13, %14, %15, %16, %17, %18, %14, %30, %29, %19, %17, %32, %31, %33, %20, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %45, %34, %46, %44, %47, %48, %36, %43, %42);
}
with prim::FusionGroup_0 = graph(%0 : Float(*)
%1 : Float(*, *)
%2 : Float(*)
%3 : Float(*, *)
%4 : Float(*)) {
%5 : Float(*, *) = aten::mul(%4, %3)
%6 : Float(*, *) = aten::mul(%5, %1)
%7 : Float(*, *) = aten::mul(%2, %3)
%8 : int = prim::Constant[value=1]()
%9 : Float(*, *) = aten::add(%6, %7, %8)
%10 : Float(*, *) = aten::mul(%0, %1)
%11 : int = prim::Constant[value=1]()
%12 : Float(*, *) = aten::add(%9, %10, %11)
return (%12, %10, %9, %7, %6, %5);
}
with prim::FusionGroup_1 = graph(%0 : Float(*, *)
with prim::FusionGroup_0 = graph(%0 : Float(*, *)
%1 : Tensor
%2 : Tensor) {
%3 : Float(*, *), %4 : Float(*, *), %5 : Float(*, *), %6 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2)
%7 : Float(*, *), %8 : Float(*, *), %9 : Float(*, *), %10 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1)
%11 : int = prim::Constant[value=1]()
%12 : Float(*, *) = aten::add(%3, %7, %11)
%13 : int = prim::Constant[value=1]()
%14 : Float(*, *) = aten::add(%4, %8, %13)
%15 : int = prim::Constant[value=1]()
%16 : Float(*, *) = aten::add(%5, %9, %15)
%17 : int = prim::Constant[value=1]()
%18 : Float(*, *) = aten::add(%6, %10, %17)
%ingate.1 : Float(*, *) = aten::sigmoid(%12)
%forgetgate.1 : Float(*, *) = aten::sigmoid(%14)
%cellgate.1 : Float(*, *) = aten::tanh(%16)
%outgate.1 : Float(*, *) = aten::sigmoid(%18)
%23 : Float(*, *) = aten::mul(%forgetgate.1, %0)
%24 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1)
%25 : int = prim::Constant[value=1]()
%cy : Float(*, *) = aten::add(%23, %24, %25)
%27 : Float(*, *) = aten::tanh(%cy)
%hy : Float(*, *) = aten::mul(%outgate.1, %27)
return (%hy, %27, %cy, %24, %23, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1);
%2 : Tensor
%3 : Tensor
%4 : Tensor
%5 : Tensor
%6 : Tensor) {
%7 : Float(*, *), %8 : Float(*, *), %9 : Float(*, *), %10 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%6)
%11 : Float(*, *), %12 : Float(*, *), %13 : Float(*, *), %14 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%5)
%15 : Float(*, *), %16 : Float(*, *), %17 : Float(*, *), %18 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%4)
%19 : Float(*, *), %20 : Float(*, *), %21 : Float(*, *), %22 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%3)
%23 : Float(*, *), %24 : Float(*, *), %25 : Float(*, *), %26 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%2)
%27 : Float(*, *), %28 : Float(*, *), %29 : Float(*, *), %30 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%1)
%31 : Float(*, *) = aten::mul(%23, %27)
%32 : Float(*, *) = aten::mul(%24, %28)
%33 : Float(*, *) = aten::mul(%25, %29)
%34 : Float(*, *) = aten::mul(%26, %30)
%35 : Float(*, *) = aten::mul(%19, %15)
%36 : Float(*, *) = aten::mul(%20, %16)
%37 : Float(*, *) = aten::mul(%21, %17)
%38 : Float(*, *) = aten::mul(%22, %18)
%39 : Float(*, *) = aten::mul(%11, %15)
%40 : Float(*, *) = aten::mul(%12, %16)
%41 : Float(*, *) = aten::mul(%13, %17)
%42 : Float(*, *) = aten::mul(%14, %18)
%43 : int = prim::Constant[value=1]()
%44 : Float(*, *) = aten::add(%35, %31, %43)
%45 : int = prim::Constant[value=1]()
%46 : Float(*, *) = aten::add(%36, %32, %45)
%47 : int = prim::Constant[value=1]()
%48 : Float(*, *) = aten::add(%37, %33, %47)
%49 : int = prim::Constant[value=1]()
%50 : Float(*, *) = aten::add(%38, %34, %49)
%51 : int = prim::Constant[value=1]()
%52 : Float(*, *) = aten::add(%44, %39, %51)
%53 : int = prim::Constant[value=1]()
%54 : Float(*, *) = aten::add(%46, %40, %53)
%55 : int = prim::Constant[value=1]()
%56 : Float(*, *) = aten::add(%48, %41, %55)
%57 : int = prim::Constant[value=1]()
%58 : Float(*, *) = aten::add(%50, %42, %57)
%59 : int = prim::Constant[value=1]()
%60 : Float(*, *) = aten::add(%52, %7, %59)
%61 : int = prim::Constant[value=1]()
%62 : Float(*, *) = aten::add(%54, %8, %61)
%63 : int = prim::Constant[value=1]()
%64 : Float(*, *) = aten::add(%56, %9, %63)
%65 : int = prim::Constant[value=1]()
%66 : Float(*, *) = aten::add(%58, %10, %65)
%ingate.1 : Float(*, *) = aten::sigmoid(%60)
%forgetgate.1 : Float(*, *) = aten::sigmoid(%62)
%cellgate.1 : Float(*, *) = aten::tanh(%64)
%outgate.1 : Float(*, *) = aten::sigmoid(%66)
%71 : Float(*, *) = aten::mul(%forgetgate.1, %0)
%72 : Float(*, *) = aten::mul(%ingate.1, %cellgate.1)
%73 : int = prim::Constant[value=1]()
%cy : Float(*, *) = aten::add(%71, %72, %73)
%75 : Float(*, *) = aten::tanh(%cy)
%hy : Float(*, *) = aten::mul(%outgate.1, %75)
return (%hy, %75, %cy, %outgate.1, %cellgate.1, %forgetgate.1, %ingate.1);
}
10 changes: 4 additions & 6 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4308,9 +4308,8 @@ def test_lstm_fusion_cuda(self):
inputs = get_lstm_inputs('cuda', training=True)
module = self.checkScript(LSTMCellS, inputs)
forward_graph = module.graph_for(*inputs)
with self.assertRaises(AssertionError):
self.assertGraphContainsExactly(
forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
self.assertGraphContainsExactly(
forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
self.assertExpectedGraph(forward_graph, subname='forward')

hy, cy = module(*inputs)
Expand All @@ -4324,9 +4323,8 @@ def test_milstm_fusion_cuda(self):
inputs = get_milstm_inputs('cuda', training=True)
module = self.checkScript(MiLSTMCell, inputs)
forward_graph = module.graph_for(*inputs)
with self.assertRaises(AssertionError):
self.assertGraphContainsExactly(
forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
self.assertGraphContainsExactly(
forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
self.assertExpectedGraph(forward_graph, subname='forward')

hy, cy = module(*inputs)
Expand Down
5 changes: 1 addition & 4 deletions torch/csrc/jit/fuser/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,13 @@ static c10::optional<std::vector<int64_t>> getMapSize(
, at::TensorList args
, at::IntList arg_subset) {

int64_t dim_after_broadcast = 0;
for (const auto arg_idx : arg_subset) {
dim_after_broadcast = std::max(dim_after_broadcast, args[arg_idx].dim());
}
// TODO: this keeps reallocating map_size at every iteration, but we know
// exactly how much storage do we need, so this could be fixed in-place at
// every step. We're just missing a few functions for ATen, but the fix
// should be straightforward.
// Note: left unitialized since empty shape is broadcastable to any shape
std::vector<int64_t> map_size;
map_size.reserve(8);
for (const auto arg_idx : arg_subset) {
auto& arg = args.at(arg_idx);
auto& chunk_desc = spec.inputChunks().at(arg_idx);
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/passes/alias_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,8 @@ void AliasDb::analyze(Node* node) {
case prim::MMTreeReduce:
case prim::MMBatchSide:
case prim::None:
case prim::BroadcastSizes:
case prim::ChunkSizes:
return analyzeCreator(node);
case prim::TupleUnpack:
case prim::TupleIndex:
Expand Down
Loading

0 comments on commit d76fd43

Please sign in to comment.