Skip to content

Commit

Permalink
[jit] Factor findAllNodes into one place. (pytorch#65965)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#65965

ghstack-source-id: 141504185

Test Plan: no behavior change

Reviewed By: qihqi, ejguan

Differential Revision: D31326152

fbshipit-source-id: 2e0261a96853bfb67a96dd68972c905b6b26d562
  • Loading branch information
zhxchen17 authored and facebook-github-bot committed Oct 25, 2021
1 parent 239b382 commit 059ae96
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 51 deletions.
23 changes: 2 additions & 21 deletions torch/csrc/jit/api/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,25 +26,6 @@ std::string getInputDebugName(const Node& n, const int idx) {
return n.inputs().at(idx)->debugName();
}

std::vector<Node*> findAllNodes(
c10::ArrayRef<torch::jit::Block*> blocks,
Symbol kind,
bool recurse) {
std::vector<Node*> ret;
for (Block* block : blocks) {
for (Node* n : block->nodes()) {
if (n->kind() == kind) {
ret.push_back(n);
}
if (recurse) {
auto nodes = findAllNodes(n->blocks(), kind, recurse);
ret.insert(ret.end(), nodes.begin(), nodes.end());
}
}
}
return ret;
}

void assert_ignored_methods_not_called(
torch::jit::Function* fn,
const std::unordered_set<std::string>& ignored_methods) {
Expand All @@ -53,7 +34,7 @@ void assert_ignored_methods_not_called(
}
const bool recurse = true;
std::vector<Node*> all_nodes =
findAllNodes({fn->graph()->block()}, c10::prim::CallMethod, recurse);
findAllNodes(*fn->graph().get(), c10::prim::CallMethod, recurse);

// Extract method names from these nodes.
std::unordered_set<std::string> encountered_ignored_methods;
Expand Down Expand Up @@ -90,7 +71,7 @@ void assert_ignored_attributes_not_referenced(

const bool recurse = true;
std::vector<Node*> all_nodes =
findAllNodes({fn->graph()->block()}, c10::prim::GetAttr, recurse);
findAllNodes(*fn->graph().get(), c10::prim::GetAttr, recurse);

// Extract attribute names from these nodes.
std::unordered_set<std::string> encountered_ignored_attributes;
Expand Down
36 changes: 36 additions & 0 deletions torch/csrc/jit/ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,23 @@ std::string normalizeAttrName(c10::string_view field) {
return std::string{field};
}

void findAllNodes(
Block& block,
Symbol kind,
bool recurse,
std::vector<Node*>& ret) {
for (Node* n : block.nodes()) {
if (n->kind() == kind) {
ret.push_back(n);
}
if (recurse) {
for (auto b : n->blocks()) {
findAllNodes(*b, kind, recurse, ret);
}
}
}
}

} // namespace

// NB: This overload will become ambiguous with the one Caffe2 provides in its
Expand Down Expand Up @@ -2179,6 +2196,25 @@ std::vector<Value*> unpackOutputs(const std::vector<Value*>& outputs) {
return new_outputs;
}

std::vector<Node*> findAllNodes(
at::ArrayRef<Block*> array,
Symbol kind,
bool recurse) {
std::vector<Node*> ret;
for (auto block : array) {
findAllNodes(*block, kind, recurse, ret);
}
return ret;
}

std::vector<Node*> findAllNodes(Block& block, Symbol kind, bool recurse) {
return findAllNodes({&block}, kind, recurse);
}

std::vector<Node*> findAllNodes(Graph& g, Symbol kind, bool recurse) {
return findAllNodes(*g.block(), kind, recurse);
}

std::vector<Value*> insertGraph(
Graph& g,
Graph& callee,
Expand Down
7 changes: 7 additions & 0 deletions torch/csrc/jit/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -1547,6 +1547,13 @@ TORCH_API std::vector<Value*> inlineCallTo(
*/
TORCH_API std::vector<Value*> unpackOutputs(const std::vector<Value*>& outputs);

TORCH_API std::vector<Node*> findAllNodes(Graph& g, Symbol kind, bool recurse);
TORCH_API std::vector<Node*> findAllNodes(Block& b, Symbol kind, bool recurse);
TORCH_API std::vector<Node*> findAllNodes(
at::ArrayRef<Block*> a,
Symbol kind,
bool recurse);

struct OperatorSet {
OperatorSet(std::initializer_list<const char*> sig_literals);

Expand Down
32 changes: 2 additions & 30 deletions torch/csrc/jit/python/python_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,33 +78,6 @@ std::ostream& printPyObject(std::ostream& out, const THPObjectPtr& obj) {
}
}

std::vector<Node*> findAllNodes(
c10::ArrayRef<torch::jit::Block*> blocks,
Symbol kind,
bool recurse = true) {
std::vector<Node*> ret;
for (Block* block : blocks) {
for (Node* n : block->nodes()) {
if (n->kind() == kind) {
ret.push_back(n);
}
if (recurse) {
auto nodes = findAllNodes(n->blocks(), kind, recurse);
ret.insert(ret.end(), nodes.begin(), nodes.end());
}
}
}
return ret;
}

std::vector<Node*> findAllNodes(
Block* block,
Symbol kind,
bool recurse = true) {
std::vector<Block*> blocks = {block};
return findAllNodes(blocks, kind, recurse);
}

Node* findNode(
c10::ArrayRef<torch::jit::Block*> blocks,
Symbol kind,
Expand Down Expand Up @@ -380,8 +353,7 @@ void initPythonIRBindings(PyObject* module_) {
.def(
"findAllNodes",
[](Graph& g, const std::string& kind, bool recurse) {
return findAllNodes(
g.block(), Symbol::fromQualString(kind), recurse);
return findAllNodes(g, Symbol::fromQualString(kind), recurse);
},
"Find all nodes",
py::arg("kind"),
Expand Down Expand Up @@ -509,7 +481,7 @@ void initPythonIRBindings(PyObject* module_) {
.def(
"findAllNodes",
[](Block& b, const std::string& kind, bool recurse) {
return findAllNodes(&b, Symbol::fromQualString(kind), recurse);
return findAllNodes(b, Symbol::fromQualString(kind), recurse);
},
"Find all nodes",
py::arg("kind"),
Expand Down

0 comments on commit 059ae96

Please sign in to comment.