Skip to content

Commit

Permalink
Revert D16197605: [jit] Make traced fns also go into the global pytho…
Browse files Browse the repository at this point in the history
…n CU

Differential Revision:
D16197605

Original commit changeset: d32c975486b0

fbshipit-source-id: a00f0490cc23824792f3e745d7b5a003b1a33d20
  • Loading branch information
suo authored and facebook-github-bot committed Jul 16, 2019
1 parent a326aad commit c5afdd0
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 31 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/core/qualified_name.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ struct QualifiedName {

// `name` can be a dotted string, like "foo.bar.baz", or just a bare name.
/* implicit */ QualifiedName(const std::string& name) {
TORCH_CHECK(!name.empty());
AT_ASSERT(!name.empty());
// split the string into its atoms.
size_t startSearchFrom = 0;
size_t pos = name.find(delimiter_, startSearchFrom);
Expand Down
1 change: 0 additions & 1 deletion torch/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def __getattr__(self, op_name):
# with qualified_op_name
torch.jit._register_builtin(op, qualified_op_name)
setattr(self, op_name, op)
op.__module__ = self.__module__ + "." + self.name
return op


Expand Down
12 changes: 6 additions & 6 deletions torch/csrc/jit/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -428,29 +428,29 @@ void initJITBindings(PyObject* module) {

m.def(
"_jit_get_operation",
[](const std::string& op_name) {
[](const std::string& qualified_name) {
try {
auto symbol = Symbol::fromQualString(op_name);
auto symbol = Symbol::fromQualString(qualified_name);
auto operations = getAllOperatorsFor(symbol);
TORCH_CHECK(!operations.empty(), "No such operator ", op_name);
TORCH_CHECK(!operations.empty(), "No such operator ", qualified_name);
TORCH_CHECK(
operations.size() == 1,
"Found ",
operations.size(),
" overloads for operator ",
op_name,
qualified_name,
"! Overloads are not supported from Python.");
std::shared_ptr<Operator> op = operations[0];
AT_ASSERT(op != nullptr);
std::ostringstream docstring;
docstring << "Automatically bound operator '" << op_name
docstring << "Automatically bound operator '" << qualified_name
<< "' with schema: " << op->schema();
return py::cpp_function(
[op](py::args args, py::kwargs kwargs) {
return invokeOperatorFromPython(
*op, std::move(args), std::move(kwargs));
},
py::name(symbol.toUnqualString()),
py::name(qualified_name.c_str()),
py::doc(docstring.str().c_str()));
} catch (const c10::Error& error) {
throw std::runtime_error(error.what_without_backtrace());
Expand Down
6 changes: 1 addition & 5 deletions torch/csrc/jit/script/compilation_unit.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,7 @@ struct TORCH_API CompilationUnit {

Function* create_function(
c10::QualifiedName name,
std::shared_ptr<Graph> graph,
bool shouldMangle = false) {
if (shouldMangle) {
name = c10::QualifiedName(name.prefix(), mangle(name.name()));
}
std::shared_ptr<Graph> graph) {
auto fn = torch::make_unique<Function>(
std::move(name), is_optimized(), std::move(graph), nullptr);
auto ret = fn.get();
Expand Down
21 changes: 8 additions & 13 deletions torch/csrc/jit/script/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -524,9 +524,8 @@ void initJitScriptBindings(PyObject* module) {
auto graph = tracer::createGraphByTracing(
func, typed_inputs, var_lookup_fn, force_outplace, &self);
const auto method_name = QualifiedName(self.name(), name);
auto fn = self.class_compilation_unit()->create_function(
self.module_object()->compilation_unit()->create_function(
method_name, graph);
self.type()->addMethod(fn);
didFinishEmitModule(self);
})
.def(
Expand Down Expand Up @@ -655,13 +654,9 @@ void initJitScriptBindings(PyObject* module) {
[](const StrongFunctionPtr& self) {
return self.function_->get_executor().getDebugState();
})
.def_property_readonly(
"name",
[](const StrongFunctionPtr& self) { return self.function_->name(); })
.def_property_readonly(
"qualified_name", [](const StrongFunctionPtr& self) {
return self.function_->qualname().qualifiedName();
});
.def_property_readonly("name", [](const StrongFunctionPtr& self) {
return self.function_->name();
});

py::class_<Method>(m, "ScriptMethod", py::dynamic_attr())
.def(
Expand Down Expand Up @@ -727,10 +722,10 @@ void initJitScriptBindings(PyObject* module) {
auto typed_inputs = toTypedStack(input_tuple);
auto graph = tracer::createGraphByTracing(
func, typed_inputs, var_lookup_fn, force_outplace);
auto cu = get_python_cu();
auto name = c10::QualifiedName(qualname);
auto result = cu->create_function(
std::move(name), std::move(graph), /*shouldMangle=*/true);
// TODO this should go in the global Python CU
auto cu = std::make_shared<CompilationUnit>();
const auto name = c10::QualifiedName(qualname);
auto result = cu->create_function(std::move(name), std::move(graph));
StrongFunctionPtr ret(std::move(cu), result);
didFinishEmitFunction(ret);
return ret;
Expand Down
6 changes: 1 addition & 5 deletions torch/jit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,7 @@ def weighted_kernel_sum(self, weight):
raise AttributeError("trace doesn't support compiling individual module's functions.\n"
"Please use trace_module")

name = _qualified_name(func)
name = getattr(func, '__name__', 'forward')
if name == '<lambda>':
name = '_lambda' # make name a valid identifier
traced = torch._C._create_function_from_trace(name, func, example_inputs,
Expand Down Expand Up @@ -1040,10 +1040,6 @@ def whichmodule(obj):

# Retrieves a fully-qualified name (module hierarchy + classname) for a given obj.
def _qualified_name(obj):
# short-circuit in cases where the object already has a known qualified name
if isinstance(obj, torch._C.Function):
return obj.qualified_name

name = obj.__name__
module_name = obj.__module__

Expand Down

0 comments on commit c5afdd0

Please sign in to comment.