From c79489c8e69f965f3e5af8f3f39df78e7d4732ba Mon Sep 17 00:00:00 2001 From: albanD Date: Thu, 24 Nov 2022 03:39:55 +0000 Subject: [PATCH] Expose to python the backward AD view_func (#89586) This will be useful for other systems (AOTAutograd) that want to replay autograd views. FYI @bdhirsh Pull Request resolved: https://github.com/pytorch/pytorch/pull/89586 Approved by: https://github.com/soulitzer --- test/test_autograd.py | 22 ++++++++++++++++++ torch/csrc/autograd/autograd_meta.cpp | 10 ++++---- torch/csrc/autograd/python_variable.cpp | 31 +++++++++++++++++++++++++ torch/csrc/autograd/variable.h | 5 ++++ torch/csrc/cuda/comm.cpp | 16 ++++++------- torch/overrides.py | 1 + 6 files changed, 73 insertions(+), 12 deletions(-) diff --git a/test/test_autograd.py b/test/test_autograd.py index 777b790da6559..4b1e97cb3b2b5 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -7236,6 +7236,28 @@ def get_out(): err_msg = "RuntimeError: one of the variables needed for gradient computation" self.assertTrue(err_msg in e.output.decode("utf-8")) + def test_view_func_replay(self): + def _assert_match_metadata(a, b): + self.assertEqual(a.size(), b.size()) + self.assertEqual(a.stride(), b.stride()) + self.assertEqual(a.storage_offset(), b.storage_offset()) + + def _test_op(fn, inp, args): + out = fn(inp, *args) + self.assertTrue(out._is_view) + self.assertTrue(out._base is inp) + + new_inp = inp.clone() + _assert_match_metadata(new_inp, inp) + new_out = out._view_func(new_inp) + _assert_match_metadata(new_out, out) + + _test_op(torch.select, torch.rand(2, 2), (0, 0)) + _test_op(torch.as_strided, torch.rand(2, 2), ((4,), (1,))) + _test_op(torch.view_as_complex, torch.rand(2, 2), ()) + _test_op(torch.view_as_real, torch.rand(2, 2, dtype=torch.cfloat), ()) + + def index_perm_variable(shape, max_indices): if not isinstance(shape, tuple): shape = (shape,) diff --git a/torch/csrc/autograd/autograd_meta.cpp b/torch/csrc/autograd/autograd_meta.cpp index db00d67576d3b..d11cd68e1800a 100644 --- a/torch/csrc/autograd/autograd_meta.cpp +++ b/torch/csrc/autograd/autograd_meta.cpp @@ -82,7 +82,7 @@ using at::Tensor; // base if needed. Case 5 is handled in fw_grad by reading the forward grad from // the base if needed. -namespace { +namespace utils { // Enforcing that the metadata between the primal and tangent are same has two // goals: @@ -139,7 +139,8 @@ bool has_same_meta(const Variable& base, const Variable& other) { } return true; } -} // anonymous namespace + +} // namespace utils // This function is will ensure that the fw_grad_ is properly a view of the base // for inplace ops on Tensors that do not have forward grad originally. @@ -219,7 +220,8 @@ void AutogradMeta::set_fw_grad( // Enforce same meta here to make sure that the view op below is // always valid Tensor new_base_fw_grad; - if (has_same_meta(new_grad, base) && has_same_meta(new_grad, self)) { + if (utils::has_same_meta(new_grad, base) && + utils::has_same_meta(new_grad, self)) { // TODO extend this special case to when the underlying storage of // new_grad can be re-used. new_base_fw_grad = new_grad; @@ -248,7 +250,7 @@ void AutogradMeta::set_fw_grad( } // Enforce the basic layout constraint - if (!has_same_meta(new_grad, self)) { + if (!utils::has_same_meta(new_grad, self)) { if (is_view_) { auto this_view_meta = static_cast(this); TORCH_INTERNAL_ASSERT( diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index e3ab10c7499ca..a08d6f7761fd2 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -684,6 +684,36 @@ static PyObject* THPVariable_fix_weakref(PyObject* self, PyObject* noargs) { Py_RETURN_NONE; } +static PyObject* THPVariable_view_func(PyObject* self_, PyObject* arg) { + HANDLE_TH_ERRORS + const auto& self = THPVariable_Unpack(self_); + TORCH_CHECK( + THPVariable_Check(arg), + "_view_func expect a single argument that is a Tensor"); + const auto& new_base = THPVariable_Unpack(arg); + + // Ensure that self is indeed a backward differentiable view + auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(self); + TORCH_CHECK( + diff_view_meta && diff_view_meta->has_bw_view(), + "_view_func can only be called on " + "a Tensor that is a backward differentiable view."); + const auto& view_info = diff_view_meta->get_backward_view(); + // Ensure that the newly provided base is similar to the original base + TORCH_CHECK( + torch::autograd::utils::has_same_meta(new_base, view_info.base_), + "The new base passed to _view_func must have the same metadata as the Tensors's base"); + + // Do the actual view replay + if (view_info.has_view_fn()) { + return THPVariable_Wrap(view_info.view_fn()(new_base)); + } else { + return THPVariable_Wrap(new_base.as_strided( + self.sizes(), self.strides(), self.storage_offset())); + } + END_HANDLE_TH_ERRORS +} + // Instantiates a subclass of self with the same data. static PyObject* THPVariable_as_subclass( PyObject* _self, @@ -1645,6 +1675,7 @@ static PyMethodDef extra_methods[] = { METH_STATIC | METH_VARARGS | METH_KEYWORDS, nullptr}, {"_fix_weakref", THPVariable_fix_weakref, METH_NOARGS, nullptr}, + {"_view_func", THPVariable_view_func, METH_O, nullptr}, {nullptr}}; /* From https://github.com/python/cpython/blob/v3.7.0/Modules/xxsubtype.c diff --git a/torch/csrc/autograd/variable.h b/torch/csrc/autograd/variable.h index 49905fe803f46..52ce34ec394d0 100644 --- a/torch/csrc/autograd/variable.h +++ b/torch/csrc/autograd/variable.h @@ -791,6 +791,11 @@ inline Variable make_variable( return Variable(); } +namespace utils { + +TORCH_API bool has_same_meta(const Variable& base, const Variable& other); + +} // namespace utils } // namespace autograd } // namespace torch diff --git a/torch/csrc/cuda/comm.cpp b/torch/csrc/cuda/comm.cpp index 117f6b571792b..e215ce0e3ed67 100644 --- a/torch/csrc/cuda/comm.cpp +++ b/torch/csrc/cuda/comm.cpp @@ -180,12 +180,12 @@ tensor_list2d broadcast_coalesced( unique_type_checker type_checker; at::cuda::CUDAGuard device_guard(devices[0]); - for (auto& chunk : utils::take_tensors(tensors, buffer_size)) { + for (auto& chunk : torch::utils::take_tensors(tensors, buffer_size)) { auto type_id = chunk.type_id(); type_checker.show(type_id); std::vector results; if (chunk.options().is_sparse()) { - auto flat_tuple = utils::flatten_sparse_tensors(chunk.tensors); + auto flat_tuple = torch::utils::flatten_sparse_tensors(chunk.tensors); auto broadcast_indices = broadcast(flat_tuple.first, devices); auto broadcast_values = broadcast(flat_tuple.second, devices); results.reserve(devices.size()); @@ -194,20 +194,20 @@ tensor_list2d broadcast_coalesced( auto& device_outputs = outputs[i]; auto& inds = broadcast_indices[i]; auto& vals = broadcast_values[i]; - for (const auto& var : - utils::unflatten_sparse_tensors(inds, vals, chunk.tensors)) { + for (const auto& var : torch::utils::unflatten_sparse_tensors( + inds, vals, chunk.tensors)) { // See NOTE [ Version Counter in comm.*_coalesced ] device_outputs.push_back(make_variable(var.tensor_data(), false)); } } } else { - auto results = - broadcast(utils::flatten_dense_tensors(chunk.tensors), devices); + auto results = broadcast( + torch::utils::flatten_dense_tensors(chunk.tensors), devices); for (size_t i = 1, num_devices = devices.size(); i < num_devices; ++i) { device_guard.set_index(devices[i]); auto& device_outputs = outputs[i]; for (auto& var : - utils::unflatten_dense_tensors(results[i], chunk.tensors)) { + torch::utils::unflatten_dense_tensors(results[i], chunk.tensors)) { // See NOTE [ Version Counter in comm.*_coalesced ] device_outputs.push_back(make_variable(var.tensor_data(), false)); } @@ -218,7 +218,7 @@ tensor_list2d broadcast_coalesced( // If we only saw a single tensor type, then we can skip expensive reordering if (!type_checker.unique) { for (auto& o : outputs) - utils::reorder_tensors_like(o, tensors); + torch::utils::reorder_tensors_like(o, tensors); } return outputs; } diff --git a/torch/overrides.py b/torch/overrides.py index 21cfe2477bd6f..ae2b23e17d30b 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -276,6 +276,7 @@ def get_ignored_functions() -> Set[Callable]: Tensor._typed_storage, Tensor._reduce_ex_internal, Tensor._fix_weakref, + Tensor._view_func, Tensor._make_wrapper_subclass, Tensor._python_dispatch.__get__, Tensor._has_symbolic_sizes_strides.__get__,