Skip to content

Commit

Permalink
Add efficient zero tensors (pytorch#64837)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#64837

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D32144240

Pulled By: anjali411

fbshipit-source-id: d44096d882657c7f9270a16636900e0b73cefa40
  • Loading branch information
anjali411 authored and facebook-github-bot committed Dec 2, 2021
1 parent abda069 commit 668574a
Show file tree
Hide file tree
Showing 22 changed files with 337 additions and 21 deletions.
1 change: 1 addition & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ genrule(
"aten/src/ATen/RegisterQuantizedCPU.cpp",
"aten/src/ATen/RegisterSparseCPU.cpp",
"aten/src/ATen/RegisterSparseCsrCPU.cpp",
"aten/src/ATen/RegisterZeroTensor.cpp",
"aten/src/ATen/RegisterCompositeImplicitAutograd.cpp",
"aten/src/ATen/RegisterCompositeExplicitAutograd.cpp",
"aten/src/ATen/RegisterMeta.cpp",
Expand Down
104 changes: 104 additions & 0 deletions aten/src/ATen/ZeroTensorFallback.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
#include <ATen/ATen.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/native/UnaryOps.h>
#include <ATen/NativeFunctions.h>
#include <c10/util/irange.h>
#include <torch/library.h>
#include <ATen/native/MathBitFallThroughLists.h>

namespace at {

// TODO: add a note explaining the design decisions
// ZeroTensors are designed to be immutable. Thus, we error out when an in-place operation is performed on ZeroTensors
void zeroTensorFallback(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
const auto& arguments = op.schema().arguments();
const auto num_arguments = arguments.size();
const auto stack_start = stack->size() - num_arguments;

c10::optional<bool> is_write;
for (const auto i : c10::irange(num_arguments)) {
const auto& alias_info = arguments[i].alias_info();
if (alias_info != nullptr) {
if (is_write.has_value()) {
TORCH_CHECK(*is_write == alias_info->isWrite(),
"Unsupported operator for ", "ZeroTensorFallback: ", op.schema().name(),
"ZeroTensor fallback doesn't work for operators with a mix "
"mutable and non-mutable inputs that alias with outputs, "
"this must be implemented manually. "
"If you got this error on a core op, please report a bug to PyTorch.");
} else {
is_write = alias_info->isWrite();
}
}
}

if (is_write.has_value() && !*is_write) {
// We assume that view operators automatically handle the ZeroTensor bit
// correctly by propagating the dispatch key in key_set.
// This is not necessarily always right, so you should test these cases.
op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::ZeroTensor), stack);
return;
}

for (const auto i : c10::irange(num_arguments)) {
auto& ivalue = (*stack)[stack_start + i];
if (!(ivalue.isTensor() || ivalue.isTensorList())) {
continue;
}
const auto& argument = arguments[i];
bool mut_arg = false;

if (argument.alias_info()) {
// Was already tested by is_write loop above
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(argument.alias_info()->isWrite());
mut_arg = true;
}

if (ivalue.isTensor()) {
auto tensor = std::move(ivalue).toTensor();
if (tensor._is_zerotensor()) {
TORCH_CHECK(!mut_arg, "ZeroTensors are immutable. Please use the materialized zero tensor ",
"obtained using .clone() if you want a mutable tensor.");
tensor = at::zeros({}, tensor.options()).expand(tensor.sizes());
}
(*stack)[stack_start + i] = std::move(tensor);
} else if (ivalue.isTensorList()) {
auto tensors = std::move(ivalue).toTensorList();
for(const auto j : c10::irange(tensors.size())) {
const Tensor& tensor = tensors[j];
if (tensor._is_zerotensor()) {
// TODO: assert requires_grad=False
//_like should not propagate zerotensor dispatch key
TORCH_CHECK(!mut_arg, "ZeroTensors are immutable. Please use the materialized zero tensor ",
"obtained using .clone() if you want a mutable tensor.");
tensors[j] = at::zeros({}, tensor.options()).expand(tensor.sizes());
}
}
(*stack)[stack_start + i] = std::move(tensors);
}
}

op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::ZeroTensor), stack);
}


TORCH_LIBRARY_IMPL(_, ZeroTensor, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&zeroTensorFallback>());
}

TORCH_LIBRARY_IMPL(aten, ZeroTensor, m) {
m.impl("zeros_like", torch::CppFunction::makeFallthrough());
m.impl("mul.Scalar", torch::CppFunction::makeFallthrough());
m.impl("add.Scalar", torch::CppFunction::makeFallthrough());
m.impl("copy_", torch::CppFunction::makeFallthrough());
m.impl("clone", torch::CppFunction::makeFallthrough());
// The functions in the list below have a specific registeration in native_functions.yaml and
// do not use the fallback.
// m.impl("mul.Tensor", torch::CppFunction::makeFallthrough());
// m.impl("add.Tensor", torch::CppFunction::makeFallthrough());

TORCH_VIEW_FNS(m)
TENSOR_UTILITIES_AND_CONSTRUCTORS(m)
}
} // namespace at
8 changes: 8 additions & 0 deletions aten/src/ATen/core/TensorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,14 @@ class TORCH_API TensorBase {
return impl_->storage().is_alias_of(other.storage());
}

inline bool _is_zerotensor() const {
return impl_->_is_zerotensor();
}

inline void _set_zero(bool zero) const {
impl_->_set_zero(zero);
}

inline bool is_conj() const {
return impl_->is_conj();
}
Expand Down
42 changes: 41 additions & 1 deletion aten/src/ATen/native/BinaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
#include <ATen/MemoryOverlap.h>
#include <ATen/NativeFunctions.h>
#include <ATen/native/TensorIterator.h>

#include <ATen/ExpandUtils.h>
#include <ATen/RedispatchFunctions.h>
#include <torch/library.h>

namespace at {
Expand Down Expand Up @@ -625,6 +626,45 @@ Tensor& mul_(Tensor& self, const Scalar& other) {
return at::mul_out(self, wrapped_scalar_tensor(other), self); // redispatch!
}

Device correct_out_device(const Tensor& self, const Tensor& other) {
if (self.device() == at::kCPU){
return other.device();
} else {
return self.device();
}
}

Tensor mul_zerotensor(const Tensor& self, const Tensor& other) {
auto out_device = correct_out_device(self, other);
// hack to use the TensorIterator to get the correct broadcasting and type promotion logic
auto device_ = Device(DeviceType::Meta);
auto meta_out = at::redispatch::mul(c10::DispatchKeySet(at::DispatchKey::Meta), self.to(device_), other.to(device_));
return at::_efficientzerotensor(meta_out.sizes(), meta_out.options().device(out_device));
}

Tensor add_zerotensor(const Tensor& self, const Tensor& other, const Scalar& alpha) {
auto out_device = correct_out_device(self, other);
// hack to use the TensorIterator to get the correct broadcasting and type promotion logic
auto device_ = Device(DeviceType::Meta);
auto meta_out = at::redispatch::add(c10::DispatchKeySet(at::DispatchKey::Meta), self.to(device_), other.to(device_));

auto get_out_like = [&] (const Tensor& tensor)
{
auto sizes = meta_out.sizes();
return at::_to_copy(tensor.expand(sizes), meta_out.options().device(out_device));
};

if (self._is_zerotensor()) {
if (other._is_zerotensor()) {
return at::_efficientzerotensor(meta_out.sizes(), meta_out.options().device(out_device));
}
auto res = get_out_like(other);
return alpha.equal(1) ? res : res.mul(alpha);
} else {
return get_out_like(self);
}
}

// multiply, alias for mul
Tensor& multiply_out(const Tensor& self, const Tensor& other, Tensor& result) {
return at::mul_out(result, self, other);
Expand Down
6 changes: 6 additions & 0 deletions aten/src/ATen/native/Copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,12 @@ Tensor& copy_(Tensor& self, const Tensor& src, bool non_blocking) {
auto maybe_outnames = namedinference::compute_broadcast_outnames(self, src);
{
NoNamesGuard guard;
if (self._is_zerotensor()) {
TORCH_CHECK(false, "ZeroTensors are immutable. Please materialize the tensor using `.clone()`, if you want a mutable zero tensor.");
}
if (src._is_zerotensor()) {
return self.zero_();
}
copy_impl(self, src, non_blocking);
}
namedinference::propagate_names_if_nonempty(self, maybe_outnames);
Expand Down
3 changes: 1 addition & 2 deletions aten/src/ATen/native/Resize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ bool resize_output_check(const Tensor& output, IntArrayRef shape) {

static auto kFunctorchWrappedTensors = DispatchKeySet({
DispatchKey::FuncTorchGradWrapper,
DispatchKey::FuncTorchBatched,
DispatchKey::FuncTorchPython});
DispatchKey::FuncTorchBatched});

static bool is_functorch_wrapped_tensor(const Tensor& tensor) {
auto key_set = tensor.unsafeGetTensorImpl()->key_set();
Expand Down
23 changes: 21 additions & 2 deletions aten/src/ATen/native/TensorFactories.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,9 +349,10 @@ Tensor empty_like(
namedinference::propagate_names(result, self.names());
}

// never propagate Conjugate and Negative dispatch key
// never propagate Conjugate, Negative, and ZeroTensor dispatch key
result._set_conj(false);
result._set_neg(false);
result._set_zero(false);
return result;
}

Expand Down Expand Up @@ -1057,6 +1058,20 @@ Tensor zeros(IntArrayRef size,
return result.zero_();
}

Tensor _efficientzerotensor(IntArrayRef size,
c10::optional<ScalarType> dtype,
c10::optional<Layout> layout,
c10::optional<Device> device,
c10::optional<bool> pin_memory) {
caffe2::TypeMeta dtype_ = scalarTypeToTypeMeta(dtype_or_default(dtype));
Tensor tensor = detail::make_tensor<TensorImpl>(c10::DispatchKeySet({at::DispatchKey::ZeroTensor}), dtype_, device);
// Default TensorImpl has size [0]
if (size.size() != 1 || size[0] != 0) {
tensor.unsafeGetTensorImpl()->set_sizes_contiguous(size);
}
return tensor;
}

Tensor& zeros_out(IntArrayRef size, Tensor& result) {
if (result.is_sparse()) {
result.sparse_resize_and_clear_(size, size.size(), 0.);
Expand Down Expand Up @@ -1427,7 +1442,11 @@ Tensor clone(const Tensor& src, c10::optional<c10::MemoryFormat> optional_memory
self = at::empty_like(src, src.options(), memory_format);
}

self.copy_(src);
if (src._is_zerotensor()) {
self.zero_();
} else {
self.copy_(src);
}
return self;
}

Expand Down
8 changes: 8 additions & 0 deletions aten/src/ATen/native/TensorShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,9 @@ Tensor make_qtensor(const Tensor& self, IntArrayRef size, IntArrayRef stride, Qu
}

Tensor as_strided_tensorimpl(const Tensor& self, IntArrayRef size, IntArrayRef stride, optional<int64_t> storage_offset_) {
if (self._is_zerotensor()) {
return at::_efficientzerotensor(size, self.options());
}
auto storage_offset = storage_offset_.value_or(self.storage_offset());
auto result = detail::make_tensor<TensorImpl>(
c10::TensorImpl::VIEW, Storage(self.storage()), self.key_set(), self.dtype());
Expand Down Expand Up @@ -1024,6 +1027,11 @@ Tensor alias_with_sizes_and_strides(
const Tensor& self,
const Vec& sizes,
const Vec& strides) {
// it's okay to return a new tensor here since we disallow in-place operation on ZeroTensors
if (self._is_zerotensor()) {
return at::_efficientzerotensor(sizes, self.options());
}

Tensor self_;
if (self.is_quantized()) {
self_ = detail::make_tensor<QTensorImpl>(
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/native/TypeProperties.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ bool is_signed(const Tensor &self) {
return self.is_signed();
}

bool _is_zerotensor(const Tensor& self) {
return self._is_zerotensor();
}

bool is_conj(const Tensor& self) {
return self.is_conj();
}
Expand Down
17 changes: 14 additions & 3 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@
SparseCPU, SparseCUDA: add_sparse
SparseCsrCPU, SparseCsrCUDA: add_sparse_csr
MkldnnCPU: mkldnn_add
ZeroTensor: add_zerotensor

- func: add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
device_check: NoCheck # TensorIterator
Expand Down Expand Up @@ -694,7 +695,7 @@
- func: as_strided(Tensor(a) self, int[] size, int[] stride, int? storage_offset=None) -> Tensor(a)
variants: function, method
dispatch:
CPU, CUDA, Meta: as_strided_tensorimpl
ZeroTensor, CPU, CUDA, Meta: as_strided_tensorimpl
QuantizedCPU, QuantizedCUDA: as_strided_qtensorimpl
device_check: NoCheck
device_guard: False
Expand Down Expand Up @@ -2429,6 +2430,11 @@
device_guard: False
manual_cpp_binding: True

- func: _is_zerotensor(Tensor self) -> bool
variants: function, method
device_guard: False
manual_cpp_binding: True

- func: is_neg(Tensor self) -> bool
variants: function, method
device_guard: False
Expand Down Expand Up @@ -3162,6 +3168,7 @@
dispatch:
SparseCPU, SparseCUDA: mul_sparse
MkldnnCPU: mkldnn_mul
ZeroTensor: mul_zerotensor

- func: mul_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
device_check: NoCheck # TensorIterator
Expand Down Expand Up @@ -3624,7 +3631,7 @@
device_check: NoCheck
device_guard: False
dispatch:
CPU, CUDA, Meta, QuantizedCPU, QuantizedCUDA: _reshape_alias
CPU, CUDA, Meta, QuantizedCPU, QuantizedCUDA, ZeroTensor: _reshape_alias
# We don't need to support mkldnn since this is handled explicitly by the reshape operator.

- func: _mkldnn_reshape(Tensor self, int[] shape) -> Tensor
Expand Down Expand Up @@ -4766,6 +4773,10 @@
device_check: NoCheck
device_guard: False

- func: _efficientzerotensor(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
dispatch:
CompositeExplicitAutograd: _efficientzerotensor

- func: zeros(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor

- func: zeros.out(int[] size, *, Tensor(a!) out) -> Tensor(a!)
Expand Down Expand Up @@ -5825,7 +5836,7 @@
device_check: NoCheck
device_guard: False
dispatch:
CPU, CUDA, Meta, QuantizedCPU, QuantizedCUDA: view
ZeroTensor, CPU, CUDA, Meta, QuantizedCPU, QuantizedCUDA: view
MkldnnCPU: mkldnn_view

# Warning: If you want to change the name or overload name of this
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/templates/Functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,10 @@ inline bool is_inference(const Tensor& tensor) {
return tensor.is_inference();
}

inline bool _is_zerotensor(const Tensor& tensor) {
return tensor._is_zerotensor();
}

inline bool is_conj(const Tensor& tensor) {
return tensor.is_conj();
}
Expand Down
7 changes: 4 additions & 3 deletions c10/core/DispatchKey.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ const char* toString(DispatchKey t) {
return "AutogradPrivateUse3";
case DispatchKey::AutogradOther:
return "AutogradOther";

case DispatchKey::ZeroTensor:
return "ZeroTensor";
case DispatchKey::BackendSelect:
return "BackendSelect";
case DispatchKey::Named:
Expand Down Expand Up @@ -149,8 +152,6 @@ const char* toString(DispatchKey t) {
// https://github.com/zou3519/functorch
// We plan on eventually upstreaming the prototype into core, at which
// point it will have a different design that should use fewer keys.
case DispatchKey::FuncTorchPython:
return "FuncTorchPython";
case DispatchKey::FuncTorchDynamicLayerBackMode:
return "FuncTorchDynamicLayerBackMode";
case DispatchKey::FuncTorchDynamicLayerFrontMode:
Expand Down Expand Up @@ -242,10 +243,10 @@ c10::DispatchKey parseDispatchKey(const std::string& k) {
{"PrivateUse3", c10::DispatchKey::PrivateUse3},
{"BackendSelect", c10::DispatchKey::BackendSelect},
{"Python", c10::DispatchKey::Python},
{"FuncTorchPython", c10::DispatchKey::FuncTorchPython},
{"Named", c10::DispatchKey::Named},
{"Conjugate", c10::DispatchKey::Conjugate},
{"Negative", c10::DispatchKey::Negative},
{"ZeroTensor", c10::DispatchKey::ZeroTensor},
{"FuncTorchDynamicLayerBackMode",
c10::DispatchKey::FuncTorchDynamicLayerBackMode},
{"ADInplaceOrView", c10::DispatchKey::ADInplaceOrView},
Expand Down
Loading

0 comments on commit 668574a

Please sign in to comment.