Skip to content

Commit

Permalink
[TensorExpr] Nuke DepTracker and findAllNeededTensors. (pytorch#54997)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#54997

DepTracker was used to automatically pull in dependent computations from
output ones. While it seems quite convenient, it's led to several
architectural issues, which are fixed in this stack.

DepTracker worked on Tensors, which is a pair of Buf and Stmt. However,
Stmt could become stale and there was no way to reliably update the
corresponding tensor. We're now using Bufs and Stmts directly and moving
away from using Tensors to avoid these problems.

Removing DepTracker allowed to unify Loads and FunctionCalls, which
essentially were duplicates of each other.

Test Plan: Imported from OSS

Reviewed By: navahgar

Differential Revision: D27446414

Pulled By: ZolotukhinM

fbshipit-source-id: a2a32749d5b28beed92a601da33d126c0a2cf399
  • Loading branch information
Mikhail Zolotukhin authored and facebook-github-bot committed Apr 2, 2021
1 parent 0d47374 commit 688e350
Show file tree
Hide file tree
Showing 13 changed files with 62 additions and 136 deletions.
4 changes: 2 additions & 2 deletions benchmarks/cpp/tensorexpr/bench_compile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ static void BM_CompileSwish(benchmark::State& state) {
te::Tensor* sixth = te::Compute("sixth", {{n, "n"}}, [&](const te::VarHandle& i) {
return times->call(i) * 1.f / 6.f;
});
te::LoopNest nest({sixth});
te::LoopNest nest({sixth}, {relu, min6, plus3, times});
for (auto tensor : {relu, min6, plus3, times}) {
nest.computeInline(tensor->buf());
}
Expand Down Expand Up @@ -58,7 +58,7 @@ static void BM_CompileSwishLLVMOnly(benchmark::State& state) {
te::Tensor* sixth = te::Compute("sixth", {{n, "n"}}, [&](const te::VarHandle& i) {
return times->call(i) * 1.f / 6.f;
});
te::LoopNest nest({sixth});
te::LoopNest nest({sixth}, {relu, min6, plus3, times, sixth});
for (auto tensor : {relu, min6, plus3, times}) {
nest.computeInline(tensor->buf());
}
Expand Down
2 changes: 1 addition & 1 deletion test/cpp/tensorexpr/test_ir_printer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ TEST(IRPrinter, FunctionName) {
R"IR(
# CHECK: for (int i
# CHECK: for (int j
# CHECK: consumer[i, j] = i * (chunk_1(i, j)IR";
# CHECK: consumer[i, j] = i * (chunk_1(i, j))IR";

torch::jit::testing::FileCheck().run(verification_pattern, ss.str());
}
Expand Down
2 changes: 1 addition & 1 deletion test/cpp/tensorexpr/test_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1682,7 +1682,7 @@ TEST(LLVM, CompositeParallel) {
[=](const VarHandle& m, const VarHandle& n) {
return t3->call(m, n) + m + n;
});
LoopNest loop_nest({t4});
LoopNest loop_nest({t4}, {t1, t2, t3, t4});
std::vector<For*> loop_list;
{
auto const& loops = loop_nest.getLoopStmtsFor(t1);
Expand Down
64 changes: 32 additions & 32 deletions test/cpp/tensorexpr/test_loopnest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -767,7 +767,7 @@ TEST(LoopNest, ScheduleFunctionCall01) {
return c->call(m, n, k) + 1;
});

LoopNest l({d});
LoopNest l({d}, {c, d});
l.prepareForCodegen();
Stmt* stmt = l.root_stmt();
std::ostringstream oss;
Expand Down Expand Up @@ -827,7 +827,7 @@ TEST(LoopNest, ScheduleInlineSimple) {
return c_buf.load(m, n) * d_buf.load(m, k) + x->call(m, n, k);
});

LoopNest l1({y});
LoopNest l1({y}, {x, y});
LoopNest l2(l1);
l2.computeInline(x->buf());

Expand Down Expand Up @@ -914,7 +914,7 @@ void InlineFunc01Helper(const std::vector<std::string>& inline_order) {
return x->call(m, n, k) + y->call(m, n, k);
});

LoopNest l({z});
LoopNest l({z}, {x, y, z});
for (const std::string& order : inline_order) {
if (order == "x") {
l.computeInline(x->buf());
Expand Down Expand Up @@ -1023,7 +1023,7 @@ TEST(LoopNest, ScheduleInlineRandom) {
return x->call(m, n, k) + x->call(m, n, k);
});

LoopNest l1({y});
LoopNest l1({y}, {x, y});
l1.computeInline(x->buf());

// would normally compare results but Rand isn't implemented in the
Expand Down Expand Up @@ -1060,7 +1060,7 @@ TEST(LoopNest, ScheduleInlineRandomUnrelated) {
Intrinsics::make(kRand, kInt);
});

LoopNest l1({y});
LoopNest l1({y}, {x, y});
l1.computeInline(x->buf());

// would normally compare results but Rand isn't implemented in the
Expand Down Expand Up @@ -1093,7 +1093,7 @@ TEST(LoopNest, ScheduleInlineRandomLowerDimensions) {
return x->call(m) + x->call(m);
});

LoopNest l1({y});
LoopNest l1({y}, {x, y});
l1.computeInline(x->buf());

// would normally compare results but Rand isn't implemented in the
Expand Down Expand Up @@ -1145,7 +1145,7 @@ TEST(LoopNest, ScheduleInlineIntrinsics) {
}
}

LoopNest l1({y});
LoopNest l1({y}, {x, y});
LoopNest l2(l1);
l2.computeInline(x->buf());

Expand Down Expand Up @@ -1190,7 +1190,7 @@ TEST(LoopNest, ScheduleInlineRandWithIntrinsics) {
return Intrinsics::make(kSqrt, x->call(m, n, k));
});

LoopNest l1({y});
LoopNest l1({y}, {x, y});
l1.computeInline(x->buf());

Stmt* stmt1 = IRSimplifier::simplify(l1.root_stmt());
Expand All @@ -1216,7 +1216,7 @@ TEST(LoopNest, ScheduleSplitAThenInline) {
For* i_outer;
For* i_inner;

LoopNest l({b});
LoopNest l({b}, {a, b});
std::vector<For*> loops = l.getLoopStmtsFor(a);
l.splitWithMask(loops[0], 4, &i_outer, &i_inner);
ASSERT_THROWS_WITH(l.computeInline(a->buf()), "compound indices");
Expand All @@ -1234,7 +1234,7 @@ TEST(LoopNest, ScheduleSplitBThenInline) {
For* i_outer;
For* i_inner;

LoopNest l({b});
LoopNest l({b}, {a, b});
std::vector<For*> loops = l.getLoopStmtsFor(b);
l.splitWithMask(loops[0], 3, &i_outer, &i_inner);
l.computeInline(a->buf());
Expand All @@ -1261,7 +1261,7 @@ TEST(LoopNest, ScheduleSplitTwiceThenInline) {
For* i_outer;
For* i_inner;

LoopNest l({b});
LoopNest l({b}, {a, b});
std::vector<For*> loops = l.getLoopStmtsFor(a);
l.splitWithMask(loops[0], 4, &i_outer, &i_inner);
l.splitWithMask(i_inner, 2, &i_outer, &i_inner);
Expand All @@ -1280,7 +1280,7 @@ TEST(LoopNest, ScheduleInlineThenSplit) {
For* i_outer;
For* i_inner;

LoopNest l({b});
LoopNest l({b}, {a, b});
l.computeInline(a->buf());

std::vector<For*> loops = NodeFinder<For>::find(l.root_stmt());
Expand Down Expand Up @@ -1308,7 +1308,7 @@ TEST(LoopNest, ScheduleSplitInlineThenSplit) {
For* i_outer;
For* i_inner;

LoopNest l({b});
LoopNest l({b}, {a, b});
auto loops = NodeFinder<For>::find(l.root_stmt());
l.splitWithMask(loops.back(), 2, &i_outer, &i_inner);
l.computeInline(a->buf());
Expand Down Expand Up @@ -1339,7 +1339,7 @@ TEST(LoopNest, ScheduleSplitInlineSimplify) {
For* i_outer;
For* i_inner;

LoopNest l({b});
LoopNest l({b}, {a, b});
std::vector<For*> loops = l.getLoopStmtsFor(a);
l.splitWithMask(loops[0], 4, &i_outer, &i_inner);
ASSERT_THROWS_WITH(l.computeInline(a->buf()), "compound indices");
Expand All @@ -1358,7 +1358,7 @@ TEST(LoopNest, ScheduleInlineThreeMixedOnce) {
return a->call(k) * b->call(l);
});

LoopNest l({c});
LoopNest l({c}, {a, b, c});
std::vector<For*> loops = l.getLoopStmtsFor(a);
l.computeInline(a->buf());
l.prepareForCodegen();
Expand Down Expand Up @@ -1388,7 +1388,7 @@ TEST(LoopNest, ScheduleInlineThreeMixedTwice) {
return a->call(k) * b->call(l);
});

LoopNest l({c});
LoopNest l({c}, {a, b, c});
std::vector<For*> loops = l.getLoopStmtsFor(a);
l.computeInline(a->buf());
l.computeInline(b->buf());
Expand Down Expand Up @@ -1419,7 +1419,7 @@ TEST(LoopNest, ScheduleInlineThreeMixedInner) {
return a->call(k) * b->call(l);
});

LoopNest l({c});
LoopNest l({c}, {a, b, c});
std::vector<For*> loops = l.getLoopStmtsFor(a);
l.computeInline(b->buf());
l.prepareForCodegen();
Expand Down Expand Up @@ -1451,7 +1451,7 @@ TEST(LoopNest, ScheduleInlineThreeMixedSplit) {

For* i_outer;
For* i_inner;
LoopNest l({c});
LoopNest l({c}, {a, b, c});
std::vector<For*> loops = l.getLoopStmtsFor(a);
l.splitWithMask(loops[0], 4, &i_outer, &i_inner);
loops = l.getLoopStmtsFor(b);
Expand Down Expand Up @@ -1555,7 +1555,7 @@ TEST(LoopNest, ScheduleFuserThreeArg) {
return f->call(i) + d.load(i);
});

LoopNest l({g});
LoopNest l({g}, {e, f, g});
l.computeInline(l.getLoopBodyFor(e));
l.computeInline(l.getLoopBodyFor(f));
l.prepareForCodegen();
Expand Down Expand Up @@ -1619,7 +1619,7 @@ TEST(LoopNest, LoopNestComputeAt_1) {
"A", {{N, "i_a"}}, [&](const VarHandle& i_a) { return i_a * i_a; });
Tensor* B = Compute(
"B", {{N, "i_b"}}, [&](const VarHandle& i_b) { return A->call(i_b); });
LoopNest l({B});
LoopNest l({B}, {A, B});
std::vector<For*> loops = l.getLoopStmtsFor(B);
l.computeAt(l.getLoopBodyFor(A), loops[0]);
l.prepareForCodegen();
Expand Down Expand Up @@ -1682,7 +1682,7 @@ TEST(LoopNest, LoopNestComputeAt_2) {
c_ref[y * kW + x] = y * x + (y + 1) * x + y * (x + 1) + (y + 1) * (x + 1);
}
}
LoopNest orig_loopnest({c});
LoopNest orig_loopnest({c}, {p, c});

{
// First let's try to compute P at axis cy (the outer loop)
Expand Down Expand Up @@ -1782,7 +1782,7 @@ TEST(LoopNest, LoopNestComputeAt_3) {
}
}

LoopNest orig_loopnest({D});
LoopNest orig_loopnest({D}, {A, B, C, D});
{
// First let's try to compute A at axis dy (the outer loop)
LoopNest l(orig_loopnest);
Expand Down Expand Up @@ -1873,7 +1873,7 @@ TEST(LoopNest, Reduce2dComputeAt) {
c_ref[y * kW + x] = y * x + (y + 1) * x + y * (x + 1) + (y + 1) * (x + 1);
}
}
LoopNest orig_loopnest({c});
LoopNest orig_loopnest({c}, {p, c});
checkIR(orig_loopnest.root_stmt(), R"IR(
# CHECK: for (int py = 0; py < H + 1; py++) {
# CHECK: for (int px = 0; px < W + 1; px++) {
Expand Down Expand Up @@ -2538,7 +2538,7 @@ TEST(LoopNest, LoopNestReorderInternalLoopNest) {
return x->call(m, n, k) + y->call(m, n, k);
});

LoopNest l({z});
LoopNest l({z}, {x, y, z});
For* a = nullptr;
For* b = nullptr;
auto fors = NodeFinder<For>::find(l.root_stmt());
Expand Down Expand Up @@ -3417,7 +3417,7 @@ TEST(LoopNest, DetectInlineRankMismatch) {
"reshape",
{{kTotalSize / 2, "i"}, {2, "j"}},
[&](const VarHandle& i, const VarHandle& j) { return a->call(i, j); });
LoopNest l({reshape});
LoopNest l({reshape}, {a, reshape});
ASSERT_THROWS_WITH(
l.computeInline(l.getLoopBodyFor(a)),
"Placeholder indexed access is inconsistent with its rank");
Expand All @@ -3439,7 +3439,7 @@ TEST(LoopNest, CacheReadsSimple) {
return A->call(i + 10, j + 20) + A->call(i + 30, j + 40);
});

LoopNest l({B, C});
LoopNest l({B, C}, {A, B, C});
Stmt* j_loop = l.getLoopStmtsFor(B)[1];
l.cacheAccesses(A->buf(), "A_local", j_loop);

Expand Down Expand Up @@ -3507,7 +3507,7 @@ TEST(LoopNest, CacheReadsOuter) {
return A->call(i + 10, j + 20) + A->call(i + 30, j + 40);
});

LoopNest l({B, C});
LoopNest l({B, C}, {A, B, C});
Stmt* i_loop = l.getLoopStmtsFor(B)[0];
l.cacheAccesses(A->buf(), "A_local", i_loop);

Expand Down Expand Up @@ -3555,7 +3555,7 @@ TEST(LoopNest, CacheReadsInternal) {
return A->call(i + 10, j + 20) + A->call(i + 30, j + 40);
});

LoopNest l({B, C});
LoopNest l({B, C}, {A, B, C});
Stmt* j_loop = l.getLoopStmtsFor(B)[1];
l.cacheAccesses(A->buf(), "A_local", j_loop);
l.prepareForCodegen();
Expand Down Expand Up @@ -3603,7 +3603,7 @@ TEST(LoopNest, CacheReadsInner) {
return A->call(i + 10, j + 20) + A->call(i + 30, j + 40);
});

LoopNest l({B, C});
LoopNest l({B, C}, {A, B, C});
Stmt* body = l.getLoopBodyFor(B);
l.cacheAccesses(A->buf(), "A_local", body);
l.prepareForCodegen();
Expand Down Expand Up @@ -3650,7 +3650,7 @@ TEST(LoopNest, CacheWritesSimple) {
return A->call(i + 10, j + 20) + A->call(i + 30, j + 40);
});

LoopNest l({B, C});
LoopNest l({B, C}, {A, B, C});
Stmt* a_loop = l.getLoopStmtsFor(A)[1];
l.cacheAccesses(A->buf(), "A_local", a_loop);

Expand Down Expand Up @@ -3830,7 +3830,7 @@ TEST(LoopNest, InlineConstantIndex) {
return y->call(m, n, o);
});

LoopNest l({z});
LoopNest l({z}, {y, z});
l.simplify();
ASSERT_TRUE(l.computeInline(y->buf()));
}
Expand Down Expand Up @@ -3858,7 +3858,7 @@ TEST(LoopNest, CompoundTensorUsed) {
return A->call(i, j + 1) + A->call(i, j + 2);
});

LoopNest l({B});
LoopNest l({B}, {A, B});
ASSERT_FALSE(l.computeInline(A->buf()));
l.prepareForCodegen();

Expand Down
6 changes: 3 additions & 3 deletions test/cpp/tensorexpr/test_memdependency.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2767,7 +2767,7 @@ TEST(MemDependency, MemDependencyCheckerComputeAPI) {
return c->call(m, n, k) + 1;
});

LoopNest l({d});
LoopNest l({d}, {c, d});

MemDependencyChecker analyzer({a_buf.data(), b_buf.data()}, {d->buf()});

Expand Down Expand Up @@ -2814,7 +2814,7 @@ TEST(MemDependency, MemDependencyCheckerComputeInline) {
return c->call(m, n, k) + 1;
});

LoopNest l({d});
LoopNest l({d}, {c, d});
l.computeInline(c->buf());

MemDependencyChecker analyzer({a_buf.data(), b_buf.data()}, {d->buf()});
Expand Down Expand Up @@ -2964,7 +2964,7 @@ TEST(MemDependency, MemDependencyCheckerComputeReduce) {
return b.load(l, n, m) * a.load(l, n, m);
});
Tensor* d = Reduce("sum", {{2, "l1"}}, Sum(), c, {{3, "n1"}, {6, "m1"}});
LoopNest l({d});
LoopNest l({d}, {c, d});

MemDependencyChecker analyzer({a.data(), b.data()}, {d->buf()});

Expand Down
Loading

0 comments on commit 688e350

Please sign in to comment.