Skip to content

Commit

Permalink
[jit] Remove graph() call from abstract Function interface. (pytorch#…
Browse files Browse the repository at this point in the history
…65967)

Summary:
Pull Request resolved: pytorch#65967

Graph is an implementation detail. If user wants to get access to the
underlying graph, they should be able to explicitly dynamic cast instead.
ghstack-source-id: 141659819

Test Plan: no behavior change.

Reviewed By: gmagogsfm

Differential Revision: D31326153

fbshipit-source-id: a0e984f57c6013494b92a7095bf5bb660035eb84
  • Loading branch information
zhxchen17 authored and facebook-github-bot committed Oct 27, 2021
1 parent 7c48b9e commit b55a250
Show file tree
Hide file tree
Showing 43 changed files with 324 additions and 261 deletions.
6 changes: 0 additions & 6 deletions aten/src/ATen/core/builtin_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,6 @@ struct BuiltinOpFunction : public Function {
// nop
}

std::shared_ptr<Graph> graph() const override {
TORCH_INTERNAL_ASSERT(false , "BuiltinFunction had a graph requested "
"from it. This probably indicates that the JIT calling context needs a "
"special case on Function::isGraphFunction()");
}

std::shared_ptr<Graph> optimized_graph() const override {
TORCH_INTERNAL_ASSERT(false , "BuiltinFunction had a graph requested "
"from it. This probably indicates that the JIT calling context needs a "
Expand Down
2 changes: 0 additions & 2 deletions aten/src/ATen/core/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@ struct TORCH_API Function {
// if this isn't yet defined, run its method_creator function
virtual void ensure_defined() = 0;

virtual std::shared_ptr<Graph> graph() const = 0;

virtual std::shared_ptr<Graph> optimized_graph() const = 0;

virtual void clear_execution_info() = 0;
Expand Down
2 changes: 1 addition & 1 deletion binaries/aot_model_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ c10::IValue preprocess(
}

auto method = mod.get_method(FLAGS_method_name);
auto graph = method.function().graph()->copy();
auto graph = toGraphFunction(method.function()).graph()->copy();
auto sizes = getInputSizes(compile_spec);

std::string llvm_asm_code;
Expand Down
8 changes: 5 additions & 3 deletions test/cpp/jit/test_argument_spec.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#include <gtest/gtest.h>

#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/runtime/argument_spec.h>
#include <torch/jit.h>

#include "test/cpp/jit/test_utils.h"
#include "torch/csrc/jit/runtime/argument_spec.h"

namespace torch {
namespace jit {
Expand Down Expand Up @@ -136,11 +138,11 @@ TEST(ArgumentSpecTest, Basic_CUDA) {
auto& GF = at::CUDA(at::kFloat);
auto& GD = at::CUDA(at::kDouble);

auto graph = jit::compile(R"JIT(
auto graph = toGraphFunction(jit::compile(R"JIT(
def fn(a, b, c, d, e):
return a, b, c, d, e
)JIT")
->get_function("fn")
->get_function("fn"))
.graph();

ArgumentSpecCreator arg_spec_creator(*graph);
Expand Down
2 changes: 1 addition & 1 deletion test/cpp/jit/test_backend_compiler_preprocess.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ c10::IValue preprocess(
c10::Dict<IValue, IValue> compiled(StringType::get(), StringType::get());

for (const auto& method : mod.get_methods()) {
auto graph = method.function().graph()->copy();
auto graph = toGraphFunction(method.function()).graph()->copy();
// Must inline the graph for debug info map.
Inline(*graph);
// This is here because to test module hierarchy we will have
Expand Down
2 changes: 1 addition & 1 deletion test/cpp/jit/test_inliner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ TEST(InlinerTest, Basic) {
CompilationUnit cu(testSource);
auto& fn = cu.get_function("foo3");

auto g = fn.graph();
auto g = toGraphFunction(fn).graph();
Inline(*g);
FileCheck().check_count("prim::Print", 3)->run(*g);
}
Expand Down
4 changes: 2 additions & 2 deletions test/cpp/jit/test_lite_interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ struct ClassNamespaceValue : public SugaredValue {

std::shared_ptr<SugaredValue> attr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& name) override {
const auto fullName = c10::QualifiedName(basename_, name);

Expand All @@ -387,7 +387,7 @@ struct ClassNamespaceValue : public SugaredValue {
struct TestModuleResolver : public Resolver {
std::shared_ptr<SugaredValue> resolveValue(
const std::string& name,
Function& m,
GraphFunction& m,
const SourceRange& loc) override {
if (name == "torch") {
return std::make_shared<BuiltinModule>("aten");
Expand Down
27 changes: 15 additions & 12 deletions test/cpp/jit/test_misc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <torch/csrc/autograd/engine.h>
#include <torch/csrc/autograd/generated/variable_factories.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/codegen/fuser/interface.h>
#include <torch/csrc/jit/frontend/code_template.h>
Expand Down Expand Up @@ -473,7 +474,7 @@ TEST(ControlFlowTest, Basic) {
auto cu = compile(cf_examples);

auto run = [&](const std::string& name, std::vector<IValue> stack) {
auto graph = cu->get_function(name).graph();
auto graph = toGraphFunction(cu->get_function(name)).graph();
Code code(graph, "");
InterpreterState interp(code);
interp.run(stack);
Expand Down Expand Up @@ -1609,7 +1610,7 @@ TEST(LoopPeelerTest, NoInductionVariableUse) {
)JIT";

auto cu = compile(str_func_def);
auto& f = cu->get_function("test_peel_n_times");
auto& f = toGraphFunction(cu->get_function("test_peel_n_times"));
auto stack = createStack({});
// peeling loop once
{
Expand Down Expand Up @@ -1651,7 +1652,7 @@ TEST(LoopPeelerTest, YesInductionVariableUse) {
)JIT";

auto cu = compile(str_func_def);
auto& f = cu->get_function("test_peel_n_times");
auto& f = toGraphFunction(cu->get_function("test_peel_n_times"));
auto stack = createStack({});
// peeling loop once
{
Expand Down Expand Up @@ -1697,7 +1698,7 @@ TEST(LoopPeelerTest, LoopWithTerminationCondition) {
// the peel changes the termination condition to false
// so the original loop doesn't run
auto cu = compile(str_func_def);
auto& f = cu->get_function("test_with_cond_times");
auto& f = toGraphFunction(cu->get_function("test_with_cond_times"));
auto stack = createStack({});
// peeling 5 iterations should update the termination
// condition to false
Expand Down Expand Up @@ -1742,7 +1743,7 @@ TEST(LoopPeelerTest, SimpleNestedLoops) {
)JIT";

auto cu = compile(str_func_def);
auto& f = cu->get_function("test_nested_loops");
auto& f = toGraphFunction(cu->get_function("test_nested_loops"));
auto stack = createStack({});

{
Expand Down Expand Up @@ -1782,7 +1783,7 @@ TEST(LoopPeelerTest, SimpleNestedLoops2) {
)JIT";

auto cu = compile(str_func_def);
auto& f = cu->get_function("test_nested_loops");
auto& f = toGraphFunction(cu->get_function("test_nested_loops"));
auto stack = createStack({});
{
LoopsPeeler peeler(true_pred, 1);
Expand Down Expand Up @@ -1859,7 +1860,7 @@ TEST(InsertAndEliminateRedundantGuardsTest, Basic) {
)JIT";

auto cu = compile(basic_example);
auto& fun = cu->get_function("basic");
auto& fun = toGraphFunction(cu->get_function("basic"));
auto pr = ProfilingRecord::instrumentGraph(fun.graph());
auto x = at::randn({2, 3}, at::kCPU);
auto y = at::randn({2, 3}, at::kCPU);
Expand Down Expand Up @@ -1910,7 +1911,7 @@ TEST(InsertBailOutsTest, Basic) {
)JIT";

auto cu = compile(basic_example);
auto& fun = cu->get_function("basic_loop");
auto& fun = toGraphFunction(cu->get_function("basic_loop"));
auto pr = ProfilingRecord::instrumentGraph(fun.graph());
auto x = at::randn({2, 3}, at::kCPU);
auto y = at::randn({2, 3}, at::kCPU);
Expand Down Expand Up @@ -2004,7 +2005,7 @@ def foo(x):
return bar(x)*baz(x)*11
)";
auto cu = compile(text);
const Function& foo = cu->get_function("foo");
const auto& foo = toGraphFunction(cu->get_function("foo"));
for (Node* n : foo.optimized_graph()->nodes()) {
if (n->kind() == prim::Constant) {
if (!n->hasAttribute(attr::value) ||
Expand Down Expand Up @@ -2086,7 +2087,7 @@ def c(x):
return x
)";
auto cu = compile(text);
const Function& baz = cu->get_function("c");
const auto& baz = toGraphFunction(cu->get_function("c"));
std::unordered_map<std::string, InlinedCallStack*> callstack_objects;
for (Node* n : baz.optimized_graph()->nodes()) {
if (n->kind() == prim::Constant) {
Expand Down Expand Up @@ -2131,7 +2132,8 @@ TEST(InlinedCallStackTest, BlockAnnotation) {
return self.A0.forward(x, y, z) + self.B0.forward(x)
)");

auto graph = c.get_method("forward").function().optimized_graph();
auto graph =
toGraphFunction(c.get_method("forward").function()).optimized_graph();
std::stringstream add_ss, mul_ss;
for (Node* n : graph->nodes()) {
if (n->kind() == prim::If) {
Expand Down Expand Up @@ -2192,7 +2194,8 @@ TEST(InlinedCallStackTest, SelfCallMethods) {
return self.A0.forward(x, y) + self.call_b(x)
)");

auto graph = c.get_method("forward").function().optimized_graph();
auto graph =
toGraphFunction(c.get_method("forward").function()).optimized_graph();
std::unordered_map<std::string, size_t> module_hierarchies;
for (Node* n : graph->nodes()) {
auto hierarchy = torch::jit::utils::getNodesModuleHierarchy(*n);
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/distributed/rpc/python_rpc_handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ constexpr auto kInternalModule = "torch.distributed.rpc.internal";
struct PythonTypeResolver : public jit::Resolver {
std::shared_ptr<jit::SugaredValue> resolveValue(
const std::string& /* unused */,
torch::jit::Function& /* unused */,
torch::jit::GraphFunction& /* unused */,
const jit::SourceRange& /* unused */) override {
TORCH_INTERNAL_ASSERT(
false, "RPC Type resolver does not need to resolve value");
Expand Down
37 changes: 36 additions & 1 deletion torch/csrc/jit/api/function_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
namespace torch {
namespace jit {
namespace {
c10::FunctionSchema defaultSchemaFor(const Function& function) {
c10::FunctionSchema defaultSchemaFor(const GraphFunction& function) {
std::vector<c10::Argument> args;
std::vector<c10::Argument> returns;
Graph& g = *function.graph();
Expand All @@ -26,6 +26,29 @@ c10::FunctionSchema defaultSchemaFor(const Function& function) {
}
return {function.name(), "", std::move(args), std::move(returns)};
}

template <typename T, typename F>
T* tryToGraphFunctionImpl(F& function) noexcept {
if (!function.isGraphFunction()) {
return nullptr;
}

return static_cast<T*>(&function);
}

template <typename T, typename F>
T& toGraphFunctionImpl(F& function) {
if (auto* g = tryToGraphFunctionImpl<T>(function)) {
return *g;
}

TORCH_INTERNAL_ASSERT(
false,
"Failed to downcast a Function to a GraphFunction. "
"This probably indicates that the JIT calling context needs a "
"special case on tryToGraphFunction() instead.");
}

} // namespace

void placeholderCreator(GraphFunction&) {
Expand Down Expand Up @@ -82,5 +105,17 @@ void preoptimizeGraph(std::shared_ptr<Graph>& graph) {
ConstantPooling(graph);
}

GraphFunction* tryToGraphFunction(Function& function) noexcept {
return tryToGraphFunctionImpl<GraphFunction>(function);
}

GraphFunction& toGraphFunction(Function& function) {
return toGraphFunctionImpl<GraphFunction>(function);
}

const GraphFunction& toGraphFunction(const Function& function) {
return toGraphFunctionImpl<const GraphFunction>(function);
}

} // namespace jit
} // namespace torch
8 changes: 7 additions & 1 deletion torch/csrc/jit/api/function_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ struct TORCH_API GraphFunction : public Function {
IValue operator()(std::vector<IValue> stack, const Kwargs& kwargs = Kwargs())
override;

std::shared_ptr<Graph> graph() const override {
std::shared_ptr<Graph> graph() const {
return graph_;
}

Expand Down Expand Up @@ -143,5 +143,11 @@ struct TORCH_API GraphFunction : public Function {
// before a call to setSchema
mutable std::unique_ptr<FunctionSchema> schema_;
};

// Short hands for dynamic_cast<GraphFunction*>.
TORCH_API GraphFunction* tryToGraphFunction(Function&) noexcept;
TORCH_API GraphFunction& toGraphFunction(Function&);
TORCH_API const GraphFunction& toGraphFunction(const Function&);

} // namespace jit
} // namespace torch
2 changes: 1 addition & 1 deletion torch/csrc/jit/api/method.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ struct TORCH_API Method : public torch::IMethod {
TaskLauncher taskLauncher = at::launch);

std::shared_ptr<Graph> graph() const {
return function_->graph();
return toGraphFunction(*function_).graph();
}

const std::string& name() const override {
Expand Down
21 changes: 11 additions & 10 deletions torch/csrc/jit/api/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <c10/util/StringUtil.h>
#include <c10/util/irange.h>
#include <torch/csrc/autograd/generated/variable_factories.h>
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/frontend/error_report.h>
#include <torch/csrc/jit/frontend/ir_emitter.h>
Expand All @@ -27,14 +28,14 @@ std::string getInputDebugName(const Node& n, const int idx) {
}

void assert_ignored_methods_not_called(
torch::jit::Function* fn,
torch::jit::Function& fn,
const std::unordered_set<std::string>& ignored_methods) {
if (ignored_methods.empty()) {
return;
}
const bool recurse = true;
std::vector<Node*> all_nodes =
findAllNodes(*fn->graph().get(), c10::prim::CallMethod, recurse);
std::vector<Node*> all_nodes = findAllNodes(
*toGraphFunction(fn).graph(), c10::prim::CallMethod, recurse);

// Extract method names from these nodes.
std::unordered_set<std::string> encountered_ignored_methods;
Expand All @@ -56,22 +57,22 @@ void assert_ignored_methods_not_called(
TORCH_CHECK(
false,
"Preserved method '",
fn->name(),
fn.name(),
"' references ignored method(s) '",
encountered_ignored_methods_str,
"'. This is not permitted.");
}

void assert_ignored_attributes_not_referenced(
torch::jit::Function* fn,
torch::jit::Function& fn,
const std::unordered_set<std::string>& ignored_attributes) {
if (ignored_attributes.empty()) {
return;
}

const bool recurse = true;
std::vector<Node*> all_nodes =
findAllNodes(*fn->graph().get(), c10::prim::GetAttr, recurse);
findAllNodes(*toGraphFunction(fn).graph(), c10::prim::GetAttr, recurse);

// Extract attribute names from these nodes.
std::unordered_set<std::string> encountered_ignored_attributes;
Expand All @@ -93,7 +94,7 @@ void assert_ignored_attributes_not_referenced(
TORCH_CHECK(
false,
"Preserved method '",
fn->name(),
fn.name(),
"' references ignored attribute(s) '",
encountered_ignored_attributes_str,
"'. This is not permitted.");
Expand Down Expand Up @@ -282,7 +283,7 @@ void Module::clone_method(
return in;
return it->second;
};
auto graph = method.graph()->copy();
auto graph = toGraphFunction(method).graph()->copy();
graph->remapTypes(type_remap_fn);
auto schema = method.getSchema().cloneWithRemappedTypes(type_remap_fn);
const auto this_method_name = getNameForMethod(method.name());
Expand Down Expand Up @@ -411,8 +412,8 @@ Module Module::clone_impl(
for (auto& fn : type()->methods()) {
// If this method is not in the list of ignored methods, clone it.
if (ignored_methods.count(fn->name()) == 0) {
assert_ignored_methods_not_called(fn, ignored_methods);
assert_ignored_attributes_not_referenced(fn, ignored_attributes);
assert_ignored_methods_not_called(*fn, ignored_methods);
assert_ignored_attributes_not_referenced(*fn, ignored_attributes);
r.clone_method(*this, *fn, type_remap);
}
}
Expand Down
Loading

0 comments on commit b55a250

Please sign in to comment.