Skip to content

Commit

Permalink
Adds a bool is_available() method to the backend contract (pytorch#53068
Browse files Browse the repository at this point in the history
)

Summary:
Pull Request resolved: pytorch#53068

Adds a ```bool is_available()``` method to the backend contract: it returns ```true``` if ```compile()``` and ```execute()``` can be called; ```false``` otherwise.

It is used to implement the following changes in the ```LoweredModule```:
* ```compile()``` in ```__setstate__``` will run if ```is_available()```, else ```__setstate__``` throws an exception (“Backend not available.”).
* ```compile()``` at ```LoweredModule``` creation will run if ```is_available()```, else a WARNING will be thrown.
* ```execute()``` will only be executed if ```is_available()``` returns true; else throws an exception (“Backend not available.”).

The goal of these changes is to ensure we have a well defined behaviour for the different combinations of backend availability on-host and on-target.

More specifically, backends may have different capabilities to compile and/or execute the Module, depending whether this happens on-host (i.e. where the program is being written) or on-target (where the program is being executed).

First of all, we know that "preprocess" always takes place, and that only happens on-host at creation time. So, we can assume that any compilation is needed/possible on-host then all of it could be pushed here.

Overall, we want to ensure the following:

**On host**

| compile | execute | Outcome |
| -- | -- | -- |
| No | No | On module creation, LoweredModule is generated, with a warning  (since compilation and execution can still take place on-target). On module load, throws an exception (since execution is not possible). |
| No | Yes | This configuration should not be possible. This assumes the full compiler is not available, even if some work was done in preprocess the program cannot be finalized for execution. |
| Yes | No | In this case, the expectation would be for is_available() to return false, and compilation logic to move into preprocess. |
| Yes | Yes | All good. This is the only case that is_available() should return true. |

**On target**

| compile | execute | Outcome |
| -- | -- | -- |
| No | No | Loading the LoweredModule throws an exception. Since execution is not possible. |
| No | Yes | Basically this is another instance of Yes/Yes: compilation per se may not be possible on device, which means compile() can be called without issue but it is a no-op, and thus is_available should return true. Consequently, loading the LoweredModule: Succeeds, if the preprocessed module is ready for execution. Fails with exception otherwise. |
| Yes | No | This configuration should not be possible. Just putting here for completeness. |
| Yes | Yes | All good. This, along with No/Yes case (because compilation is assumed to have happened on-host, so it's just another instance of Yes/Yes), are the cases where is_available() should return true. |

**Refactoring existing code**
This change also updates other backends (Glow) code, to implement the is_available() method to have the same behaviour as before this change (i.e. always available).

This should not cause backward incompatibilities with already saved models since we're adding a new method to the PyTorchBackendInterface.
Models saved with the old interface that didn't have is_available() will still find the other 2 methods in the bound object (i.e. compile and execute), and the saved LoweredModule logic will be the old one.

**Future**
We plan to use is_available() to implement support for fallback to the PyTorch interpreter.
ghstack-source-id: 123498571

Test Plan: Added C++ (test_backend.cpp) and Python (test_backends.py) tests to validate the exceptions.

Reviewed By: jackm321, spaugh, iseeyuan

Differential Revision: D26615833

fbshipit-source-id: 562e8b11db25784348b5f86bbc4179aedf15e0d3
  • Loading branch information
raziel authored and facebook-github-bot committed Mar 10, 2021
1 parent 215950e commit c5cd993
Show file tree
Hide file tree
Showing 10 changed files with 181 additions and 14 deletions.
33 changes: 33 additions & 0 deletions test/cpp/jit/test_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,39 @@ TEST(BackendTest, ToBackend) {
AT_ASSERT(res[1].toTensor().equal(ref[1].toTensor()));
}

TEST(BackendTest, ToBackendNotAvailable) {
Module m("m");
m.define(R"(
def forward(self, x, h):
return self.accum(x, h), self.sub_accum(x, h)
def accum(self, x, h):
return x + h
def sub_accum(self, x, h):
return x - h
)");

std::vector<IValue> inputs;
inputs.emplace_back(2.0 * torch::ones({}));
inputs.emplace_back(1.0 * torch::ones({}));
auto ref = m.forward(inputs).toTuple()->elements();

c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
fake_dict.insert("", "");
compile_spec.insert("forward", fake_dict);
auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
// Produce lowered module (backend not available).
// Exception is not thrown at this point.
auto lm = torch::jit::detail::codegen_backend_module(
"test_backend_unavailable", m, compile_spec, any_dict_ty);
// Validate exception is thrown when trying to execute and
// the backend is not available.
ASSERT_THROWS_WITH_MESSAGE(
lm.forward(inputs).toTuple()->elements(), "Backend is not available.");
}

TEST(BackendTest, TestCompiler) {
Module m("m");
m.define(R"(
Expand Down
4 changes: 4 additions & 0 deletions test/cpp/jit/test_backend_compiler_lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ class BackendWithCompiler : public PyTorchBackendInterface {
explicit BackendWithCompiler() {}
virtual ~BackendWithCompiler() = default;

bool is_available() override {
return true;
}

// Since the actual compilation is done AOT,
c10::impl::GenericDict compile(
c10::IValue processed,
Expand Down
11 changes: 10 additions & 1 deletion test/cpp/jit/test_backend_lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,17 @@ namespace jit {
// necessary to test that the JIT backend registration endpoints and
// code generation are working correctly. It is not intended to
// produce numerically correct results.
template <bool isAvailable>
class TestBackend : public PyTorchBackendInterface {
public:
// Constructor.
explicit TestBackend() {}
virtual ~TestBackend() = default;

bool is_available() override {
return isAvailable;
}

c10::impl::GenericDict compile(
c10::IValue processed,
c10::impl::GenericDict method_compile_spec) override {
Expand Down Expand Up @@ -68,7 +73,11 @@ c10::IValue preprocess(
return mod._ivalue();
}

static auto cls = torch::jit::backend<TestBackend>("test_backend", preprocess);
static auto cls_available =
torch::jit::backend<TestBackend<true>>("test_backend", preprocess);
static auto cls_unavailable = torch::jit::backend<TestBackend<false>>(
"test_backend_unavailable",
preprocess);
} // namespace

} // namespace jit
Expand Down
4 changes: 4 additions & 0 deletions test/custom_backend/custom_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ class CustomBackend : public torch::jit::PyTorchBackendInterface {
explicit CustomBackend() {}
virtual ~CustomBackend() = default;

bool is_available() override {
return true;
}

c10::impl::GenericDict compile(
c10::IValue processed,
c10::impl::GenericDict method_compile_spec) override {
Expand Down
53 changes: 51 additions & 2 deletions test/jit/test_backends.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from torch.testing._internal.jit_utils import JitTestCase
import io
import os
import sys
import unittest
Expand Down Expand Up @@ -61,6 +62,9 @@ def sub_accum(self, x, h):
return x - h


# This is ignored in IS_WINDOWS or IS_MACOS cases. Hence we need the one in TestBackends.
@unittest.skipIf(TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
"Non-portable load_library call used in test")
class JitBackendTestCase(JitTestCase):
"""
A common base class for JIT backend tests that contains common utility
Expand All @@ -69,8 +73,6 @@ class JitBackendTestCase(JitTestCase):

def setUp(self):
super().setUp()
if TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE:
raise unittest.SkipTest("non-portable load_library call used in test")
torch_root = Path(__file__).resolve().parent.parent.parent
p = torch_root / 'build' / 'lib' / 'libjitbackend_test.so'
torch.ops.load_library(str(p))
Expand Down Expand Up @@ -168,6 +170,46 @@ def test_save_load(self):
self.test_execution()


class BasicModuleUnavailableTest(JitBackendTestCase):
"""
Tests for BasicModule with a backend that is not available.
Fundamentally:
* _jit_to_backend is successful.
* Execution fails with an exception.
* Saving is successful.
* Loading fails with an exception.
"""

def setUp(self):
super().setUp()
# Create Python, JIT and backend versions of BasicModule.
self.module = BasicModule()
self.scripted_module = torch.jit.script(BasicModule())
self.lowered_module = torch._C._jit_to_backend(
"test_backend_unavailable",
self.scripted_module,
{"forward": {"": ""}},
)

def test_execution(self):
# Test execution with backend fails because the backend that is not available.
input = torch.randn(5)

# Test exception is thrown.
with self.assertRaisesRegex(Exception, r"Backend is not available."):
backend_method = self.lowered_module.__getattr__("forward")
backend_output = backend_method(*(input, input))

@skipIfRocm
def test_save_load(self):
# Test that saving the lowered module is OK but loading fails because the backend is not available.
buffer = io.BytesIO()
torch.jit.save(self.lowered_module, buffer)
buffer.seek(0)
with self.assertRaisesRegex(Exception, r"Backend is not available."):
imported = torch.jit.load(buffer)


class NestedModuleTest(JitBackendTestCase):
"""
Tests for NestedModule that check that a module lowered to a backend can be used
Expand Down Expand Up @@ -376,6 +418,9 @@ def test_errors(self):
to_test_backend_selective(torch.jit.script(mod), {"forward": ""}, ["sub1.submodule"])


# This is needed for IS_WINDOWS or IS_MACOS to skip the tests.
@unittest.skipIf(TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
"Non-portable load_library call used in test")
class TestBackends(JitTestCase):
"""
This class wraps and invokes all subclasses of JitBackendTestCase so that each one
Expand All @@ -385,25 +430,29 @@ class TestBackends(JitTestCase):
def __init__(self, name):
super().__init__(name)
self.basic_module_test = BasicModuleTest(name)
self.basic_module_unavailable_test = BasicModuleUnavailableTest(name)
self.nested_module_test = NestedModuleTest(name)
self.selective_lowering_test = SelectiveLoweringTest(name)

def setUp(self):
super().setUp()
if not TEST_WITH_ROCM:
self.basic_module_test.setUp()
self.basic_module_unavailable_test.setUp()
self.nested_module_test.setUp()
self.selective_lowering_test.setUp()

@skipIfRocm
def test_execution(self):
self.basic_module_test.test_execution()
self.basic_module_unavailable_test.test_execution()
self.nested_module_test.test_execution()
self.selective_lowering_test.test_execution()

@skipIfRocm
def test_save_load(self):
self.basic_module_test.test_save_load()
self.basic_module_unavailable_test.test_save_load()
self.nested_module_test.test_save_load()
self.selective_lowering_test.test_save_load()

Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/jit/backends/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ class backend {
static auto cls =
torch::class_<TBackendInterface>(detail::kBackendsNamespace, name)
.def(torch::init<>())
._def_unboxed(
"is_available",
detail::getIsAvailableFunc<TBackendInterface>(),
detail::getIsAvailableSchema())
._def_unboxed(
"compile",
detail::getCompileFunc<TBackendInterface>(),
Expand Down
71 changes: 60 additions & 11 deletions torch/csrc/jit/backends/backend_detail.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,17 @@
namespace torch {
namespace jit {
namespace detail {
c10::FunctionSchema getIsAvailableSchema() {
c10::Argument self("self", c10::AnyType::get());
c10::Argument available("available", c10::BoolType::get());
c10::FunctionSchema preprocessor_schema(
"is_available",
/*overload_name=*/"",
/*arguments=*/{self},
/*returns=*/{available});
return preprocessor_schema;
}

c10::FunctionSchema getCompileSchema() {
c10::Argument self("self", c10::AnyType::get());
c10::Argument mod("processed", c10::AnyType::get());
Expand Down Expand Up @@ -147,12 +158,29 @@ Module codegen_backend_module(
loweredModule.define(
create_backend_ct.format(create_backend_te), loweredModuleResolver());

// Helper function to expose backend.is_available() to Module generation code.
// Assumes self.__backend exists (i.e. __create_backend() has already been
// invoked).
loweredModule.define(
R"(
def __is_available(self):
return self.__backend.is_available()
)",
loweredModuleResolver());

// getstate and setstate are for serialization/deserialization of
// the LoweredModule.
// setstate is in charge of initializing self.__backend by invoking
// __create_backend().
loweredModule.define(
R"(
def __getstate__(self):
return self.__method_compile_spec, self.__processed_module
# The third parameter indicates whether __setstate__ must create
# the backend instance. It's hardcoded to True since the only
# case it can be false is when __setstate__ is called from
# outside the module (at module creation time), because
# __create_backed has been called already (also directly).
return self.__method_compile_spec, self.__processed_module, True
)",
loweredModuleResolver());

Expand All @@ -161,8 +189,13 @@ Module codegen_backend_module(
def __setstate__(self, state):
self.__method_compile_spec = state[0]
self.__processed_module = state[1]
self.__create_backend()
self.__handles = self.__backend.compile(self.__processed_module, self.__method_compile_spec)
# state[2] indicates whether to create the backend instance.
if state[2]:
self.__create_backend()
if self.__backend.is_available() :
self.__handles = self.__backend.compile(self.__processed_module, self.__method_compile_spec)
else:
raise Exception("Backend is not available.")
)",
loweredModuleResolver());

Expand All @@ -173,9 +206,12 @@ Module codegen_backend_module(
static const auto method_ct = CodeTemplate(R"(
def $method(self${,def_inputs}):
typed_inputs: List[Any] = [${fwd_inputs,}]
$unpack, = self.__backend.execute(self.__handles["$method"], typed_inputs)
${refine,}
return $ret
if self.__backend.is_available() :
$unpack, = self.__backend.execute(self.__handles["$method"], typed_inputs)
${refine,}
return $ret
else:
raise Exception("Backend is not available.")
)");

TemplateEnv method_te;
Expand Down Expand Up @@ -264,11 +300,24 @@ Module codegen_backend_module(
loweredModule.define(method_ct.format(method_te), loweredModuleResolver());
}

// Call __setstate__ to ensure that the returned Module is ready to
// run.
auto state = at::ivalue::Tuple::create(
method_compile_spec, loweredModule.attr("__processed_module"));
loweredModule.run_method("__setstate__", state);
// If backend is available, call __setstate__ to ensure that the returned
// Module is ready to run.
// Otherwise throw a warning indicating that the resulting Module is not
// ready for execution until is loaded to a device with the backend.
loweredModule.run_method("__create_backend");
if (loweredModule.run_method("__is_available").toBool()) {
auto state = at::ivalue::Tuple::create(
method_compile_spec,
loweredModule.attr("__processed_module"),
/*create_backend*/ false);
loweredModule.run_method("__setstate__", state);
} else {
TORCH_WARN(
"Backend [",
backend_name,
"] is not available. Execution of this Module is still possible by "
"saving and loading on a device where the backend is available.");
}
return loweredModule;
}
} // namespace detail
Expand Down
10 changes: 10 additions & 0 deletions torch/csrc/jit/backends/backend_detail.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,19 @@ namespace detail {

constexpr static auto kBackendsNamespace = "__backends__";

c10::FunctionSchema TORCH_API getIsAvailableSchema();
c10::FunctionSchema TORCH_API getCompileSchema();
c10::FunctionSchema TORCH_API getExecuteSchema();

template <typename TBackendInterface>
std::function<void(Stack&)> getIsAvailableFunc() {
return [](Stack& stack) {
auto self = pop(stack).toCustomClass<TBackendInterface>();
auto ret = self->is_available();
push(stack, ret);
};
}

template <typename TBackendInterface>
std::function<void(Stack&)> getCompileFunc() {
return [](Stack& stack) {
Expand Down
3 changes: 3 additions & 0 deletions torch/csrc/jit/backends/backend_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ class TORCH_API PyTorchBackendInterface : public torch::CustomClassHolder {
PyTorchBackendInterface();
virtual ~PyTorchBackendInterface();

// Returns true if the backend is available to process delegation calls.
virtual bool is_available() = 0;

// Compile the module contained in \p processed using the details provided in
// \p method_compile_spec for each module method that should be compiled for
// the backend. \p method_compile_spec should be of type Dict<string, Any>.
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/backends/backend_resolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ struct LoweredModuleResolver : public Resolver {
return std::make_shared<BuiltinModule>("aten");
} else if (name == "__torch__") {
return std::make_shared<ClassNamespaceValue>(c10::QualifiedName(name));
} else if (name == "Exception") {
return std::make_shared<ExceptionValue>(name);
}

return nullptr;
Expand Down

0 comments on commit c5cd993

Please sign in to comment.