forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add efficient zero tensors (pytorch#64837)
Summary: Pull Request resolved: pytorch#64837 Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D32144240 Pulled By: anjali411 fbshipit-source-id: d44096d882657c7f9270a16636900e0b73cefa40
- Loading branch information
1 parent
abda069
commit 668574a
Showing
22 changed files
with
337 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
#include <ATen/ATen.h> | ||
#include <ATen/core/dispatch/Dispatcher.h> | ||
#include <ATen/core/op_registration/op_registration.h> | ||
#include <ATen/native/UnaryOps.h> | ||
#include <ATen/NativeFunctions.h> | ||
#include <c10/util/irange.h> | ||
#include <torch/library.h> | ||
#include <ATen/native/MathBitFallThroughLists.h> | ||
|
||
namespace at { | ||
|
||
// TODO: add a note explaining the design decisions | ||
// ZeroTensors are designed to be immutable. Thus, we error out when an in-place operation is performed on ZeroTensors | ||
void zeroTensorFallback(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, torch::jit::Stack* stack) { | ||
const auto& arguments = op.schema().arguments(); | ||
const auto num_arguments = arguments.size(); | ||
const auto stack_start = stack->size() - num_arguments; | ||
|
||
c10::optional<bool> is_write; | ||
for (const auto i : c10::irange(num_arguments)) { | ||
const auto& alias_info = arguments[i].alias_info(); | ||
if (alias_info != nullptr) { | ||
if (is_write.has_value()) { | ||
TORCH_CHECK(*is_write == alias_info->isWrite(), | ||
"Unsupported operator for ", "ZeroTensorFallback: ", op.schema().name(), | ||
"ZeroTensor fallback doesn't work for operators with a mix " | ||
"mutable and non-mutable inputs that alias with outputs, " | ||
"this must be implemented manually. " | ||
"If you got this error on a core op, please report a bug to PyTorch."); | ||
} else { | ||
is_write = alias_info->isWrite(); | ||
} | ||
} | ||
} | ||
|
||
if (is_write.has_value() && !*is_write) { | ||
// We assume that view operators automatically handle the ZeroTensor bit | ||
// correctly by propagating the dispatch key in key_set. | ||
// This is not necessarily always right, so you should test these cases. | ||
op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::ZeroTensor), stack); | ||
return; | ||
} | ||
|
||
for (const auto i : c10::irange(num_arguments)) { | ||
auto& ivalue = (*stack)[stack_start + i]; | ||
if (!(ivalue.isTensor() || ivalue.isTensorList())) { | ||
continue; | ||
} | ||
const auto& argument = arguments[i]; | ||
bool mut_arg = false; | ||
|
||
if (argument.alias_info()) { | ||
// Was already tested by is_write loop above | ||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(argument.alias_info()->isWrite()); | ||
mut_arg = true; | ||
} | ||
|
||
if (ivalue.isTensor()) { | ||
auto tensor = std::move(ivalue).toTensor(); | ||
if (tensor._is_zerotensor()) { | ||
TORCH_CHECK(!mut_arg, "ZeroTensors are immutable. Please use the materialized zero tensor ", | ||
"obtained using .clone() if you want a mutable tensor."); | ||
tensor = at::zeros({}, tensor.options()).expand(tensor.sizes()); | ||
} | ||
(*stack)[stack_start + i] = std::move(tensor); | ||
} else if (ivalue.isTensorList()) { | ||
auto tensors = std::move(ivalue).toTensorList(); | ||
for(const auto j : c10::irange(tensors.size())) { | ||
const Tensor& tensor = tensors[j]; | ||
if (tensor._is_zerotensor()) { | ||
// TODO: assert requires_grad=False | ||
//_like should not propagate zerotensor dispatch key | ||
TORCH_CHECK(!mut_arg, "ZeroTensors are immutable. Please use the materialized zero tensor ", | ||
"obtained using .clone() if you want a mutable tensor."); | ||
tensors[j] = at::zeros({}, tensor.options()).expand(tensor.sizes()); | ||
} | ||
} | ||
(*stack)[stack_start + i] = std::move(tensors); | ||
} | ||
} | ||
|
||
op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::ZeroTensor), stack); | ||
} | ||
|
||
|
||
TORCH_LIBRARY_IMPL(_, ZeroTensor, m) { | ||
m.fallback(torch::CppFunction::makeFromBoxedFunction<&zeroTensorFallback>()); | ||
} | ||
|
||
TORCH_LIBRARY_IMPL(aten, ZeroTensor, m) { | ||
m.impl("zeros_like", torch::CppFunction::makeFallthrough()); | ||
m.impl("mul.Scalar", torch::CppFunction::makeFallthrough()); | ||
m.impl("add.Scalar", torch::CppFunction::makeFallthrough()); | ||
m.impl("copy_", torch::CppFunction::makeFallthrough()); | ||
m.impl("clone", torch::CppFunction::makeFallthrough()); | ||
// The functions in the list below have a specific registeration in native_functions.yaml and | ||
// do not use the fallback. | ||
// m.impl("mul.Tensor", torch::CppFunction::makeFallthrough()); | ||
// m.impl("add.Tensor", torch::CppFunction::makeFallthrough()); | ||
|
||
TORCH_VIEW_FNS(m) | ||
TENSOR_UTILITIES_AND_CONSTRUCTORS(m) | ||
} | ||
} // namespace at |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.