Skip to content

Commit

Permalink
add special handling for resize_() in functionalization pass
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#77714

Approved by: https://github.com/ezyang
  • Loading branch information
bdhirsh authored and pytorchmergebot committed May 26, 2022
1 parent e9c54ae commit 92229ad
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 3 deletions.
40 changes: 40 additions & 0 deletions aten/src/ATen/FunctionalTensorWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,46 @@ void FunctionalTensorWrapper::replace_(const Tensor& other) {
}
}

void FunctionalTensorWrapper::maybe_replace_storage(const Tensor& other) {
// Note [resize_() in functionalization pass]
// resize_() is a special operator in functionalization because it can reallocate its underlying storage.
// This function is only ever called in the case that resize_() needs to reallocate its storage to a larger size.
//
// However, functionalization currently bans the following code:
// a = torch.ones(2)
// b = a.view(2)
// b.resize_(4) # b is a view tensor, that we are trying to increase the storage size of
//
// Why is this code difficult to handle?
// The functionalization pass currently keeps aliases in sync by making the following assumptions:
// - The “base” tensor always refers to “all of the data”
// - Whenever you have b = view_op(a), “b” should always refer to a subset of “a”s memory.
//
// The code above breaks that assumption b.resize_(4) actually needs to update "a"
// to tell it that it is now actually some slice of a pre-existing larger storage.
// We're also no longer re-generate "b" fully from "a" anymore, since "a" refers to a slice of "b"'s data.
//
// This is probably fixable in theory, but:
// - the fix would likey complicated the functionalization logic quite a bit.
// - the primary use case for resize_() today is resizing zero-sized tensors in out= variants of operators
// - resize_() also can give you weird results today if you try to resize_() a weirdly strided tensor.
//
// Given all of the above, for now we're just banning the above usage.
TORCH_CHECK(storage().use_count() == 1, "Attempted to resize a view tensor to a larger size. This is not allowed in the functionalization pass");
TORCH_CHECK(view_metas_.size() == 0, "Attempted to resize a view tensor to a larger size. This is not allowed in the functionalization pass");
// If this tensor is not a view (and has no outstanding views taken out on it),
// Then it's safe to throw out the old storage and replace it with the new, larger one.
storage_ = c10::Storage(c10::make_intrusive<functionalization::FunctionalStorageImpl>(other));
value_ = other;
generation_ = 0;
// And update the metadata on the wrapper to reflect the new sizes and strides
set_sizes_and_strides(value_.sizes(), value_.strides());
refresh_numel();
// (Technically we should be guaranteed that the tensor was already contiguous,
// since it's guaranteed not to have been a view. Doesnt hurt to run though)
refresh_contiguous();
}


void FunctionalTensorWrapper::sync_() {
if (is_up_to_date()) {
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/FunctionalTensorWrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
// replace_() swaps out the wrapped tensor, value_, with tmp.
void replace_(const Tensor& other);

// See Note[resize_() in functionalization pass]
void maybe_replace_storage(const Tensor& other);

~FunctionalTensorWrapper() override = default;

private:
Expand Down
86 changes: 86 additions & 0 deletions aten/src/ATen/FunctionalizeFallbackKernel.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
#include <ATen/core/dispatch/Dispatcher.h>
#include <ATen/core/LegacyTypeDispatch.h>
#include <ATen/EmptyTensor.h>
#include <ATen/FunctionalTensorWrapper.h>
#include <torch/library.h>
#include <c10/util/irange.h>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/ATen.h>
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/to_native.h>
#include <ATen/ops/resize.h>
#include <ATen/ops/as_strided.h>
#include <ATen/ops/as_strided_copy.h>
#include <ATen/ops/empty_strided_native.h>
#endif

namespace {
Expand Down Expand Up @@ -81,6 +88,84 @@ namespace {
}
}

// Vanilla implementation to compute contiguous strides given some sizes.
// Should probably refactor this into shared code (also used in TensorImpl.h)
std::vector<int64_t> compute_contiguous_strides(c10::IntArrayRef sizes) {
auto n = sizes.size();
std::vector<int64_t> strides(n);
if (n == 0) return strides;

strides[n - 1] = 1;
for (int64_t i = n - 2; i >= 0; --i) {
strides[i] = strides[i+1] * sizes[i];
}
return strides;
}

// resize_() is special because:
// - when we resize to a larger size, it acts as a mutation
// - when we resize to a smaller size, it acts as a view
// See Note [resize_ in Functionalization] for more dtails
const at::Tensor & resize__functionalization(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, c10::optional<at::MemoryFormat> memory_format) {
// First unwrap the tensor arguments
at::Tensor self_;
if (at::functionalization::impl::isFunctionalTensor(self)) {
at::functionalization::impl::sync(self);
self_ = at::functionalization::impl::from_functional_tensor(self);
} else {
self_ = self;
}
// Case 1: arguments are not functional tensors, so we no-op and redispatch.
if (!at::functionalization::impl::isFunctionalTensor(self)) {
at::AutoDispatchSkipFunctionalize guard;
at::Tensor tmp_output = self_.resize_(size, memory_format);
return self;
}

// Case 2: actually functionalize resize_()
at::Tensor tmp_output;
{
at::AutoDispatchSkipFunctionalize guard;
tmp_output = at::resize_functional(self_, size, memory_format);
}

auto itemsize = self.dtype().itemsize();
auto storage_offset = self.storage_offset();
auto new_size_bytes = at::detail::computeStorageNbytesContiguous(size, itemsize, storage_offset);
auto needs_resize_storage = new_size_bytes > self.storage().nbytes();

if (needs_resize_storage) {
// If resize_() actually increases the size of the storage, then we need to tell FunctionalTensorWrapper about it.
// See Note[resize_() in functionalization pass]
auto func_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(self);
func_impl->maybe_replace_storage(tmp_output);
// See the note - we're guaranteed at this point that "self" is *not* a view (and has no outstanding views)
// So we don't need to treat the output of resize as view tensor.
return self;
}

// Otherwise, we know that we're resizing to a smaller size.
// resize_() is effectively a view operator.
// The output of resizing is equivalent to taking a slice of a larger tensor.
// We have to emulate this "slicing" with an as_strided call.
auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS();
at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
[reapply_views = reapply_views, size = size.vec()](const at::Tensor & base, int64_t mutated_view_idx) -> at::Tensor {
if (reapply_views) {
return base.as_strided(size, compute_contiguous_strides(size));
} else {
return at::as_strided_copy(base, size, compute_contiguous_strides(size));
}
},
[size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx) -> at::Tensor {
return base.as_strided_scatter(mutated_view, size, compute_contiguous_strides(size));
}
);
at::functionalization::impl::mutate_view_meta(self, view_meta);
return self;
}


at::Tensor lift_functionalize(const at::Tensor & self) {
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(self));
return at::functionalization::impl::to_functional_tensor(self);
Expand All @@ -91,5 +176,6 @@ TORCH_LIBRARY_IMPL(_, Functionalize, m) {
}

TORCH_LIBRARY_IMPL(aten, Functionalize, m) {
m.impl("resize_", TORCH_FN(resize__functionalization));
m.impl("lift", TORCH_FN(lift_functionalize));
}
92 changes: 89 additions & 3 deletions test/test_functionalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,22 @@ def assert_functionalization(self, func, inpt, *, reapply_views=False):
finally:
torch._disable_functionalization()

self.assertEqual(out_ref.size(), out_functional.size())
# We need to sync the input tensors first, in case there are any queued mutations left.
torch._sync(input_functional)
torch._sync(out_functional)
self.assertEqual(out_ref, torch._from_functional_tensor(out_functional))
self.assertEqual(inpt, torch._from_functional_tensor(input_functional)) # input mutations should still occur

# Handle tests with multi-tensor outputs
if isinstance(out_ref, tuple) and isinstance(out_functional, tuple):
out_refs, out_functionals = list(out_ref), list(out_functional)
else:
out_refs, out_functionals = [out_ref], [out_functional]

for out_ref_, out_functional_ in zip(out_refs, out_functionals):
self.assertEqual(out_ref_.size(), out_functional_.size())
torch._sync(out_functional_)
out_functional_unwrapped = torch._from_functional_tensor(out_functional_)
self.assertEqual(out_ref_, out_functional_unwrapped)

def test_multiple_views_of_same_base(self):
def f(x):
y = x.view(-1)
Expand Down Expand Up @@ -501,6 +510,83 @@ def f(x):
$2 = torch._ops.aten.diagonal_copy.default($1)
$3 = torch._ops.aten.fill.Scalar($2, 0)""")

def test_resize_smaller(self):
def f(w):
# Resizing to a smaller size doesn't affect storage
x = w + 1
y = x.view(4, 4)
y.resize_(3, 3)
y2 = y.view(-1)
y2.add_(1)
z = y + 1
return z

self.assert_functionalization(f, torch.ones(8, 2))
logs = self.get_logs(f, torch.ones(8, 2))
self.assertExpectedInline('\n'.join(logs), """\
$0 = input('input')
$1 = torch._ops.aten.add.Tensor($0, 1)
$2 = torch._ops.aten.view_copy.default($1, [4, 4])
$3 = torch._ops.aten.resize.functional($2, [3, 3])
$4 = torch._ops.aten.as_strided_copy.default($2, [3, 3], [3, 1])
$5 = torch._ops.aten.view_copy.default($4, [-1])
$6 = torch._ops.aten.add.Tensor($5, 1)
$7 = torch._ops.aten.view_copy.default($1, [4, 4])
$8 = torch._ops.aten.as_strided_copy.default($7, [3, 3], [3, 1])
$9 = torch._ops.aten.view_copy.default($6, [3, 3])
$10 = torch._ops.aten.as_strided_scatter.default($7, $9, [3, 3], [3, 1])
$11 = torch._ops.aten.view_copy.default($10, [8, 2])
$12 = torch._ops.aten.view_copy.default($11, [4, 4])
$13 = torch._ops.aten.as_strided_copy.default($12, [3, 3], [3, 1])
$14 = torch._ops.aten.add.Tensor($13, 1)""")

def test_resize_larger_valid(self):
def f(x):
y = x + 1
# resizing a tensor to a larger size is only currently allowed
# if the tensor-to-resize is not a view / has no outstanding views.
# See Note [resize_() in functionalization pass]
y.resize_(5, 5)
y2 = y.view(25)
# Do a mutation to ensure that aliases of the output of resize_()
# propagate mutations correctly.
# I'm using fill_ specifically because I want to guarantee that
# none of the output has uninitialized memory at the end
# (since these tests compare the data output against a reference impl)
y2.fill_(1)
out = y + 1
return y, out

self.assert_functionalization(f, torch.ones(8, 2))
logs = self.get_logs(f, torch.ones(8, 2))
self.assertExpectedInline('\n'.join(logs), """\
$0 = input('input')
$1 = torch._ops.aten.add.Tensor($0, 1)
$2 = torch._ops.aten.resize.functional($1, [5, 5])
$3 = torch._ops.aten.view_copy.default($2, [25])
$4 = torch._ops.aten.fill.Scalar($3, 1)
$5 = torch._ops.aten.view_copy.default($4, [5, 5])
$6 = torch._ops.aten.add.Tensor($5, 1)""")

def test_resize_larger_invalid(self):
def f(x):
y = x + 1
z = y.view(4, 4)
# resizing a tensor to a larger size is only currently allowed
# if the tensor-to-resize is not a view / has no outstanding views.
# See Note [resize_() in functionalization pass]
# This should fail
z.resize_(5, 5)
z2 = z.view(25)
z2.fill_(1)
out = z + 1
return y, out

with self.assertRaisesRegex(
RuntimeError,
r'Attempted to resize a view tensor to a larger size. This is not allowed in the functionalization pass'):
self.assert_functionalization(f, torch.ones(8, 2))

def test_nested_functions_propagate_updates(self):
def g(x):
# Create a view of x
Expand Down
3 changes: 3 additions & 0 deletions torchgen/gen_functionalization_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,9 @@ def emit_registration_helper(f: NativeFunction) -> str:
if str(f.func.name) == "lift":
# See Note [Functionalization <> torch.Tensor constructor]
return []
if str(f.func.name) == "resize_":
# See Note [resize_ in Functionalization]
return []
assert not f.is_view_op
# functionalization needs to generate and register kernals for inplace ops.
# We *also* need to directly register CompositeImplicitAUtograd kernels
Expand Down

0 comments on commit 92229ad

Please sign in to comment.