From 145bc5cd515ed4e4b58d9004fabab549d9e208e9 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Wed, 24 Mar 2021 13:47:50 -0700 Subject: [PATCH] Rename Math to CompositeImplicitAutograd (#54466) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54466 I had to very carefully audit all the use sites since there are a lot of other uses of the string Math; I did most of the conversion by grepping for all occurrences of Math and then doing a search replace. I also updated documentation for clarity. Signed-off-by: Edward Z. Yang Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D27253239 Pulled By: ezyang fbshipit-source-id: afb485d07ff39575742a4f0e1e205179b60bc953 --- BUILD.bazel | 2 +- aten/src/ATen/core/boxing/KernelFunction.cpp | 7 +- aten/src/ATen/core/boxing/KernelFunction.h | 46 ++++++++--- aten/src/ATen/core/dispatch/OperatorEntry.cpp | 38 ++++----- .../op_registration/op_registration_test.cpp | 36 ++++----- aten/src/ATen/native/README.md | 79 +++++++++++-------- aten/src/ATen/native/native_functions.yaml | 12 +-- c10/core/DispatchKey.cpp | 4 +- c10/core/DispatchKey.h | 2 +- c10/core/DispatchKeySet.cpp | 4 +- c10/core/DispatchKeySet.h | 4 +- test/test_dispatch.py | 73 ++++++++--------- tools/codegen/gen.py | 12 +-- tools/codegen/model.py | 26 +++--- torch/_python_dispatcher.py | 12 +-- torch/csrc/autograd/VariableTypeManual.cpp | 20 ++--- torch/csrc/utils/python_dispatch.cpp | 2 +- 17 files changed, 213 insertions(+), 166 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index 9875805578b7d..00e726a784574 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -130,7 +130,7 @@ genrule( "aten/src/ATen/RegisterMkldnnCPU.cpp", "aten/src/ATen/RegisterQuantizedCPU.cpp", "aten/src/ATen/RegisterSparseCPU.cpp", - "aten/src/ATen/RegisterMath.cpp", + "aten/src/ATen/RegisterCompositeImplicitAutograd.cpp", "aten/src/ATen/RegisterMeta.cpp", "aten/src/ATen/RegisterDefaultBackend.cpp", "aten/src/ATen/RegisterSchema.cpp", diff --git a/aten/src/ATen/core/boxing/KernelFunction.cpp b/aten/src/ATen/core/boxing/KernelFunction.cpp index 48b7c111cef98..c494280e52c68 100644 --- a/aten/src/ATen/core/boxing/KernelFunction.cpp +++ b/aten/src/ATen/core/boxing/KernelFunction.cpp @@ -21,9 +21,10 @@ void fallthrough_kernel(OperatorKernel*, const OperatorHandle&, DispatchKeySet, void ambiguous_autogradother_kernel(OperatorKernel*, const OperatorHandle& op, DispatchKeySet, Stack*) { TORCH_INTERNAL_ASSERT(0, - op.operator_name(), " has kernels registered to both Math and a backend mapped to AutogradOther. " - "This makes the backend kernel unreachable (see Note [Ambiguity in AutogradOther kernel]). " - "If it's intended to override Math kernel behavior, please open an issue to request a dedicated " + op.operator_name(), " has kernels registered to both CompositeImplicitAutograd and a backend mapped to AutogradOther. " + "This makes the backend kernel unreachable; the dispatcher will always prefer the CompositeImplicitAutograd lowering " + "(see Note [Ambiguity in AutogradOther kernel]). " + "If you want to override CompositeImplicitAutograd, please open an issue to request a dedicated " "Autograd dispatch key for the backend.\n", "If you only want to run inference instead of training, add `at::AutoNonVariableTypeMode guard(true);` " "before model.forward(). Note this guard is only available in C++ but not Python at present.", diff --git a/aten/src/ATen/core/boxing/KernelFunction.h b/aten/src/ATen/core/boxing/KernelFunction.h index d4c3677ea18fe..07e4cb2bb3325 100644 --- a/aten/src/ATen/core/boxing/KernelFunction.h +++ b/aten/src/ATen/core/boxing/KernelFunction.h @@ -18,15 +18,43 @@ struct OperatorKernel; TORCH_API void fallthrough_kernel(OperatorKernel*, const OperatorHandle&, DispatchKeySet, Stack*); // Note [Ambiguity in AutogradOther kernel] -// This kernel implements reporting an error message when there're kernels registered -// to both Math and a backend of AutogradOther, we don't know which kernel to pick: -// - if we pick Math kernel for AutogradOther, the kernel registered to backend will be -// silently ignored and never called. -// - if we skip using Math kernel for AutogradOther (it might pick Autograd kernel if available), -// it'll break all backends mapped to AutogradOther without a direct registration to backend. -// See c10/core/DispatchKeySet.cpp for a list of backends mapped to AutogradOther. -// Thus if backend extender indeed want to override Math kernel behavior, they should request -// a dedicated Autograd key for their backend to resolve the ambiguity. +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// This error-reporting kernel is registered to the AutogradOther entry in the +// dispatch table when there is both a CompositeImplicitAutograd kernel and a +// backend kernel for ANY backend that maps to AutogradOther. To see why +// this is necessary in the AutogradOther case, it's helpful to first see +// why everything works out fine for a backend that has a reserved Autograd +// entry (see rule 2.2 in [Note] DispatchTable computation): +// +// CPU AutogradCPU +// reg? registers with... +// ------------------------------------------------- +// y Autograd registration takes precedence +// over CompositeImplicitAutograd. +// This is good, because the CPU specific backend +// implementation is more specialized and typically better; +// if we used the composite, we would bypass it. +// (NB: the Autograd key is guaranteed to exist because +// the autograd codegen requires it!) +// +// n CompositeImplicitAutograd takes precedence. +// This is also good, because the Autograd +// registration (if it exists) would try to redispatch +// to the (non-existent) CPU implementation; by +// using the composite, we ensure the operator +// actually works. +// +// As you can see, when we have a specific Autograd key (AutogradCPU), we can +// decide whether or not to use the CompositeImplicitAutograd kernel or the +// Autograd kernel based on whether or not the backend kernel exists. +// +// However, for AutogradOther (which is the catchall autograd kernel for +// everything that doesn't have a specific Autograd key), we can't do this +// trick because there isn't any unique backend to peek at to disambiguate; +// if there are some backends that have implementations they prefer Autograd, +// but unimplemented backends would prefer CompositeImplicitAutograd. Rather +// than arbitrarily pick one or the other, we just register a kernel that raises +// an error and let the user decide how to proceed. TORCH_API void ambiguous_autogradother_kernel(OperatorKernel*, const OperatorHandle&, DispatchKeySet, Stack*); // Note [named_not_supported_kernel] diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.cpp b/aten/src/ATen/core/dispatch/OperatorEntry.cpp index 36cd4bdcf6e30..a1360090d245b 100644 --- a/aten/src/ATen/core/dispatch/OperatorEntry.cpp +++ b/aten/src/ATen/core/dispatch/OperatorEntry.cpp @@ -108,8 +108,8 @@ std::list::iterator OperatorEntry::registerKernel( // Add the kernel to the kernels list, // possibly creating the list if this is the first kernel. - // Redirect catchAll registrations to Math. - auto& k = dispatch_key.has_value() ? kernels_[*dispatch_key] : kernels_[DispatchKey::Math]; + // Redirect catchAll registrations to CompositeImplicitAutograd. + auto& k = dispatch_key.has_value() ? kernels_[*dispatch_key] : kernels_[DispatchKey::CompositeImplicitAutograd]; if (k.size() > 0) { TORCH_WARN("Overriding a previously registered kernel for the same operator and the same dispatch key\n", @@ -138,8 +138,8 @@ void OperatorEntry::deregisterKernel_( c10::optional dispatch_key, std::list::iterator kernel ) { - // Redirect catchAll deregistrations to Math. - DispatchKey dk = dispatch_key.has_value() ? *dispatch_key : DispatchKey::Math; + // Redirect catchAll deregistrations to CompositeImplicitAutograd. + DispatchKey dk = dispatch_key.has_value() ? *dispatch_key : DispatchKey::CompositeImplicitAutograd; auto found = kernels_.find(dk); TORCH_INTERNAL_ASSERT(found != kernels_.end(), "Tried to deregister a kernel for dispatch key ", toString(dispatch_key), " but there are no kernels registered for this dispatch key. The operator is ", toString(name_)); auto& k = found->second; @@ -186,13 +186,13 @@ std::pair OperatorEntry::computeDispatchTab // (2.1) Use kernel from DispatchKey::DefaultBackend if available. // This is used to register a kernel that works for all backend in inference. But it requires // separate registration for Autograd keys to support training. - // (2.2) Use kernel from DispatchKey::Math if available. - // For autograd keys, we only use kernel from Math when there's no direct registration - // to its corresponding backend key or DefaultBackend. See Note [DefaultBackend and Math]. + // (2.2) Use kernel from DispatchKey::CompositeImplicitAutograd if available. + // For autograd keys, we only use kernel from CompositeImplicitAutograd when there's no direct registration + // to its corresponding backend key or DefaultBackend. See Note [DefaultBackend and CompositeImplicitAutograd]. // For AutogradOther, we eagerly return ambiguousAutogradOtherKernel_ if there's registration to any of // its backends and ask backend extender to request a decicated Autograd key for the backend. // See Note [Ambiguity in AutogradOther kernel] for more details. - // A DefaultBackend kernel prevents Math kernel being used for Autograd keys, but it doesn't + // A DefaultBackend kernel prevents CompositeImplicitAutograd kernel being used for Autograd keys, but it doesn't // cause confusion for AutogradOther. It's pretty straightforward to use Autograd (if available) // in this case. // (2.3) Use kernel from DispatchKey::Autograd if available @@ -201,11 +201,11 @@ std::pair OperatorEntry::computeDispatchTab // backend key. See Note [Refresh Runtime Autograd entries in dispatchTable_] // (3) Use fallthrough kernel that are registered as fallback. // Alias Key Precedence: - // DefaultBackend > Math > Autograd - // Note [DefaultBackend and Math] - // When there're registrations to both DefaultBackend & Math & Autograd, from (2.2) we know DefaultBackend - // and Autograd kernels will be picked up and Math is overriden. - // This is fine and in practice DefaultBackend and Math shouldn't co-exist for an op. + // DefaultBackend > CompositeImplicitAutograd > Autograd + // Note [DefaultBackend and CompositeImplicitAutograd] + // When there're registrations to both DefaultBackend & CompositeImplicitAutograd & Autograd, from (2.2) we know DefaultBackend + // and Autograd kernels will be picked up and CompositeImplicitAutograd is overriden. + // This is fine and in practice DefaultBackend and CompositeImplicitAutograd shouldn't co-exist for an op. // TODO: Update alias key precedence after we add new alias keys AutogradDispatchCPUOrCUDA . // 1. Operator registration @@ -226,13 +226,13 @@ std::pair OperatorEntry::computeDispatchTab bool has_backend_kernel = hasKernelForAnyDispatchKey(getBackendKeySetFromAutograd(dispatch_key).add(DispatchKey::DefaultBackend)); - // 2.2. Use Math kernel if available. For autograd keys, we only use kernel from Math + // 2.2. Use CompositeImplicitAutograd kernel if available. For autograd keys, we only use kernel from CompositeImplicitAutograd // when there's no direct registration to its corresponding backend key or DefaultBackend. // For AutogradOther, we return ambiguousAutogradOtherKernel_ if there's registration // to any of its backends. // See Note [Undefined in dispatchTable_] for the special handling for Undefined. - if (dispatch_key == DispatchKey::Undefined || isIncludedInAlias(dispatch_key, DispatchKey::Math)) { - if (auto math_registration = getKernelForDispatchKey(DispatchKey::Math)) { + if (dispatch_key == DispatchKey::Undefined || isIncludedInAlias(dispatch_key, DispatchKey::CompositeImplicitAutograd)) { + if (auto math_registration = getKernelForDispatchKey(DispatchKey::CompositeImplicitAutograd)) { if (dispatch_key == DispatchKey::AutogradOther && hasKernelForAnyDispatchKey(c10::autogradother_backends)) { return {ambiguousAutogradOtherKernel_, "ambiguous autogradother"}; @@ -286,9 +286,9 @@ void OperatorEntry::updateDispatchTable_(const c10::Dispatcher& dispatcher, Disp for (auto k : c10::getRuntimeDispatchKeySet(dispatch_key)) { updateDispatchTableEntry_(dispatcher, k); } - // Registration to DefaultBackend and Math should be populated to Undefined. + // Registration to DefaultBackend and CompositeImplicitAutograd should be populated to Undefined. // We cannot do this above since Undefined cannot be represented in DispatchKeySet. - if (dispatch_key == DispatchKey::Math || dispatch_key == DispatchKey::DefaultBackend) { + if (dispatch_key == DispatchKey::CompositeImplicitAutograd || dispatch_key == DispatchKey::DefaultBackend) { updateDispatchTableEntry_(dispatcher, DispatchKey::Undefined); } // Note [Refresh Runtime Autograd entries in dispatchTable_] @@ -319,7 +319,7 @@ void OperatorEntry::updateDispatchTableFull_(const c10::Dispatcher& dispatcher) // the error message. // In the old world of catchAll, the only way to "register" a kernel to Undefined is by registering it to // catchAll. After catchAllKernel_ is removed, Undefined now can get a kernel from either DefaultBackend - // or Math alias key so that we don't break the support. Ideally isIncludedInAlias(Undefined, Math) + // or CompositeImplicitAutograd alias key so that we don't break the support. Ideally isIncludedInAlias(Undefined, CompositeImplicitAutograd) // should return true, it returns false because Undefined cannot be represented in a DispatchKeySet. for (uint8_t iter = 0; iter != static_cast(DispatchKey::NumDispatchKeys); ++iter) { updateDispatchTable_(dispatcher, static_cast(iter)); diff --git a/aten/src/ATen/core/op_registration/op_registration_test.cpp b/aten/src/ATen/core/op_registration/op_registration_test.cpp index 5e1889d48809f..bb296a3ecb2f7 100644 --- a/aten/src/ATen/core/op_registration/op_registration_test.cpp +++ b/aten/src/ATen/core/op_registration/op_registration_test.cpp @@ -520,7 +520,7 @@ TEST(OperatorRegistrationTest, whenRegisteringAutogradKernelWithCatchAllKernel_t auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); ASSERT_TRUE(op.has_value()); - // catchAll now maps to Math which has higher precedence than Autograd + // catchAll now maps to CompositeImplicitAutograd which has higher precedence than Autograd called_nonautograd = called_autograd = false; op->typed().call(dummyTensor(DispatchKey::CPU, /*requires_grad=*/true)); EXPECT_TRUE(called_nonautograd); @@ -1306,7 +1306,7 @@ TEST(NewOperatorRegistrationTest, whenRegisteringBackendFallbackKernelAndCatchal called = false; auto stack = callOp(*op, dummyTensor(c10::DispatchKey::CPU), "hello "); - // CatchAll now maps to Math and has higher precedence than backend fallback. + // CatchAll now maps to CompositeImplicitAutograd and has higher precedence than backend fallback. EXPECT_TRUE(called); } @@ -1325,10 +1325,10 @@ TEST(NewOperatorRegistrationTest, whenRegisteringAutogradKernelWithRegularKernel EXPECT_FALSE(called_autograd); } -TEST(NewOperatorRegistrationTest, dispatchWithMathKernel) { +TEST(NewOperatorRegistrationTest, dispatchWithCompositeImplicitAutogradKernel) { bool math_called = false; auto m = MAKE_TORCH_LIBRARY(test); - m.def("fn", torch::dispatch(c10::DispatchKey::Math, [&](const Tensor& x) { math_called = true; return x; })); + m.def("fn", torch::dispatch(c10::DispatchKey::CompositeImplicitAutograd, [&](const Tensor& x) { math_called = true; return x; })); auto op = Dispatcher::singleton().findSchema({"test::fn", ""}); ASSERT_TRUE(op.has_value()); @@ -1370,17 +1370,17 @@ TEST(NewOperatorRegistrationTest, dispatchWithMathKernel) { } } -TEST(NewOperatorRegistrationTest, dispatchWithMathAndAutogradKernel) { +TEST(NewOperatorRegistrationTest, dispatchWithCompositeImplicitAutogradAndAutogradKernel) { bool math_called = false; bool autograd_called = false; auto m = MAKE_TORCH_LIBRARY(test); - m.def("fn", torch::dispatch(c10::DispatchKey::Math, [&](const Tensor& x) { math_called = true; return x; })); + m.def("fn", torch::dispatch(c10::DispatchKey::CompositeImplicitAutograd, [&](const Tensor& x) { math_called = true; return x; })); m.impl("fn", c10::DispatchKey::Autograd, [&](const Tensor& x) { autograd_called = true; return x; }); auto op = Dispatcher::singleton().findSchema({"test::fn", ""}); ASSERT_TRUE(op.has_value()); - // Math has higher precedence than Autograd + // CompositeImplicitAutograd has higher precedence than Autograd { math_called = autograd_called = false; callOp(*op, dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true)); @@ -1396,17 +1396,17 @@ TEST(NewOperatorRegistrationTest, dispatchWithMathAndAutogradKernel) { } } -TEST(NewOperatorRegistrationTest, dispatchWithMathAndCatchAllKernel) { +TEST(NewOperatorRegistrationTest, dispatchWithCompositeImplicitAutogradAndCatchAllKernel) { bool math_called = false; bool catchall_called = false; auto m = MAKE_TORCH_LIBRARY(test); - m.def("fn", torch::dispatch(c10::DispatchKey::Math, [&](const Tensor& x) { math_called = true; return x; })); + m.def("fn", torch::dispatch(c10::DispatchKey::CompositeImplicitAutograd, [&](const Tensor& x) { math_called = true; return x; })); m.impl("fn", [&](const Tensor& x) { catchall_called = true; return x; }); auto op = Dispatcher::singleton().findSchema({"test::fn", ""}); ASSERT_TRUE(op.has_value()); - // catchAll now maps to Math, which means we have two registrations to Math key. + // catchAll now maps to CompositeImplicitAutograd, which means we have two registrations to CompositeImplicitAutograd key. // The last registration is used. { catchall_called = math_called = false; @@ -1423,11 +1423,11 @@ TEST(NewOperatorRegistrationTest, dispatchWithMathAndCatchAllKernel) { } } -TEST(NewOperatorRegistrationTest, AutogradBackendOverridesMathKernel) { +TEST(NewOperatorRegistrationTest, AutogradBackendOverridesCompositeImplicitAutogradKernel) { bool math_called = false; bool autograd_called = false; auto m = MAKE_TORCH_LIBRARY(test); - m.def("fn", torch::dispatch(c10::DispatchKey::Math, [&](const Tensor& x) { math_called = true; return x; })); + m.def("fn", torch::dispatch(c10::DispatchKey::CompositeImplicitAutograd, [&](const Tensor& x) { math_called = true; return x; })); m.impl("fn", c10::DispatchKey::AutogradCPU, [&](const Tensor& x) { autograd_called = true; return x; }); auto op = Dispatcher::singleton().findSchema({"test::fn", ""}); @@ -1462,11 +1462,11 @@ TEST(NewOperatorRegistrationTest, AutogradBackendOverridesMathKernel) { } } -TEST(NewOperatorRegistrationTest, BackendOverridesMathKernel) { +TEST(NewOperatorRegistrationTest, BackendOverridesCompositeImplicitAutogradKernel) { bool math_called = false; bool backend_called = false; auto m = MAKE_TORCH_LIBRARY(test); - m.def("fn", torch::dispatch(c10::DispatchKey::Math, [&](const Tensor& x) { math_called = true; return x; })); + m.def("fn", torch::dispatch(c10::DispatchKey::CompositeImplicitAutograd, [&](const Tensor& x) { math_called = true; return x; })); m.impl("fn", c10::DispatchKey::CPU, [&](const Tensor& x) { backend_called = true; return x; }); auto op = Dispatcher::singleton().findSchema({"test::fn", ""}); @@ -1550,12 +1550,12 @@ TEST(NewOperatorRegistrationTest, dispatchWithDefaultBackendKernel) { } } -TEST(NewOperatorRegistrationTest, dispatchWithDefaultBackendAndMathKernel) { +TEST(NewOperatorRegistrationTest, dispatchWithDefaultBackendAndCompositeImplicitAutogradKernel) { bool backend_called = false; bool math_called = false; auto m = MAKE_TORCH_LIBRARY(test); m.def("fn", torch::dispatch(c10::DispatchKey::DefaultBackend, [&](const Tensor& x) { backend_called = true; return x; })); - m.impl("fn", c10::DispatchKey::Math, [&](const Tensor& x) { math_called = true; return x; }); + m.impl("fn", c10::DispatchKey::CompositeImplicitAutograd, [&](const Tensor& x) { math_called = true; return x; }); auto op = Dispatcher::singleton().findSchema({"test::fn", ""}); ASSERT_TRUE(op.has_value()); @@ -1735,7 +1735,7 @@ TEST(NewOperatorRegistrationTest, throwsWhenRegisterToBackendMapsToAutogradOther bool sparsecpu_called, math_called = false; auto m = MAKE_TORCH_LIBRARY(test); m.def("fn", torch::dispatch(c10::DispatchKey::SparseCPU, [&](const Tensor& x) { sparsecpu_called = true; return x; })); - m.impl("fn", c10::DispatchKey::Math, [&](const Tensor& x) { math_called = true; return x; }); + m.impl("fn", c10::DispatchKey::CompositeImplicitAutograd, [&](const Tensor& x) { math_called = true; return x; }); auto op = Dispatcher::singleton().findSchema({"test::fn", ""}); ASSERT_TRUE(op.has_value()); @@ -1748,7 +1748,7 @@ TEST(NewOperatorRegistrationTest, throwsWhenRegisterToBackendMapsToAutogradOther { expectThrows([&] { callOp(*op, dummyTensor(c10::DispatchKey::SparseCPU, /*requires_grad=*/true)); - }, "test::fn has kernels registered to both Math and a backend mapped to AutogradOther."); + }, "test::fn has kernels registered to both CompositeImplicitAutograd and a backend mapped to AutogradOther."); } } diff --git a/aten/src/ATen/native/README.md b/aten/src/ATen/native/README.md index f652250bbad40..3bde88521a240 100644 --- a/aten/src/ATen/native/README.md +++ b/aten/src/ATen/native/README.md @@ -264,37 +264,54 @@ dispatch: This specifies the actual name of the function you want to dispatch to, so you can dispatch to different functions depending on which backend the passed tensors -belong to. Technically, it is also possible to write `dispatch: func_name` -to unconditionally dispatch to a native function whose name is different than -the name in the public ATen API, but this is generally frowned upon (just name -them the same thing!) +belong to. If the dispatch table is omitted, we assume a default dispatch +table: + +``` +# overload is ignored +func: func.overload(...) -> ... +dispatch: + CompositeImplicitAutograd: func + +# overload is ignored, but out functions get suffixed with _out in their name +func: func.out_overload(...) -> ... +dispatch: + CompositeImplicitAutograd: func_out +``` If two backends have the same dispatch function, you can write `CPU, CUDA: func` to reuse the same function name in both cases. Available backend options can be found by searching `dispatch_keys` in [codegen](https://github.com/pytorch/pytorch/blob/master/tools/codegen/gen.py). -Among the supported backends, there're a few alias keys that maps to a set of backends: - - `DefaultBackend`: an alias that maps to all backends. Functions registered to - `DefaultBackend` should work for any backend for inference. (Note: - calling into a DispatchStub does NOT mean it works for any backend; +There are also two special "generic" backends: + + - `CompositeExplicitAutograd` (previously known as `DefaultBackend`): + implementations of kernels that work for all backends, but require an + explicit definition of backward function in `derivatives.yaml` to support autograd. + The most typical use of this key are for delegating functions; i.e., + functions that do a very small amount of work and then delegate to another + operator to do the actual heavy lifting. Under the hood, registering a + kernel to `CompositeExplicitAutograd` is equivalent to registering that + kernel to every backend (e.g., `CPU, CUDA`). Note: kernels which call + DispatchStub should NOT be registered as CompositeExplicitAutograd, as DispatchStub only works for `CPU, CUDA`) - - `Math`: an alias that maps to all backend and autograd backend keys. Functions - registered to `Math` key should be plain mathematical composition of other - `at::` functions and support training and inference for any backend. -`DefaultBackend` and `Math` keys act as defaults that can be overridden: for example, you can specify a custom -kernel for a particular backend using a backend-specific dispatch key, and use -`DefaultBackend` or `Math` to specify a generic kernel for the others. + - `CompositeImplicitAutograd` (previously known as `Math`): implementations of + kernels that work for all backends, and also can implicitly support autograd, + because all of the operations it calls support autograd. Direct use of + this key should be rare: if you provide no dispatch table, we default to + registering your kernel as `CompositeImplicitAutograd`. Explicitly adding + this key to an existing dispatch table may be useful if you have specialized + CPU and CUDA implementations, but you might want to provide a fallback + lowering for external backends that may not have a specialized + implementation. -Note that like those registered to `Math`, kernels registered to `DefaultBackend` are -very often implemented as mathematical expressions built up from calls to other `at::` -functions. This is because in both cases, the kernel needs to delegate backend-specific -computation to the functions it calls. The difference between `DefaultBackend` and `Math` -is that a `Math` kernel also implicitly defines a derivative formula: to do this, it must -call only functions that themselves support autograd. +Functions registered to composite backends should work for any backend, if the +nested functions they call work for those backends. For example, suppose `my_op` can be implemented in the following way: + ``` at::Tensor my_op(const Tensor& self, const Tensor& other) { return self + 2 * other; @@ -302,26 +319,26 @@ at::Tensor my_op(const Tensor& self, const Tensor& other) { ``` If we already know inference kernels and derivative formulas for operators `+` and `*` in our system, -you can just register `my_op` to `Math` and both inference & autograd will just work. +you can just register `my_op` to `CompositeImplicitAutograd` and both inference & autograd will just work. Although it seems we only write down the inference formula here, PyTorch autograd system would correctly set up the backward for `my_op` using the chain formula and derivatives of `+` & `*` operators. In other words `d_out/d_self = 1; d_out/d_other = 2` can be derived automatically from the `my_op` inference kernel. Of course if we don't have derivative formula defined for either `+` or `*`, backward of `my_op` can no longer be derived automatically. -Whether to use `Math` or `DefaultBackend` for your kernel can be decided by the following steps: -1. If you can, always start with a `Math` kernel that's composable from existing operators. -2. If you don't want to use the derived gradient formula from `Math` kernel for autograd, either to - get better performance or better numerical stability, you should put the kernel in `DefaultBackend` +Whether to use implicit or explicit autograd for your kernel can be decided by the following steps: +1. If you can, always start with a `CompositeImplicitAutograd` kernel that's composable from existing operators. +2. If you don't want to use the derived gradient formula from `CompositeImplicitAutograd` kernel for autograd, either to + get better performance or better numerical stability, you should register the kernel with `CompositeExplicitAutograd` so that it's only used in inference. Later for autograd, depending on whether your autograd kernel works for all backends or not, you can put them in alias `Autograd` or specific keys like `AutogradCPU`. 3. If you prefer to write backend-specific kernels, use reserved dispatch keys for your backend instead, e.g. `CPU/AutogradCPU`. -**Important**: because a `Math` kernel is implicitly registered for ops with no `dispatch:` section, +**Important**: because a `CompositeImplicitAutograd` kernel is implicitly registered for ops with no `dispatch:` section, when you add a backend-specific kernel (and hence a `dispatch:` section) to one of these, you **must** also -add a `Math:` entry that names the old kernel implementation (it's named after the op, with _ +add a `CompositeImplicitAutograd:` entry that names the old kernel implementation (it's named after the op, with _ added if applicable), so that it's still available for other backends to use. If you implemented a native function in C++ and want to find out which dispatch keyword @@ -460,7 +477,7 @@ Here're steps to follow to decide the right dispatch keyword: - Yes: you're likely calling other `at::` ops in the implemetation. Go to step 2. 2. Think about training: does your kernel support autograd? [check autograd support](#will-your-function-be-automatically-differentiable) - - Yes: in other words, you're providing a `Math` kernel which supports both inference and autograd. + - Yes: in other words, you're providing a `CompositeImplicitAutograd` kernel which supports both inference and autograd. To use autograd support for training, simply skip adding a dispatch section and you're done. This will allow this op to be correctly registered for both inference and training. @@ -504,7 +521,7 @@ It shows for a certain operator, what the computed dispatch table looks like aft ``` dispatcher = PythonDispatcher() - dispatcher.register(["CPU", "XLA", "AutogradCPU", "Math"]) + dispatcher.register(["CPU", "XLA", "AutogradCPU", "CompositeImplicitAutograd"]) print(dispatcher.dispatchTable()) # Tells you exactly which kernel is used for certain backend. ``` @@ -512,8 +529,8 @@ It shows for a certain operator, what the computed dispatch table looks like aft Note that in native_functions.yaml you can mix using backend keywords and alias keywords above for one op: - direct registration to backend always has higher precendence than alias - - DO NOT provide multiple alias keywords to the same op: alias keywords have precedence `DefaultBackend > Math`, - e.g. adding both `Math` and `DefaultBackend` kernels for one op will completely ignore `Math` kernel for + - DO NOT provide multiple alias keywords to the same op: alias keywords have precedence `DefaultBackend > CompositeImplicitAutograd`, + e.g. adding both `CompositeImplicitAutograd` and `DefaultBackend` kernels for one op will completely ignore `CompositeImplicitAutograd` kernel for both inference and training. Thus this will trigger an error when native_functions.yaml is parsed. diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index f69dc02eb1d6a..bfdc829b60297 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -398,7 +398,7 @@ variants: function, method dispatch: CPU, CUDA: addr - Math: math_addr + CompositeImplicitAutograd: math_addr - func: addr_(Tensor(a!) self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) variants: method @@ -408,7 +408,7 @@ - func: addr.out(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) dispatch: CPU, CUDA: addr_out - Math: math_addr_out + CompositeImplicitAutograd: math_addr_out - func: affine_grid_generator(Tensor theta, int[] size, bool align_corners) -> Tensor variants: function @@ -1853,7 +1853,7 @@ - func: native_group_norm(Tensor input, Tensor? weight, Tensor? bias, int N, int C, int HxW, int group, float eps) -> (Tensor, Tensor, Tensor) dispatch: CPU, CUDA: native_group_norm - Math: math_group_norm + CompositeImplicitAutograd: math_group_norm - func: native_group_norm_backward(Tensor grad_out, Tensor input, Tensor mean, Tensor rstd, Tensor? weight, int N, int C, int HxW, int group, bool[3] output_mask) -> (Tensor, Tensor, Tensor) dispatch: @@ -2041,7 +2041,7 @@ dispatch: CPU: layer_norm_cpu CUDA: layer_norm_cuda - Math: math_native_layer_norm + CompositeImplicitAutograd: math_native_layer_norm - func: native_layer_norm_backward(Tensor grad_out, Tensor input, int[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask) -> (Tensor, Tensor, Tensor) dispatch: @@ -3219,7 +3219,7 @@ python_module: nn dispatch: CPU, CUDA: silu_backward - Math: math_silu_backward + CompositeImplicitAutograd: math_silu_backward - func: sigmoid(Tensor self) -> Tensor variants: function, method @@ -4054,7 +4054,7 @@ dispatch: CPU, CUDA: norm_out -# These four redispatch in their implementation, so OK to be Math +# These four redispatch in their implementation, so OK to be CompositeImplicitAutograd - func: norm.names_ScalarOpt_dim_dtype(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor variants: function, method diff --git a/c10/core/DispatchKey.cpp b/c10/core/DispatchKey.cpp index 8509a96041b69..12ec486246542 100644 --- a/c10/core/DispatchKey.cpp +++ b/c10/core/DispatchKey.cpp @@ -109,8 +109,8 @@ const char* toString(DispatchKey t) { case DispatchKey::VmapMode: return "VmapMode"; - case DispatchKey::Math: - return "Math"; + case DispatchKey::CompositeImplicitAutograd: + return "CompositeImplicitAutograd"; case DispatchKey::DefaultBackend: return "DefaultBackend"; diff --git a/c10/core/DispatchKey.h b/c10/core/DispatchKey.h index 808846a8f5503..c16c72bbaaca1 100644 --- a/c10/core/DispatchKey.h +++ b/c10/core/DispatchKey.h @@ -270,7 +270,7 @@ enum class DispatchKey : uint8_t { // See Note [Alias Dispatch Key : Autograd] Autograd, - Math, // registered at build/aten/src/ATen/RegisterMath.cpp + CompositeImplicitAutograd, // registered at build/aten/src/ATen/RegisterCompositeImplicitAutograd.cpp DefaultBackend, // registered at // build/aten/src/ATen/RegisterDefaultBackend.cpp diff --git a/c10/core/DispatchKeySet.cpp b/c10/core/DispatchKeySet.cpp index f91a9a6cd2d85..c70666c9371bb 100644 --- a/c10/core/DispatchKeySet.cpp +++ b/c10/core/DispatchKeySet.cpp @@ -22,7 +22,7 @@ bool isBackendDispatchKey(DispatchKey t) { } // math_dispatch_keyset contains all keys in backend_dispatch_keyset and autograd_dispatch_keyset -// Alias key DispatchKey::Math maps to math_dispatch_keyset. +// Alias key DispatchKey::CompositeImplicitAutograd maps to math_dispatch_keyset. constexpr DispatchKeySet math_dispatch_keyset = backend_dispatch_keyset | autograd_dispatch_keyset; DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) { @@ -30,7 +30,7 @@ DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) { switch (t) { case DispatchKey::Autograd: return autograd_dispatch_keyset; - case DispatchKey::Math: + case DispatchKey::CompositeImplicitAutograd: return math_dispatch_keyset; case DispatchKey::DefaultBackend: return backend_dispatch_keyset; diff --git a/c10/core/DispatchKeySet.h b/c10/core/DispatchKeySet.h index b1cdb95c7e35b..4d94b0ed835eb 100644 --- a/c10/core/DispatchKeySet.h +++ b/c10/core/DispatchKeySet.h @@ -189,7 +189,7 @@ C10_API std::ostream& operator<<(std::ostream&, DispatchKeySet); // autograd_dispatch_keyset should include all runtime autograd keys. // Alias key DispatchKey::Autograd maps to autograd_dispatch_keyset. -// NB: keys in this set also get associated with Math +// NB: keys in this set also get associated with CompositeImplicitAutograd constexpr DispatchKeySet autograd_dispatch_keyset = DispatchKeySet({ DispatchKey::AutogradCPU, DispatchKey::AutogradCUDA, @@ -207,7 +207,7 @@ constexpr DispatchKeySet autograd_dispatch_keyset_with_InplaceOrView = autograd_dispatch_keyset | DispatchKeySet(DispatchKey::InplaceOrView); // backend dispatch keys that map to DispatchKey::AutogradOther -// NB: keys in this set also get associated with Math +// NB: keys in this set also get associated with CompositeImplicitAutograd constexpr DispatchKeySet autogradother_backends = DispatchKeySet({ DispatchKey::HIP, DispatchKey::FPGA, diff --git a/test/test_dispatch.py b/test/test_dispatch.py index 80b9f9adeac1b..8f8ecdb318a3d 100644 --- a/test/test_dispatch.py +++ b/test/test_dispatch.py @@ -245,7 +245,7 @@ def test_def(self): CPU: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] AutogradCPU: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] Autograd[alias]: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] -Math[alias]: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] +CompositeImplicitAutograd[alias]: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] ''') def test_def_impl_schema_mismatch(self): @@ -285,7 +285,7 @@ def test_def_with_inference(self): CPU: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] AutogradCPU: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] Autograd[alias]: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] -Math[alias]: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] +CompositeImplicitAutograd[alias]: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] ''') def test_def_only(self): @@ -317,7 +317,7 @@ def test_impl_only(self): CPU: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] AutogradCPU: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] Autograd[alias]: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] -Math[alias]: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] +CompositeImplicitAutograd[alias]: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] ''') def test_computed_table(self): @@ -343,7 +343,7 @@ def test_computed_table(self): XLA: fn_xla :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] AutogradCPU: fn_autogradcpu :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] Autograd[alias]: fn_autograd :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] -Math[alias]: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] +CompositeImplicitAutograd[alias]: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] ''') # computed dispatch table is too big, so we only check on a few entries we're interested in. @@ -375,7 +375,7 @@ def test_computed_table_with_cpu_math_autogradcpu_fallthrough(self): debug: registered at /dev/null:0 alias analysis kind: CONSERVATIVE CPU: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] -Math[alias]: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] +CompositeImplicitAutograd[alias]: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] ''') # computed dispatch table is too big, so we only check on a few entries we're interested in. @@ -397,8 +397,8 @@ def test_computed_table_with_math(self): result = self.commute("foo", [ # m.def("foo(Tensor x) -> Tensor") lambda m: m.def_("foo(Tensor x) -> Tensor"), - # m.impl("foo", torch::kMath, [](const Tensor & x) { return x }) - lambda m: m.impl_t_t("foo", "Math"), + # m.impl("foo", torch::kCompositeImplicitAutograd, [](const Tensor & x) { return x }) + lambda m: m.impl_t_t("foo", "CompositeImplicitAutograd"), ]) state, table = result.state, result.table self.assertExpectedInline(state, '''\ @@ -406,7 +406,7 @@ def test_computed_table_with_math(self): schema: test::foo(Tensor x) -> (Tensor) debug: registered at /dev/null:0 alias analysis kind: FROM_SCHEMA -Math[alias]: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] +CompositeImplicitAutograd[alias]: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] ''') # computed dispatch table is too big, so we only check on a few entries we're interested in. @@ -430,8 +430,8 @@ def test_computed_table_with_cpu_math(self): lambda m: m.def_("foo(Tensor x) -> Tensor"), # m.impl("foo", torch::kCPU, [](const Tensor & x) { return x }) lambda m: m.impl_t_t("foo", "CPU", debug="fn_cpu"), - # m.impl("foo", torch::kMath, [](const Tensor & x) { return x }) - lambda m: m.impl_t_t("foo", "Math", debug="fn_math"), + # m.impl("foo", torch::kCompositeImplicitAutograd, [](const Tensor & x) { return x }) + lambda m: m.impl_t_t("foo", "CompositeImplicitAutograd", debug="fn_math"), ]) state, table = result.state, result.table self.assertExpectedInline(state, '''\ @@ -440,7 +440,7 @@ def test_computed_table_with_cpu_math(self): debug: registered at /dev/null:0 alias analysis kind: FROM_SCHEMA CPU: fn_cpu :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] -Math[alias]: fn_math :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] +CompositeImplicitAutograd[alias]: fn_math :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] ''') # computed dispatch table is too big, so we only check on a few entries we're interested in. @@ -484,7 +484,8 @@ def test_computed_table_with_autograd(self): AutogradXLA: impl_t_t [autograd kernel] ''') - # Now that catchAll maps to Math, registering to both catchAll and Math breaks commutativity. + # Now that catchAll maps to CompositeImplicitAutograd, registering to both + # catchAll and CompositeImplicitAutograd breaks commutativity. def test_computed_table_with_cpu_autograd_math(self): result = self.commute("foo", [ # m.def("foo(Tensor x) -> Tensor") @@ -493,8 +494,8 @@ def test_computed_table_with_cpu_autograd_math(self): lambda m: m.impl_t_t("foo", "CPU", debug="fn_cpu"), # m.impl("foo", torch::kAutograd, [](const Tensor & x) { return x }) lambda m: m.impl_t_t("foo", "Autograd", debug="fn_autograd"), - # m.impl("foo", torch::kMath, [](const Tensor & x) { return x }) - lambda m: m.impl_t_t("foo", "Math", debug="fn_math"), + # m.impl("foo", torch::kCompositeImplicitAutograd, [](const Tensor & x) { return x }) + lambda m: m.impl_t_t("foo", "CompositeImplicitAutograd", debug="fn_math"), ]) state, table = result.state, result.table self.assertExpectedInline(state, '''\ @@ -504,7 +505,7 @@ def test_computed_table_with_cpu_autograd_math(self): alias analysis kind: FROM_SCHEMA CPU: fn_cpu :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] Autograd[alias]: fn_autograd :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] -Math[alias]: fn_math :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] +CompositeImplicitAutograd[alias]: fn_math :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] ''') # computed dispatch table is too big, so we only check on a few entries we're interested in. @@ -525,8 +526,8 @@ def test_computed_table_with_ambiguous_autogradother(self): result = self.commute("foo", [ # m.def("foo(Tensor x) -> Tensor") lambda m: m.def_("foo(Tensor x) -> Tensor"), - # m.impl("foo", torch::kMath, [](const Tensor & x) { return x }) - lambda m: m.impl_t_t("foo", "Math", debug="fn_math"), + # m.impl("foo", torch::kCompositeImplicitAutograd, [](const Tensor & x) { return x }) + lambda m: m.impl_t_t("foo", "CompositeImplicitAutograd", debug="fn_math"), # m.impl("foo", torch::kQuantizedCPU, [](const Tensor & x) { return x }) lambda m: m.impl_t_t("foo", "QuantizedCPU", debug="fn_quantizedcpu"), ]) @@ -537,7 +538,7 @@ def test_computed_table_with_ambiguous_autogradother(self): debug: registered at /dev/null:0 alias analysis kind: FROM_SCHEMA QuantizedCPU: fn_quantizedcpu :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] -Math[alias]: fn_math :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] +CompositeImplicitAutograd[alias]: fn_math :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] ''') # computed dispatch table is too big, so we only check on a few entries we're interested in. @@ -633,8 +634,8 @@ def test_computed_table_with_cpu_autograd_math_defaultbackend(self): lambda m: m.impl_t_t("foo", "CPU", debug="fn_cpu"), # m.impl("foo", torch::kAutograd, [](const Tensor & x) { return x }) lambda m: m.impl_t_t("foo", "Autograd", debug="fn_autograd"), - # m.impl("foo", torch::kMath, [](const Tensor & x) { return x }) - lambda m: m.impl_t_t("foo", "Math", debug="fn_math"), + # m.impl("foo", torch::kCompositeImplicitAutograd, [](const Tensor & x) { return x }) + lambda m: m.impl_t_t("foo", "CompositeImplicitAutograd", debug="fn_math"), # m.impl("foo", torch::kDefaultBackend, [](const Tensor & x) { return x }) lambda m: m.impl_t_t("foo", "DefaultBackend", debug="fn_defaultbackend"), ]) @@ -646,7 +647,7 @@ def test_computed_table_with_cpu_autograd_math_defaultbackend(self): alias analysis kind: FROM_SCHEMA CPU: fn_cpu :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] Autograd[alias]: fn_autograd :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] -Math[alias]: fn_math :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] +CompositeImplicitAutograd[alias]: fn_math :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] DefaultBackend[alias]: fn_defaultbackend :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] ''') @@ -749,15 +750,15 @@ def test_overwrite_math(self): '''\ name: test::foo schema: (none) -Math[alias]: fn2 :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] -Math[alias] (inactive): fn1 :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] +CompositeImplicitAutograd[alias]: fn2 :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] +CompositeImplicitAutograd[alias] (inactive): fn1 :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] ''' ) class TestPythonDispatcher(TestCase): def test_basic(self): dispatcher = PythonDispatcher() - dispatcher.register(["CPU", "XLA", "Math"]) + dispatcher.register(["CPU", "XLA", "CompositeImplicitAutograd"]) self.assertExpectedInline( dispatcher.dispatchTable(), '''\ @@ -767,8 +768,8 @@ def test_basic(self): --------------------------- CPU fn_CPU [kernel] XLA fn_XLA [kernel] -QuantizedCPU fn_Math [math kernel] -AutogradOther fn_Math [math kernel] +QuantizedCPU fn_CompositeImplicitAutograd [math kernel] +AutogradOther fn_CompositeImplicitAutograd [math kernel] AutogradCPU fallthrough [backend fallback] AutogradXLA fallthrough [backend fallback] ''' @@ -776,7 +777,7 @@ def test_basic(self): def test_math_autogradcpu(self): dispatcher = PythonDispatcher() - dispatcher.register(["CPU", "XLA", "Math", "AutogradCPU"]) + dispatcher.register(["CPU", "XLA", "CompositeImplicitAutograd", "AutogradCPU"]) self.assertExpectedInline( dispatcher.dispatchTable(), '''\ @@ -786,8 +787,8 @@ def test_math_autogradcpu(self): --------------------------- CPU fn_CPU [kernel] XLA fn_XLA [kernel] -QuantizedCPU fn_Math [math kernel] -AutogradOther fn_Math [math kernel] +QuantizedCPU fn_CompositeImplicitAutograd [math kernel] +AutogradOther fn_CompositeImplicitAutograd [math kernel] AutogradCPU fn_AutogradCPU [kernel] AutogradXLA fallthrough [backend fallback] ''' @@ -802,7 +803,7 @@ def test_math_autogradcpu(self): CPU fn_CPU XLA fn_XLA AutogradCPU fn_AutogradCPU -Math[alias] fn_Math +CompositeImplicitAutograd[alias] fn_CompositeImplicitAutograd ''' ) @@ -841,7 +842,7 @@ def test_defaultbackend_autogradcpu(self): def test_autogradother(self): dispatcher = PythonDispatcher() - dispatcher.register(["CPU", "QuantizedCPU", "Math"]) + dispatcher.register(["CPU", "QuantizedCPU", "CompositeImplicitAutograd"]) self.assertExpectedInline( dispatcher.dispatchTable(), '''\ @@ -850,11 +851,11 @@ def test_autogradother(self): key kernel --------------------------- CPU fn_CPU [kernel] -XLA fn_Math [math kernel] +XLA fn_CompositeImplicitAutograd [math kernel] QuantizedCPU fn_QuantizedCPU [kernel] AutogradOther ambiguous_autogradother [ambiguous autogradother] AutogradCPU fallthrough [backend fallback] -AutogradXLA fn_Math [math kernel] +AutogradXLA fn_CompositeImplicitAutograd [math kernel] ''' ) @@ -867,7 +868,7 @@ def test_autogradother(self): --------------------------- CPU fn_CPU QuantizedCPU fn_QuantizedCPU -Math[alias] fn_Math +CompositeImplicitAutograd[alias] fn_CompositeImplicitAutograd ''' ) @@ -882,8 +883,8 @@ def test_defaultbackend_math(self): with self.assertRaisesRegex( RuntimeError, - r"Registration to both Math and DefaultBackend is not allowed"): - dispatcher.register(["DefaultBackend", "Math"]) + r"Registration to both CompositeImplicitAutograd and DefaultBackend is not allowed"): + dispatcher.register(["DefaultBackend", "CompositeImplicitAutograd"]) if __name__ == '__main__': diff --git a/tools/codegen/gen.py b/tools/codegen/gen.py index 9c86a01aa9ac7..61b9f0e1ee94e 100644 --- a/tools/codegen/gen.py +++ b/tools/codegen/gen.py @@ -125,7 +125,7 @@ def static_dispatch_extra_headers(backend: Optional[DispatchKey]) -> str: return f""" #include #include -#include +#include """ def static_dispatch( @@ -147,7 +147,7 @@ def static_dispatch( # migrate math/default_backend ops to use structured delegate. return f'return at::{backend.lower()}::{name}({exprs_str});' - for dispatch_key in (backend, DispatchKey.DefaultBackend, DispatchKey.Math): + for dispatch_key in (backend, DispatchKey.DefaultBackend, DispatchKey.CompositeImplicitAutograd): if dispatch_key in f.dispatch: return f'return at::{dispatch_key.lower()}::{name}({exprs_str});' @@ -634,7 +634,7 @@ def compute_declaration_yaml(f: NativeFunction) -> object: ('device_guard', f.device_guard), ('with_gil', False), ('deprecated', False), - ('has_math_kernel', DispatchKey.Math in f.dispatch), + ('has_math_kernel', DispatchKey.CompositeImplicitAutograd in f.dispatch), ]) @with_native_function @@ -646,7 +646,7 @@ def compute_registration_declarations(f: NativeFunction) -> str: comment_data : Dict[str, str] = { 'schema': f'aten::{f.func}', # TODO: What exactly is the semantics of the 'dispatch' field? - 'dispatch': str(f.dispatch.keys() != {DispatchKey.Math}), + 'dispatch': str(f.dispatch.keys() != {DispatchKey.CompositeImplicitAutograd}), 'default': str(any(is_generic_dispatch_key(k) for k in f.dispatch)) } return f"""{returns_type} {name}({args_str}); // {json.dumps(comment_data)} @@ -862,7 +862,7 @@ def make_file_manager(install_dir: str) -> FileManager: DispatchKey.SparseCUDA, DispatchKey.QuantizedCPU, DispatchKey.QuantizedCUDA, - DispatchKey.Math, + DispatchKey.CompositeImplicitAutograd, DispatchKey.DefaultBackend, # Meta is a magic key: it is automatically generated for structured # kernels @@ -873,7 +873,7 @@ def make_file_manager(install_dir: str) -> FileManager: functions_keys = { DispatchKey.CPU, DispatchKey.CUDA, - DispatchKey.Math, + DispatchKey.CompositeImplicitAutograd, DispatchKey.DefaultBackend, } if options.backend_whitelist: diff --git a/tools/codegen/model.py b/tools/codegen/model.py index 16a0afb585f92..9dbf7ebe27cec 100644 --- a/tools/codegen/model.py +++ b/tools/codegen/model.py @@ -102,7 +102,7 @@ class DispatchKey(Enum): TESTING_ONLY_GenericMode = auto() NumDispatchKeys = auto() Autograd = auto() - Math = auto() + CompositeImplicitAutograd = auto() DefaultBackend = auto() EndOfAliasKeys = DefaultBackend @@ -134,7 +134,7 @@ class UseC10Dispatcher(Enum): # Dispatch keys that "support all backends". These codegen slightly differently # then backend specific keys. def is_generic_dispatch_key(dk: DispatchKey) -> bool: - return dk in {DispatchKey.DefaultBackend, DispatchKey.Math} + return dk in {DispatchKey.DefaultBackend, DispatchKey.CompositeImplicitAutograd} # CUDA specific dispatch keys def is_cuda_dispatch_key(dk: DispatchKey) -> bool: @@ -205,7 +205,7 @@ class NativeFunction: # case, that is equivalent to having written: # # dispatch: - # Math: $operator_name + # CompositeImplicitAutograd: $operator_name dispatch: Dict[DispatchKey, str] # The location in the YAML file were this native function entry was @@ -249,7 +249,7 @@ def is_abstract(self) -> bool: # Structured functions MUST have a dispatch table return True else: - return self.dispatch.keys() != {DispatchKey.Math} + return self.dispatch.keys() != {DispatchKey.CompositeImplicitAutograd} # NB: The benefit of defining a dataclass is that we automatically get # a constructor defined for all the fields we specify. No need @@ -337,20 +337,20 @@ def from_yaml(ei: Dict[str, object], loc: 'Location') -> 'NativeFunction': for k in ks.split(","): dispatch_key = DispatchKey.parse(k.strip()) dispatch[dispatch_key] = v - assert dispatch != {DispatchKey.Math: cpp.name(func)}, \ + assert dispatch != {DispatchKey.CompositeImplicitAutograd: cpp.name(func)}, \ "unnecessary dispatch table for this function; just delete the dispatch " \ "key entirely" - assert dispatch.keys() != {DispatchKey.Math}, \ - f"unexpected name for singleton Math dispatch entry: expected {cpp.name(func)} " \ - f"but got {dispatch[DispatchKey.Math]}. Rename your implementation to the expected " \ + assert dispatch.keys() != {DispatchKey.CompositeImplicitAutograd}, \ + f"unexpected name for singleton CompositeImplicitAutograd dispatch entry: expected {cpp.name(func)} " \ + f"but got {dispatch[DispatchKey.CompositeImplicitAutograd]}. Rename your implementation to the expected " \ "name, then delete the dispatch table" elif not structured and structured_delegate is None: - dispatch[DispatchKey.Math] = cpp.name(func) + dispatch[DispatchKey.CompositeImplicitAutograd] = cpp.name(func) - assert not (DispatchKey.DefaultBackend in dispatch and DispatchKey.Math in dispatch), \ - "cannot specify both DefaultBackend and Math on a single kernel; each " \ + assert not (DispatchKey.DefaultBackend in dispatch and DispatchKey.CompositeImplicitAutograd in dispatch), \ + "cannot specify both DefaultBackend and CompositeImplicitAutograd on a single kernel; each " \ "strictly subsumes the other. If you wanted to provide an explicit autograd " \ - "implementation, specify DefaultBackend; otherwise specify Math only" + "implementation, specify DefaultBackend; otherwise specify CompositeImplicitAutograd only" e.pop('__line__') assert not e, f"leftover entries: {e}" @@ -454,7 +454,7 @@ def __post_init__(self) -> None: if self.structured: # For now, structured composite kernels are not supported (need some # design work to figure out how to make the composite case work) - assert self.out.dispatch.keys() != {DispatchKey.Math} + assert self.out.dispatch.keys() != {DispatchKey.CompositeImplicitAutograd} assert self.functional.structured_delegate == self.out.func.name, \ f"{self.functional.func.name} delegates to {self.functional.structured_delegate} " \ diff --git a/torch/_python_dispatcher.py b/torch/_python_dispatcher.py index 9b4a603385960..9ed1b8856ade4 100644 --- a/torch/_python_dispatcher.py +++ b/torch/_python_dispatcher.py @@ -26,16 +26,16 @@ Kernels registered to this key MUST work for inference for all backends. - Autograd: alias key mapped to autograd of all backends like AutogradCPU, AutogradXLA, AutogradOther. Kernels registered to this key MUST work for autograd for all backends. -- Math: alias key Math = DefaultBackend + Autograd +- CompositeImplicitAutograd: alias key CompositeImplicitAutograd = DefaultBackend + Autograd Kernels registered to this key MUST work for both inference + autograd for all backends. Note we only allow registrations to alias keys inside pytorch core library. E.g you shouldn't register -a Math or DefaultBackend kernel from torch-xla extension, instead you should upstream the kernel into +a CompositeImplicitAutograd or DefaultBackend kernel from torch-xla extension, instead you should upstream the kernel into pytorch/pytorch repo so that it's available for all backends and continuously tested even without the extension. Usage: dispatcher = PythonDispatcher() - dispatcher.register(["CPU", "XLA", "Math"]) + dispatcher.register(["CPU", "XLA", "CompositeImplicitAutograd"]) print(dispatcher.dispatchTable()) # This tells you exactly which kernel is used for certain backend. # For more debugging information # print(dispatcher.keys()) @@ -57,7 +57,7 @@ class PythonDispatcher: alias_keys = [ "DefaultBackend", "Autograd", - "Math", + "CompositeImplicitAutograd", ] supported_keys = runtime_keys + alias_keys @@ -85,8 +85,8 @@ def register(self, dispatchKeys): if len(set(dispatchKeys)) != len(dispatchKeys): raise RuntimeError(f"Overriden is not allowed but found duplicates in {dispatchKeys}.") # We currently forbid this in codegen instead of C++ dispatcher. - if 'Math' in dispatchKeys and 'DefaultBackend' in dispatchKeys: - raise RuntimeError("Registration to both Math and DefaultBackend is not allowed.") + if 'CompositeImplicitAutograd' in dispatchKeys and 'DefaultBackend' in dispatchKeys: + raise RuntimeError("Registration to both CompositeImplicitAutograd and DefaultBackend is not allowed.") for key in dispatchKeys: if key not in self.supported_keys: raise RuntimeError(f"{key} is not supported, please select a dispatch key in {self.supported_keys}.") diff --git a/torch/csrc/autograd/VariableTypeManual.cpp b/torch/csrc/autograd/VariableTypeManual.cpp index ea728e1c4e0c5..90e055f6dc195 100644 --- a/torch/csrc/autograd/VariableTypeManual.cpp +++ b/torch/csrc/autograd/VariableTypeManual.cpp @@ -371,16 +371,16 @@ Tensor & detach_(Tensor & self) { } // Ops in the following registration list are registered as -// (1) Math kernels +// (1) CompositeImplicitAutograd kernels // (2) Autograd kernels // (3) DefaultBackend kernels and additionally Autograd kernels // The reason for (3) is that ops that also use dispatch (e.g. register CPU/CUDA/QuantizedCPU -// kernels) will skip picking up Math kernels for Autograd, so we register them to both +// kernels) will skip picking up CompositeImplicitAutograd kernels for Autograd, so we register them to both // DefaultBackend and Autograd instead. See // https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native#choosing-the-right-dispatch-keyword // for more details. // Invariant: -// - Ops registered to Math or DefaultBackend below must match `MANUAL_BACKEND` set in tools/autograd/gen_variable_type.py. +// - Ops registered to CompositeImplicitAutograd or DefaultBackend below must match `MANUAL_BACKEND` set in tools/autograd/gen_variable_type.py. // and they have manual_kernel_registration=True in native_functions.yaml. // - Ops registered to DispatchKey::Autograd below must be included in `MANUAL_AUTOGRAD` in tools/autograd/gen_variable_type.py @@ -398,13 +398,13 @@ TORCH_LIBRARY_IMPL(aten, DefaultBackend, m) { m.impl("requires_grad_", torch::dispatch(DispatchKey::DefaultBackend, TORCH_FN(VariableType::requires_grad_))); } -TORCH_LIBRARY_IMPL(aten, Math, m) { - m.impl("set_data", torch::dispatch(DispatchKey::Math, TORCH_FN(VariableType::set_data))); - m.impl("data", torch::dispatch(DispatchKey::Math, TORCH_FN(VariableType::data))); - m.impl("is_leaf", torch::dispatch(DispatchKey::Math, TORCH_FN(VariableType::is_leaf))); - m.impl("output_nr", torch::dispatch(DispatchKey::Math, TORCH_FN(VariableType::output_nr))); - m.impl("_version", torch::dispatch(DispatchKey::Math, TORCH_FN(VariableType::_version))); - m.impl("retain_grad", torch::dispatch(DispatchKey::Math, TORCH_FN(VariableType::retain_grad))); +TORCH_LIBRARY_IMPL(aten, CompositeImplicitAutograd, m) { + m.impl("set_data", torch::dispatch(DispatchKey::CompositeImplicitAutograd, TORCH_FN(VariableType::set_data))); + m.impl("data", torch::dispatch(DispatchKey::CompositeImplicitAutograd, TORCH_FN(VariableType::data))); + m.impl("is_leaf", torch::dispatch(DispatchKey::CompositeImplicitAutograd, TORCH_FN(VariableType::is_leaf))); + m.impl("output_nr", torch::dispatch(DispatchKey::CompositeImplicitAutograd, TORCH_FN(VariableType::output_nr))); + m.impl("_version", torch::dispatch(DispatchKey::CompositeImplicitAutograd, TORCH_FN(VariableType::_version))); + m.impl("retain_grad", torch::dispatch(DispatchKey::CompositeImplicitAutograd, TORCH_FN(VariableType::retain_grad))); } } // namespace diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp index 9c228795a3d98..2ff386d89a81b 100644 --- a/torch/csrc/utils/python_dispatch.cpp +++ b/torch/csrc/utils/python_dispatch.cpp @@ -32,7 +32,7 @@ c10::optional parseDispatchKey(const std::string& k) { {"CUDA", c10::DispatchKey::CUDA}, {"XLA", c10::DispatchKey::XLA}, {"QuantizedCPU", c10::DispatchKey::QuantizedCPU}, - {"Math", c10::DispatchKey::Math}, + {"CompositeImplicitAutograd", c10::DispatchKey::CompositeImplicitAutograd}, {"Autograd", c10::DispatchKey::Autograd}, {"DefaultBackend", c10::DispatchKey::DefaultBackend}, {"AutogradCPU", c10::DispatchKey::AutogradCPU},