Skip to content

Commit

Permalink
add nested tensor matmul support (pytorch#81957)
Browse files Browse the repository at this point in the history
There was a discussion on whether letting nested tensor `reshape` support collapsing and splitting dimension 0. The conclusion was to make reshape simple, so we need a tweaked `matmul`, which only supports 3+ dimension nonbroadcast case, i.e. a generalized `bmm`.

Pull Request resolved: pytorch#81957
Approved by: https://github.com/jbschlosser
  • Loading branch information
YifanShenSZ authored and pytorchmergebot committed Jul 30, 2022
1 parent 32cf6c6 commit 4bb7e14
Show file tree
Hide file tree
Showing 9 changed files with 601 additions and 110 deletions.
18 changes: 16 additions & 2 deletions aten/src/ATen/native/LinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1880,8 +1880,22 @@ Tensor _matmul_impl(

Tensor matmul(const Tensor & tensor1, const Tensor & tensor2) {
auto maybe_outnames = namedinference::compute_matmul_outnames(tensor1, tensor2);
at::Tensor unused;
auto result = at::native::_matmul_impl(unused, tensor1, tensor2);
at::Tensor result, unused;
// Note [is_nested check]
// We have 2 choices to support nested tensor matmul:
// 1. intercept here by is_nested check
// 2. add nested tensor dispatch key
// Although 1. is gross, we still choose 1. because we hesitate about 2.:
// * We tried 2. for reshape and it caused a weird autograd bug
// (see comment in reshape in TensorShape.cpp)
// * but 2. for linear works?
// TODO: use 2. after we make sure it is fine
if (tensor1.is_nested() || tensor2.is_nested()) {
result = at::_NestedTensor_GeneralizedBMM(tensor1, tensor2);
}
else {
result = at::native::_matmul_impl(unused, tensor1, tensor2);
}
namedinference::propagate_names_if_nonempty(result, maybe_outnames);
return result;
}
Expand Down
8 changes: 8 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1138,6 +1138,10 @@
SparseCUDA: bmm_out_sparse_cuda
SparseCsrCUDA: bmm_out_sparse_csr_cuda

- func: _NestedTensor_GeneralizedBMM(Tensor self, Tensor mat2) -> Tensor
dispatch:
NestedTensorCPU, NestedTensorCUDA: _NestedTensor_GeneralizedBMM

- func: broadcast_tensors(Tensor[] tensors) -> Tensor[]
device_check: NoCheck
device_guard: False
Expand Down Expand Up @@ -4060,6 +4064,10 @@
dispatch:
NestedTensorCPU, NestedTensorCUDA: _reshape_nested

- func: _reshape_nested_backward(Tensor self, Tensor grad) -> Tensor
dispatch:
NestedTensorCPU, NestedTensorCUDA: _reshape_nested_backward

# NOTE [ _reshape_alias ] is meant to be used in the implementation of reshape.
# They are not user-facing, hence the leading underscore. Please don't use it
# anywhere else.
Expand Down
68 changes: 68 additions & 0 deletions aten/src/ATen/native/nested/NestedTensorBackward.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#include <ATen/native/nested/NestedTensorMath.h>

#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/NamedTensorUtils.h>
#include <ATen/WrapDimUtils.h>
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/native/layer_norm.h>
#include <ATen/NestedTensorImpl.h>
#include <c10/core/DispatchKey.h>
#include <ATen/native/nested/NestedTensorMath.h>

namespace at {
namespace native {

std::tuple<Tensor, Tensor, Tensor> nested_linear_backward(
const Tensor& input,
const Tensor& grad_output,
const Tensor& weight,
std::array<bool, 3> output_mask) {
if (!grad_output.defined()) {
return std::tuple<Tensor, Tensor, Tensor>{Tensor(), Tensor(), Tensor()};
}
Tensor grad_input, grad_weight, grad_bias;
auto* nt_grad_output = get_nested_tensor_impl(grad_output);
auto* nt_input = get_nested_tensor_impl(input);
TORCH_INTERNAL_ASSERT(nt_grad_output != nullptr);
TORCH_INTERNAL_ASSERT(nt_input != nullptr);
TORCH_CHECK(nested_tensor_impl_is_contiguous(nt_grad_output));
auto grad_ouput_buffer = nt_grad_output->get_buffer();
auto input_buffer = nt_input->get_buffer();

auto reshaped_grad = grad_ouput_buffer.reshape({-1, weight.size(0)});

if (output_mask[0]) {
auto grad_input_buffer = at::mm(reshaped_grad, weight).view({-1});
auto grad_input_nt_size = nt_input->get_nested_size_tensor().clone();
grad_input = wrap_buffer(grad_input_buffer, grad_input_nt_size);
}
if (output_mask[1]) {
grad_weight =
at::mm(reshaped_grad.t(), input_buffer.reshape({-1, weight.size(1)}));
}
if (output_mask[2]) {
grad_bias = reshaped_grad.sum(0);
}
return std::tuple<Tensor, Tensor, Tensor>{grad_input, grad_weight, grad_bias};
}

Tensor _reshape_nested_backward(const Tensor& self, const Tensor& grad) {
auto self_ptr = get_nested_tensor_impl(self);
// TODO: this is to reproduce self_ptr->opt_sizes_
// if an accessor is provided in the future, can replace this
std::vector<int64_t> sizes;
for (int64_t i = 0; i < self_ptr->dim(); i++) {
c10::optional<int64_t> opt_size = self_ptr->opt_size(i);
if (opt_size.has_value()) {
sizes.push_back(*opt_size);
}
else {
sizes.push_back(-1);
}
}
return grad.reshape(sizes);
}

} // namespace native
} // namespace at
187 changes: 133 additions & 54 deletions aten/src/ATen/native/nested/NestedTensorMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,64 +96,10 @@ Tensor pad_tensor_to_shape(
}
} // namespace



inline const at::Tensor& get_buffer(const at::Tensor& tensor) {
return get_nested_tensor_impl(tensor)->get_buffer();
}

// The sizes of the underlying tensors
inline std::vector<IntArrayRef> NestedTensor_get_sizes(const NestedTensorImpl* self_ptr) {
int64_t ntensors = self_ptr->size(0);
std::vector<IntArrayRef> sizes(ntensors);
if (ntensors == 0) {
return sizes;
}
const Tensor& sizemat = self_ptr->get_nested_size_tensor();
int64_t orig_dim = sizemat.size(1);
// nesting scalars has empty sizes
if (orig_dim == 0) {
return sizes;
}
const int64_t* sizemat_ptr = sizemat.data_ptr<int64_t>();
for (int64_t i = 0; i < ntensors; i++) {
sizes[i] = IntArrayRef(sizemat_ptr, sizemat_ptr + orig_dim);
sizemat_ptr += orig_dim;
}
return sizes;
}

inline std::vector<IntArrayRef> NestedTensor_get_sizes(const at::Tensor& self) {
const NestedTensorImpl* self_ptr = get_nested_tensor_impl(self);
return NestedTensor_get_sizes(self_ptr);
}

// The strides of the underlying tensors
inline std::vector<IntArrayRef> NestedTensor_get_strides(const NestedTensorImpl* self_ptr) {
int64_t ntensors = self_ptr->size(0);
std::vector<IntArrayRef> strides(ntensors);
if (ntensors == 0) {
return strides;
}
const Tensor& stridemat = self_ptr->get_nested_stride_tensor();
int64_t orig_dim = stridemat.size(1);
// nesting scalars has empty strides
if (orig_dim == 0) {
return strides;
}
const int64_t* stridemat_ptr = stridemat.data_ptr<int64_t>();
for (int64_t i = 0; i < ntensors; i++) {
strides[i] = IntArrayRef(stridemat_ptr, stridemat_ptr + orig_dim);
stridemat_ptr += orig_dim;
}
return strides;
}

inline std::vector<IntArrayRef> NestedTensor_get_strides(const at::Tensor& self) {
const NestedTensorImpl* self_ptr = get_nested_tensor_impl(self);
return NestedTensor_get_strides(self_ptr);
}

std::vector<at::Tensor> NestedTensor_unbind(
const at::Tensor& self,
int64_t dim) {
Expand Down Expand Up @@ -814,6 +760,13 @@ Tensor softmax_nested(
}

Tensor bmm_nested(const Tensor& self, const Tensor& mat2) {
if (self.is_nested() && !mat2.is_nested()) {
AT_ERROR("Expected both to be nested, but got a nested self and non-nested other");
}
else if (!self.is_nested() && mat2.is_nested()) {
AT_ERROR("Expected both to be nested, but got a non-nested self and nested other");
}
// dispatcher should have guaranteed that at least one is nested
auto self_ptr = get_nested_tensor_impl(self);
auto mat2_ptr = get_nested_tensor_impl(mat2);
TORCH_CHECK(self_ptr->dim() == 3, "batch1 must be a 3D tensor");
Expand Down Expand Up @@ -866,6 +819,132 @@ Tensor bmm_nested(const Tensor& self, const Tensor& mat2) {
return output;
}

// utilities support _NestedTensor_GeneralizedBMM
namespace {
inline std::tuple<std::vector<int64_t>, Tensor>
_NestedTensor_GeneralizedBMM_BatchSizes_OutputMemory(
const std::vector<IntArrayRef>& self_sizes,
const std::vector<IntArrayRef>& mat2_sizes,
const c10::TensorOptions& buffer_op,
const c10::TensorOptions& sizemat_op) {
int64_t ntensors = self_sizes.size(),
ndims = self_sizes[0].size();
std::vector<int64_t> batch_sizes(ntensors, 1);
Tensor sizemat = at::empty({ntensors, ndims}, sizemat_op);
int64_t* sizemat_ptr = sizemat.data_ptr<int64_t>();
int64_t numel = 0;
for (int64_t i = 0; i < ntensors; i++) {
const IntArrayRef& self_size = self_sizes[i],
& mat2_size = mat2_sizes[i];
int64_t& batch_size = batch_sizes[i];
// batch dimensions
for (int64_t j = 0; j < ndims - 2; j++) {
const int64_t& self_sizej = self_size[j],
& mat2_sizej = mat2_size[j];
TORCH_CHECK(
self_sizej == mat2_sizej,
"matmul: For nested tensors, no broadcasting is currently performed: ",
i, "-th nested matrices in batch at dimension ", j + 1,
" have mismatching sizes ", self_sizej, " and ", mat2_sizej);
sizemat_ptr[j] = self_sizej;
batch_size *= sizemat_ptr[j];
}
// matrix multiplication dimensions
const int64_t& self_size0 = self_size[ndims - 2], & self_size1 = self_size[ndims - 1],
& mat2_size0 = mat2_size[ndims - 2], & mat2_size1 = mat2_size[ndims - 1];
TORCH_CHECK(
self_size1 == mat2_size0,
"matmul: ",
i, "-th nested matrices in batch cannot be multiplied (",
self_size0, "x", self_size1, " and ",
mat2_size0, "x", mat2_size1, ")");
sizemat_ptr[ndims - 2] = self_size0;
sizemat_ptr[ndims - 1] = mat2_size1;
sizemat_ptr += ndims;
numel += batch_size * self_size0 * mat2_size1;
}
Tensor buffer = at::empty(numel, buffer_op);
Tensor output = wrap_buffer(buffer, sizemat);
return std::make_tuple(batch_sizes, output);
}
}

// This is a generalized batched matmul dedicated to nested tensors,
// where `self` and `mat2` have same number (>= 3) of dimensions.
// The last 2 dimensions will be considered as matrix dimensions,
// so they should be matrix-multiplicable.
// The leading dimensions are considered as batch dimensions,
// and since nested tensor does not support broadcasting for now,
// for each batch dimension `self` and `mat2` must have same size.
Tensor _NestedTensor_GeneralizedBMM(const Tensor& self, const Tensor& mat2) {
if (self.is_nested() && !mat2.is_nested()) {
AT_ERROR("Expected both to be nested, but got a nested self and non-nested other");
}
else if (!self.is_nested() && mat2.is_nested()) {
AT_ERROR("Expected both to be nested, but got a non-nested self and nested other");
}
// dispatcher should have guaranteed that at least one is nested
auto self_ptr = get_nested_tensor_impl(self),
mat2_ptr = get_nested_tensor_impl(mat2);
int64_t self_dim = self_ptr->dim(),
mat2_dim = mat2_ptr->dim();
TORCH_CHECK(
self_dim >= 3,
"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: ",
self_dim);
TORCH_CHECK(
mat2_dim >= 3,
"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 2nd input has rank: ",
mat2_dim);
TORCH_CHECK(self_dim == mat2_dim, "matmul: both inputs must have same rank");
int64_t ntensors = self_ptr->size(0),
ntensors2 = mat2_ptr->size(0);
TORCH_CHECK(ntensors == ntensors2,
"matmul: Expected size for the 1st dimension of 2nd input tensor to be: ", ntensors,
" but got: ", ntensors2, ".");
const Tensor& self_buffer = self_ptr->get_buffer(),
& mat2_buffer = mat2_ptr->get_buffer();
std::vector<IntArrayRef> self_sizes = NestedTensor_get_sizes(self_ptr),
mat2_sizes = NestedTensor_get_sizes(mat2_ptr),
self_strides = NestedTensor_get_strides(self_ptr),
mat2_strides = NestedTensor_get_strides(mat2_ptr);
const std::vector<int64_t>& self_offsets = self_ptr->get_offsets(),
& mat2_offsets = mat2_ptr->get_offsets();
// create a contiguous output
std::vector<int64_t> batch_sizes;
Tensor output;
std::tie(batch_sizes, output) = _NestedTensor_GeneralizedBMM_BatchSizes_OutputMemory(
self_sizes, mat2_sizes, self_buffer.options(), self_ptr->get_nested_size_tensor().options());
// call tensor matmul
// TODO: `padding nested tensor -> bmm -> remove padding` may be more efficient
// until we have specialized nested tensor bmm kernel
// useful resource: `aten/src/ATen/native/cpu/LinearAlgebra.cpp/bmm_out_or_baddbmm_`
// `aten/src/ATen/native/cuda/Blas.cpp/baddbmm_out_cuda_impl`
std::vector<Tensor> output_unbind = output.unbind();
for (int64_t i = 0; i < ntensors; i++) {
const IntArrayRef& self_size = self_sizes[i],
& mat2_size = mat2_sizes[i];
const int64_t& batch_size = batch_sizes[i];
if (batch_size == 1) {
at::mm_out(
output_unbind[i],
self_buffer.as_strided(self_size, self_strides[i], self_offsets[i]),
mat2_buffer.as_strided(mat2_size, mat2_strides[i], mat2_offsets[i])
);
}
else {
at::bmm_out(
output_unbind[i],
self_buffer.as_strided(self_size, self_strides[i], self_offsets[i])
.reshape({batch_size, self_size[self_dim - 1 - 2], self_size[self_dim - 1 - 1]}),
mat2_buffer.as_strided(mat2_size, mat2_strides[i], mat2_offsets[i])
.reshape({batch_size, mat2_size[self_dim - 1 - 2], mat2_size[self_dim - 1 - 1]})
);
}
}
return output;
}

Tensor transpose_nested(const Tensor& self, int64_t dim0, int64_t dim1) {
auto self_ptr = get_nested_tensor_impl(self);
// check input dimensions
Expand Down
52 changes: 52 additions & 0 deletions aten/src/ATen/native/nested/NestedTensorMath.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,58 @@ inline at::Tensor wrap_buffer(
std::move(nested_stride_tensor), offsets);
}

// The sizes of the underlying tensors
inline std::vector<IntArrayRef> NestedTensor_get_sizes(const NestedTensorImpl* self_ptr) {
int64_t ntensors = self_ptr->size(0);
std::vector<IntArrayRef> sizes(ntensors);
if (ntensors == 0) {
return sizes;
}
const Tensor& sizemat = self_ptr->get_nested_size_tensor();
int64_t orig_dim = sizemat.size(1);
// nesting scalars has empty sizes
if (orig_dim == 0) {
return sizes;
}
const int64_t* sizemat_ptr = sizemat.data_ptr<int64_t>();
for (int64_t i = 0; i < ntensors; i++) {
sizes[i] = IntArrayRef(sizemat_ptr, sizemat_ptr + orig_dim);
sizemat_ptr += orig_dim;
}
return sizes;
}

inline std::vector<IntArrayRef> NestedTensor_get_sizes(const at::Tensor& self) {
const NestedTensorImpl* self_ptr = get_nested_tensor_impl(self);
return NestedTensor_get_sizes(self_ptr);
}

// The strides of the underlying tensors
inline std::vector<IntArrayRef> NestedTensor_get_strides(const NestedTensorImpl* self_ptr) {
int64_t ntensors = self_ptr->size(0);
std::vector<IntArrayRef> strides(ntensors);
if (ntensors == 0) {
return strides;
}
const Tensor& stridemat = self_ptr->get_nested_stride_tensor();
int64_t orig_dim = stridemat.size(1);
// nesting scalars has empty strides
if (orig_dim == 0) {
return strides;
}
const int64_t* stridemat_ptr = stridemat.data_ptr<int64_t>();
for (int64_t i = 0; i < ntensors; i++) {
strides[i] = IntArrayRef(stridemat_ptr, stridemat_ptr + orig_dim);
stridemat_ptr += orig_dim;
}
return strides;
}

inline std::vector<IntArrayRef> NestedTensor_get_strides(const at::Tensor& self) {
const NestedTensorImpl* self_ptr = get_nested_tensor_impl(self);
return NestedTensor_get_strides(self_ptr);
}

TORCH_API std::vector<int64_t> NestedTensor_get_max_size(
const NestedTensorImpl& nt);

Expand Down
Loading

0 comments on commit 4bb7e14

Please sign in to comment.