From c5cd993add01f5114851b6b868824d59752642ba Mon Sep 17 00:00:00 2001 From: Raziel Alvarez Guevara Date: Wed, 10 Mar 2021 00:21:34 -0800 Subject: [PATCH] Adds a bool is_available() method to the backend contract (#53068) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/53068 Adds a ```bool is_available()``` method to the backend contract: it returns ```true``` if ```compile()``` and ```execute()``` can be called; ```false``` otherwise. It is used to implement the following changes in the ```LoweredModule```: * ```compile()``` in ```__setstate__``` will run if ```is_available()```, else ```__setstate__``` throws an exception (“Backend not available.”). * ```compile()``` at ```LoweredModule``` creation will run if ```is_available()```, else a WARNING will be thrown. * ```execute()``` will only be executed if ```is_available()``` returns true; else throws an exception (“Backend not available.”). The goal of these changes is to ensure we have a well defined behaviour for the different combinations of backend availability on-host and on-target. More specifically, backends may have different capabilities to compile and/or execute the Module, depending whether this happens on-host (i.e. where the program is being written) or on-target (where the program is being executed). First of all, we know that "preprocess" always takes place, and that only happens on-host at creation time. So, we can assume that any compilation is needed/possible on-host then all of it could be pushed here. Overall, we want to ensure the following: **On host** | compile | execute | Outcome | | -- | -- | -- | | No | No | On module creation, LoweredModule is generated, with a warning (since compilation and execution can still take place on-target). On module load, throws an exception (since execution is not possible). | | No | Yes | This configuration should not be possible. This assumes the full compiler is not available, even if some work was done in preprocess the program cannot be finalized for execution. | | Yes | No | In this case, the expectation would be for is_available() to return false, and compilation logic to move into preprocess. | | Yes | Yes | All good. This is the only case that is_available() should return true. | **On target** | compile | execute | Outcome | | -- | -- | -- | | No | No | Loading the LoweredModule throws an exception. Since execution is not possible. | | No | Yes | Basically this is another instance of Yes/Yes: compilation per se may not be possible on device, which means compile() can be called without issue but it is a no-op, and thus is_available should return true. Consequently, loading the LoweredModule: Succeeds, if the preprocessed module is ready for execution. Fails with exception otherwise. | | Yes | No | This configuration should not be possible. Just putting here for completeness. | | Yes | Yes | All good. This, along with No/Yes case (because compilation is assumed to have happened on-host, so it's just another instance of Yes/Yes), are the cases where is_available() should return true. | **Refactoring existing code** This change also updates other backends (Glow) code, to implement the is_available() method to have the same behaviour as before this change (i.e. always available). This should not cause backward incompatibilities with already saved models since we're adding a new method to the PyTorchBackendInterface. Models saved with the old interface that didn't have is_available() will still find the other 2 methods in the bound object (i.e. compile and execute), and the saved LoweredModule logic will be the old one. **Future** We plan to use is_available() to implement support for fallback to the PyTorch interpreter. ghstack-source-id: 123498571 Test Plan: Added C++ (test_backend.cpp) and Python (test_backends.py) tests to validate the exceptions. Reviewed By: jackm321, spaugh, iseeyuan Differential Revision: D26615833 fbshipit-source-id: 562e8b11db25784348b5f86bbc4179aedf15e0d3 --- test/cpp/jit/test_backend.cpp | 33 +++++++++ test/cpp/jit/test_backend_compiler_lib.cpp | 4 ++ test/cpp/jit/test_backend_lib.cpp | 11 ++- test/custom_backend/custom_backend.h | 4 ++ test/jit/test_backends.py | 53 ++++++++++++++- torch/csrc/jit/backends/backend.h | 4 ++ torch/csrc/jit/backends/backend_detail.cpp | 71 +++++++++++++++++--- torch/csrc/jit/backends/backend_detail.h | 10 +++ torch/csrc/jit/backends/backend_interface.h | 3 + torch/csrc/jit/backends/backend_resolver.cpp | 2 + 10 files changed, 181 insertions(+), 14 deletions(-) diff --git a/test/cpp/jit/test_backend.cpp b/test/cpp/jit/test_backend.cpp index a0f8cef1f9e31..cca8294c407e6 100644 --- a/test/cpp/jit/test_backend.cpp +++ b/test/cpp/jit/test_backend.cpp @@ -79,6 +79,39 @@ TEST(BackendTest, ToBackend) { AT_ASSERT(res[1].toTensor().equal(ref[1].toTensor())); } +TEST(BackendTest, ToBackendNotAvailable) { + Module m("m"); + m.define(R"( + def forward(self, x, h): + return self.accum(x, h), self.sub_accum(x, h) + + def accum(self, x, h): + return x + h + + def sub_accum(self, x, h): + return x - h + )"); + + std::vector inputs; + inputs.emplace_back(2.0 * torch::ones({})); + inputs.emplace_back(1.0 * torch::ones({})); + auto ref = m.forward(inputs).toTuple()->elements(); + + c10::Dict compile_spec(StringType::get(), AnyType::get()); + c10::Dict fake_dict(StringType::get(), AnyType::get()); + fake_dict.insert("", ""); + compile_spec.insert("forward", fake_dict); + auto any_dict_ty = DictType::create(StringType::get(), AnyType::get()); + // Produce lowered module (backend not available). + // Exception is not thrown at this point. + auto lm = torch::jit::detail::codegen_backend_module( + "test_backend_unavailable", m, compile_spec, any_dict_ty); + // Validate exception is thrown when trying to execute and + // the backend is not available. + ASSERT_THROWS_WITH_MESSAGE( + lm.forward(inputs).toTuple()->elements(), "Backend is not available."); +} + TEST(BackendTest, TestCompiler) { Module m("m"); m.define(R"( diff --git a/test/cpp/jit/test_backend_compiler_lib.cpp b/test/cpp/jit/test_backend_compiler_lib.cpp index 93e639b5d5d86..caf4e3f0494d6 100644 --- a/test/cpp/jit/test_backend_compiler_lib.cpp +++ b/test/cpp/jit/test_backend_compiler_lib.cpp @@ -49,6 +49,10 @@ class BackendWithCompiler : public PyTorchBackendInterface { explicit BackendWithCompiler() {} virtual ~BackendWithCompiler() = default; + bool is_available() override { + return true; + } + // Since the actual compilation is done AOT, c10::impl::GenericDict compile( c10::IValue processed, diff --git a/test/cpp/jit/test_backend_lib.cpp b/test/cpp/jit/test_backend_lib.cpp index e401c72ff0595..5c10308f40cdf 100644 --- a/test/cpp/jit/test_backend_lib.cpp +++ b/test/cpp/jit/test_backend_lib.cpp @@ -6,12 +6,17 @@ namespace jit { // necessary to test that the JIT backend registration endpoints and // code generation are working correctly. It is not intended to // produce numerically correct results. +template class TestBackend : public PyTorchBackendInterface { public: // Constructor. explicit TestBackend() {} virtual ~TestBackend() = default; + bool is_available() override { + return isAvailable; + } + c10::impl::GenericDict compile( c10::IValue processed, c10::impl::GenericDict method_compile_spec) override { @@ -68,7 +73,11 @@ c10::IValue preprocess( return mod._ivalue(); } -static auto cls = torch::jit::backend("test_backend", preprocess); +static auto cls_available = + torch::jit::backend>("test_backend", preprocess); +static auto cls_unavailable = torch::jit::backend>( + "test_backend_unavailable", + preprocess); } // namespace } // namespace jit diff --git a/test/custom_backend/custom_backend.h b/test/custom_backend/custom_backend.h index 125339b98c94a..b1f8ca13609dc 100644 --- a/test/custom_backend/custom_backend.h +++ b/test/custom_backend/custom_backend.h @@ -12,6 +12,10 @@ class CustomBackend : public torch::jit::PyTorchBackendInterface { explicit CustomBackend() {} virtual ~CustomBackend() = default; + bool is_available() override { + return true; + } + c10::impl::GenericDict compile( c10::IValue processed, c10::impl::GenericDict method_compile_spec) override { diff --git a/test/jit/test_backends.py b/test/jit/test_backends.py index 9f61ec77a1f60..2e1b786f2b464 100644 --- a/test/jit/test_backends.py +++ b/test/jit/test_backends.py @@ -1,4 +1,5 @@ from torch.testing._internal.jit_utils import JitTestCase +import io import os import sys import unittest @@ -61,6 +62,9 @@ def sub_accum(self, x, h): return x - h +# This is ignored in IS_WINDOWS or IS_MACOS cases. Hence we need the one in TestBackends. +@unittest.skipIf(TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE, + "Non-portable load_library call used in test") class JitBackendTestCase(JitTestCase): """ A common base class for JIT backend tests that contains common utility @@ -69,8 +73,6 @@ class JitBackendTestCase(JitTestCase): def setUp(self): super().setUp() - if TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE: - raise unittest.SkipTest("non-portable load_library call used in test") torch_root = Path(__file__).resolve().parent.parent.parent p = torch_root / 'build' / 'lib' / 'libjitbackend_test.so' torch.ops.load_library(str(p)) @@ -168,6 +170,46 @@ def test_save_load(self): self.test_execution() +class BasicModuleUnavailableTest(JitBackendTestCase): + """ + Tests for BasicModule with a backend that is not available. + Fundamentally: + * _jit_to_backend is successful. + * Execution fails with an exception. + * Saving is successful. + * Loading fails with an exception. + """ + + def setUp(self): + super().setUp() + # Create Python, JIT and backend versions of BasicModule. + self.module = BasicModule() + self.scripted_module = torch.jit.script(BasicModule()) + self.lowered_module = torch._C._jit_to_backend( + "test_backend_unavailable", + self.scripted_module, + {"forward": {"": ""}}, + ) + + def test_execution(self): + # Test execution with backend fails because the backend that is not available. + input = torch.randn(5) + + # Test exception is thrown. + with self.assertRaisesRegex(Exception, r"Backend is not available."): + backend_method = self.lowered_module.__getattr__("forward") + backend_output = backend_method(*(input, input)) + + @skipIfRocm + def test_save_load(self): + # Test that saving the lowered module is OK but loading fails because the backend is not available. + buffer = io.BytesIO() + torch.jit.save(self.lowered_module, buffer) + buffer.seek(0) + with self.assertRaisesRegex(Exception, r"Backend is not available."): + imported = torch.jit.load(buffer) + + class NestedModuleTest(JitBackendTestCase): """ Tests for NestedModule that check that a module lowered to a backend can be used @@ -376,6 +418,9 @@ def test_errors(self): to_test_backend_selective(torch.jit.script(mod), {"forward": ""}, ["sub1.submodule"]) +# This is needed for IS_WINDOWS or IS_MACOS to skip the tests. +@unittest.skipIf(TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE, + "Non-portable load_library call used in test") class TestBackends(JitTestCase): """ This class wraps and invokes all subclasses of JitBackendTestCase so that each one @@ -385,6 +430,7 @@ class TestBackends(JitTestCase): def __init__(self, name): super().__init__(name) self.basic_module_test = BasicModuleTest(name) + self.basic_module_unavailable_test = BasicModuleUnavailableTest(name) self.nested_module_test = NestedModuleTest(name) self.selective_lowering_test = SelectiveLoweringTest(name) @@ -392,18 +438,21 @@ def setUp(self): super().setUp() if not TEST_WITH_ROCM: self.basic_module_test.setUp() + self.basic_module_unavailable_test.setUp() self.nested_module_test.setUp() self.selective_lowering_test.setUp() @skipIfRocm def test_execution(self): self.basic_module_test.test_execution() + self.basic_module_unavailable_test.test_execution() self.nested_module_test.test_execution() self.selective_lowering_test.test_execution() @skipIfRocm def test_save_load(self): self.basic_module_test.test_save_load() + self.basic_module_unavailable_test.test_save_load() self.nested_module_test.test_save_load() self.selective_lowering_test.test_save_load() diff --git a/torch/csrc/jit/backends/backend.h b/torch/csrc/jit/backends/backend.h index b423f663f7ed9..2d648d81bf298 100644 --- a/torch/csrc/jit/backends/backend.h +++ b/torch/csrc/jit/backends/backend.h @@ -28,6 +28,10 @@ class backend { static auto cls = torch::class_(detail::kBackendsNamespace, name) .def(torch::init<>()) + ._def_unboxed( + "is_available", + detail::getIsAvailableFunc(), + detail::getIsAvailableSchema()) ._def_unboxed( "compile", detail::getCompileFunc(), diff --git a/torch/csrc/jit/backends/backend_detail.cpp b/torch/csrc/jit/backends/backend_detail.cpp index 512ed5ba4493f..5dbd3dcdec053 100644 --- a/torch/csrc/jit/backends/backend_detail.cpp +++ b/torch/csrc/jit/backends/backend_detail.cpp @@ -10,6 +10,17 @@ namespace torch { namespace jit { namespace detail { +c10::FunctionSchema getIsAvailableSchema() { + c10::Argument self("self", c10::AnyType::get()); + c10::Argument available("available", c10::BoolType::get()); + c10::FunctionSchema preprocessor_schema( + "is_available", + /*overload_name=*/"", + /*arguments=*/{self}, + /*returns=*/{available}); + return preprocessor_schema; +} + c10::FunctionSchema getCompileSchema() { c10::Argument self("self", c10::AnyType::get()); c10::Argument mod("processed", c10::AnyType::get()); @@ -147,12 +158,29 @@ Module codegen_backend_module( loweredModule.define( create_backend_ct.format(create_backend_te), loweredModuleResolver()); + // Helper function to expose backend.is_available() to Module generation code. + // Assumes self.__backend exists (i.e. __create_backend() has already been + // invoked). + loweredModule.define( + R"( + def __is_available(self): + return self.__backend.is_available() + )", + loweredModuleResolver()); + // getstate and setstate are for serialization/deserialization of // the LoweredModule. + // setstate is in charge of initializing self.__backend by invoking + // __create_backend(). loweredModule.define( R"( def __getstate__(self): - return self.__method_compile_spec, self.__processed_module + # The third parameter indicates whether __setstate__ must create + # the backend instance. It's hardcoded to True since the only + # case it can be false is when __setstate__ is called from + # outside the module (at module creation time), because + # __create_backed has been called already (also directly). + return self.__method_compile_spec, self.__processed_module, True )", loweredModuleResolver()); @@ -161,8 +189,13 @@ Module codegen_backend_module( def __setstate__(self, state): self.__method_compile_spec = state[0] self.__processed_module = state[1] - self.__create_backend() - self.__handles = self.__backend.compile(self.__processed_module, self.__method_compile_spec) + # state[2] indicates whether to create the backend instance. + if state[2]: + self.__create_backend() + if self.__backend.is_available() : + self.__handles = self.__backend.compile(self.__processed_module, self.__method_compile_spec) + else: + raise Exception("Backend is not available.") )", loweredModuleResolver()); @@ -173,9 +206,12 @@ Module codegen_backend_module( static const auto method_ct = CodeTemplate(R"( def $method(self${,def_inputs}): typed_inputs: List[Any] = [${fwd_inputs,}] - $unpack, = self.__backend.execute(self.__handles["$method"], typed_inputs) - ${refine,} - return $ret + if self.__backend.is_available() : + $unpack, = self.__backend.execute(self.__handles["$method"], typed_inputs) + ${refine,} + return $ret + else: + raise Exception("Backend is not available.") )"); TemplateEnv method_te; @@ -264,11 +300,24 @@ Module codegen_backend_module( loweredModule.define(method_ct.format(method_te), loweredModuleResolver()); } - // Call __setstate__ to ensure that the returned Module is ready to - // run. - auto state = at::ivalue::Tuple::create( - method_compile_spec, loweredModule.attr("__processed_module")); - loweredModule.run_method("__setstate__", state); + // If backend is available, call __setstate__ to ensure that the returned + // Module is ready to run. + // Otherwise throw a warning indicating that the resulting Module is not + // ready for execution until is loaded to a device with the backend. + loweredModule.run_method("__create_backend"); + if (loweredModule.run_method("__is_available").toBool()) { + auto state = at::ivalue::Tuple::create( + method_compile_spec, + loweredModule.attr("__processed_module"), + /*create_backend*/ false); + loweredModule.run_method("__setstate__", state); + } else { + TORCH_WARN( + "Backend [", + backend_name, + "] is not available. Execution of this Module is still possible by " + "saving and loading on a device where the backend is available."); + } return loweredModule; } } // namespace detail diff --git a/torch/csrc/jit/backends/backend_detail.h b/torch/csrc/jit/backends/backend_detail.h index 11a701245693f..1d75378017566 100644 --- a/torch/csrc/jit/backends/backend_detail.h +++ b/torch/csrc/jit/backends/backend_detail.h @@ -13,9 +13,19 @@ namespace detail { constexpr static auto kBackendsNamespace = "__backends__"; +c10::FunctionSchema TORCH_API getIsAvailableSchema(); c10::FunctionSchema TORCH_API getCompileSchema(); c10::FunctionSchema TORCH_API getExecuteSchema(); +template +std::function getIsAvailableFunc() { + return [](Stack& stack) { + auto self = pop(stack).toCustomClass(); + auto ret = self->is_available(); + push(stack, ret); + }; +} + template std::function getCompileFunc() { return [](Stack& stack) { diff --git a/torch/csrc/jit/backends/backend_interface.h b/torch/csrc/jit/backends/backend_interface.h index caa19052bcc9e..e6d5eac3fd2c7 100644 --- a/torch/csrc/jit/backends/backend_interface.h +++ b/torch/csrc/jit/backends/backend_interface.h @@ -11,6 +11,9 @@ class TORCH_API PyTorchBackendInterface : public torch::CustomClassHolder { PyTorchBackendInterface(); virtual ~PyTorchBackendInterface(); + // Returns true if the backend is available to process delegation calls. + virtual bool is_available() = 0; + // Compile the module contained in \p processed using the details provided in // \p method_compile_spec for each module method that should be compiled for // the backend. \p method_compile_spec should be of type Dict. diff --git a/torch/csrc/jit/backends/backend_resolver.cpp b/torch/csrc/jit/backends/backend_resolver.cpp index a3a8c63e94c8a..01b1face00cc5 100644 --- a/torch/csrc/jit/backends/backend_resolver.cpp +++ b/torch/csrc/jit/backends/backend_resolver.cpp @@ -47,6 +47,8 @@ struct LoweredModuleResolver : public Resolver { return std::make_shared("aten"); } else if (name == "__torch__") { return std::make_shared(c10::QualifiedName(name)); + } else if (name == "Exception") { + return std::make_shared(name); } return nullptr;