Skip to content

Commit

Permalink
Expose to python the backward AD view_func (pytorch#89586)
Browse files Browse the repository at this point in the history
This will be useful for other systems (AOTAutograd) that want to replay autograd views.

FYI @bdhirsh
Pull Request resolved: pytorch#89586
Approved by: https://github.com/soulitzer
  • Loading branch information
albanD authored and pytorchmergebot committed Nov 24, 2022
1 parent 4cb6bbb commit c79489c
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 12 deletions.
22 changes: 22 additions & 0 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7236,6 +7236,28 @@ def get_out():
err_msg = "RuntimeError: one of the variables needed for gradient computation"
self.assertTrue(err_msg in e.output.decode("utf-8"))

def test_view_func_replay(self):
def _assert_match_metadata(a, b):
self.assertEqual(a.size(), b.size())
self.assertEqual(a.stride(), b.stride())
self.assertEqual(a.storage_offset(), b.storage_offset())

def _test_op(fn, inp, args):
out = fn(inp, *args)
self.assertTrue(out._is_view)
self.assertTrue(out._base is inp)

new_inp = inp.clone()
_assert_match_metadata(new_inp, inp)
new_out = out._view_func(new_inp)
_assert_match_metadata(new_out, out)

_test_op(torch.select, torch.rand(2, 2), (0, 0))
_test_op(torch.as_strided, torch.rand(2, 2), ((4,), (1,)))
_test_op(torch.view_as_complex, torch.rand(2, 2), ())
_test_op(torch.view_as_real, torch.rand(2, 2, dtype=torch.cfloat), ())


def index_perm_variable(shape, max_indices):
if not isinstance(shape, tuple):
shape = (shape,)
Expand Down
10 changes: 6 additions & 4 deletions torch/csrc/autograd/autograd_meta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ using at::Tensor;
// base if needed. Case 5 is handled in fw_grad by reading the forward grad from
// the base if needed.

namespace {
namespace utils {

// Enforcing that the metadata between the primal and tangent are same has two
// goals:
Expand Down Expand Up @@ -139,7 +139,8 @@ bool has_same_meta(const Variable& base, const Variable& other) {
}
return true;
}
} // anonymous namespace

} // namespace utils

// This function is will ensure that the fw_grad_ is properly a view of the base
// for inplace ops on Tensors that do not have forward grad originally.
Expand Down Expand Up @@ -219,7 +220,8 @@ void AutogradMeta::set_fw_grad(
// Enforce same meta here to make sure that the view op below is
// always valid
Tensor new_base_fw_grad;
if (has_same_meta(new_grad, base) && has_same_meta(new_grad, self)) {
if (utils::has_same_meta(new_grad, base) &&
utils::has_same_meta(new_grad, self)) {
// TODO extend this special case to when the underlying storage of
// new_grad can be re-used.
new_base_fw_grad = new_grad;
Expand Down Expand Up @@ -248,7 +250,7 @@ void AutogradMeta::set_fw_grad(
}

// Enforce the basic layout constraint
if (!has_same_meta(new_grad, self)) {
if (!utils::has_same_meta(new_grad, self)) {
if (is_view_) {
auto this_view_meta = static_cast<DifferentiableViewMeta*>(this);
TORCH_INTERNAL_ASSERT(
Expand Down
31 changes: 31 additions & 0 deletions torch/csrc/autograd/python_variable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,36 @@ static PyObject* THPVariable_fix_weakref(PyObject* self, PyObject* noargs) {
Py_RETURN_NONE;
}

static PyObject* THPVariable_view_func(PyObject* self_, PyObject* arg) {
HANDLE_TH_ERRORS
const auto& self = THPVariable_Unpack(self_);
TORCH_CHECK(
THPVariable_Check(arg),
"_view_func expect a single argument that is a Tensor");
const auto& new_base = THPVariable_Unpack(arg);

// Ensure that self is indeed a backward differentiable view
auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(self);
TORCH_CHECK(
diff_view_meta && diff_view_meta->has_bw_view(),
"_view_func can only be called on "
"a Tensor that is a backward differentiable view.");
const auto& view_info = diff_view_meta->get_backward_view();
// Ensure that the newly provided base is similar to the original base
TORCH_CHECK(
torch::autograd::utils::has_same_meta(new_base, view_info.base_),
"The new base passed to _view_func must have the same metadata as the Tensors's base");

// Do the actual view replay
if (view_info.has_view_fn()) {
return THPVariable_Wrap(view_info.view_fn()(new_base));
} else {
return THPVariable_Wrap(new_base.as_strided(
self.sizes(), self.strides(), self.storage_offset()));
}
END_HANDLE_TH_ERRORS
}

// Instantiates a subclass of self with the same data.
static PyObject* THPVariable_as_subclass(
PyObject* _self,
Expand Down Expand Up @@ -1645,6 +1675,7 @@ static PyMethodDef extra_methods[] = {
METH_STATIC | METH_VARARGS | METH_KEYWORDS,
nullptr},
{"_fix_weakref", THPVariable_fix_weakref, METH_NOARGS, nullptr},
{"_view_func", THPVariable_view_func, METH_O, nullptr},
{nullptr}};

/* From https://github.com/python/cpython/blob/v3.7.0/Modules/xxsubtype.c
Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/autograd/variable.h
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,11 @@ inline Variable make_variable(
return Variable();
}

namespace utils {

TORCH_API bool has_same_meta(const Variable& base, const Variable& other);

} // namespace utils
} // namespace autograd
} // namespace torch

Expand Down
16 changes: 8 additions & 8 deletions torch/csrc/cuda/comm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,12 +180,12 @@ tensor_list2d broadcast_coalesced(

unique_type_checker type_checker;
at::cuda::CUDAGuard device_guard(devices[0]);
for (auto& chunk : utils::take_tensors(tensors, buffer_size)) {
for (auto& chunk : torch::utils::take_tensors(tensors, buffer_size)) {
auto type_id = chunk.type_id();
type_checker.show(type_id);
std::vector<at::Tensor> results;
if (chunk.options().is_sparse()) {
auto flat_tuple = utils::flatten_sparse_tensors(chunk.tensors);
auto flat_tuple = torch::utils::flatten_sparse_tensors(chunk.tensors);
auto broadcast_indices = broadcast(flat_tuple.first, devices);
auto broadcast_values = broadcast(flat_tuple.second, devices);
results.reserve(devices.size());
Expand All @@ -194,20 +194,20 @@ tensor_list2d broadcast_coalesced(
auto& device_outputs = outputs[i];
auto& inds = broadcast_indices[i];
auto& vals = broadcast_values[i];
for (const auto& var :
utils::unflatten_sparse_tensors(inds, vals, chunk.tensors)) {
for (const auto& var : torch::utils::unflatten_sparse_tensors(
inds, vals, chunk.tensors)) {
// See NOTE [ Version Counter in comm.*_coalesced ]
device_outputs.push_back(make_variable(var.tensor_data(), false));
}
}
} else {
auto results =
broadcast(utils::flatten_dense_tensors(chunk.tensors), devices);
auto results = broadcast(
torch::utils::flatten_dense_tensors(chunk.tensors), devices);
for (size_t i = 1, num_devices = devices.size(); i < num_devices; ++i) {
device_guard.set_index(devices[i]);
auto& device_outputs = outputs[i];
for (auto& var :
utils::unflatten_dense_tensors(results[i], chunk.tensors)) {
torch::utils::unflatten_dense_tensors(results[i], chunk.tensors)) {
// See NOTE [ Version Counter in comm.*_coalesced ]
device_outputs.push_back(make_variable(var.tensor_data(), false));
}
Expand All @@ -218,7 +218,7 @@ tensor_list2d broadcast_coalesced(
// If we only saw a single tensor type, then we can skip expensive reordering
if (!type_checker.unique) {
for (auto& o : outputs)
utils::reorder_tensors_like(o, tensors);
torch::utils::reorder_tensors_like(o, tensors);
}
return outputs;
}
Expand Down
1 change: 1 addition & 0 deletions torch/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ def get_ignored_functions() -> Set[Callable]:
Tensor._typed_storage,
Tensor._reduce_ex_internal,
Tensor._fix_weakref,
Tensor._view_func,
Tensor._make_wrapper_subclass,
Tensor._python_dispatch.__get__,
Tensor._has_symbolic_sizes_strides.__get__,
Expand Down

0 comments on commit c79489c

Please sign in to comment.