Skip to content

Commit

Permalink
[TensorExpr] LoopNest: add a constructor that takes Stmt instead of l…
Browse files Browse the repository at this point in the history
…ist of Tensors. (pytorch#45949)

Summary: Pull Request resolved: pytorch#45949

Test Plan: Imported from OSS

Reviewed By: bertmaher

Differential Revision: D24156001

Pulled By: ZolotukhinM

fbshipit-source-id: 6f4f050b04e802e274c42ed64be74c21ba79c29f
  • Loading branch information
Mikhail Zolotukhin authored and facebook-github-bot committed Oct 8, 2020
1 parent 1036b77 commit 6e4de44
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions torch/csrc/jit/tensorexpr/loopnest.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,22 @@ class Dtype;

class TORCH_API LoopNest {
public:
// A constructor for building a LoopNest from a list of Tensors
LoopNest(const std::vector<Tensor*>& output_tensors);

// A constructor for building a LoopNest from a pre-baked Stmt and meta-info
// TODO: Nuke intermediate_bufs_ and possibly buf_initializers from here if
// they can be deduced.
LoopNest(
Stmt* stmt,
const std::unordered_set<const Buf*>& output_bufs,
const std::unordered_set<const Buf*>& intermediate_bufs,
const std::unordered_map<const Buf*, const Expr*>& buf_initializers)
: root_stmt_(stmt),
output_bufs_(output_bufs),
intermediate_bufs_(intermediate_bufs),
buf_initializers_(buf_initializers) {}

Stmt* root_stmt() const {
return root_stmt_;
}
Expand Down

0 comments on commit 6e4de44

Please sign in to comment.