Skip to content

Commit

Permalink
Support softmax(nested tensor) (pytorch#80179)
Browse files Browse the repository at this point in the history
  • Loading branch information
YifanShenSZ authored and pytorchmergebot committed Jun 24, 2022
1 parent 79ba65c commit 54a1cc5
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 0 deletions.
4 changes: 4 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4348,6 +4348,9 @@
# softmax allows positional dtype, unlike most operators, because kwonly is BC-breaking when loading jit models.
- func: softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor
variants: function, method
dispatch:
CompositeImplicitAutograd: softmax
NestedTensorCPU, NestedTensorCUDA: softmax

- func: softmax.int_out(Tensor self, int dim, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!)
variants: function
Expand All @@ -4361,6 +4364,7 @@
structured_delegate: _softmax.out
dispatch:
MkldnnCPU: mkldnn_softmax
NestedTensorCPU, NestedTensorCUDA: softmax_nested

- func: _softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!)
structured: True
Expand Down
32 changes: 32 additions & 0 deletions aten/src/ATen/native/nested/NestedTensorMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -681,5 +681,37 @@ Tensor& dropout_nested_(Tensor& input, double p, bool train) {
return input;
}

Tensor softmax_nested(const Tensor& input, const int64_t dim, const bool half_to_float) {
auto input_ptr = get_nested_tensor_impl(input);
int64_t ntensors = input_ptr->size(0);
if (ntensors == 0) {
return input;
}
int64_t positive_dim = at::maybe_wrap_dim(dim, input_ptr->dim());
TORCH_CHECK(
positive_dim >= 1,
"Cannot apply softmax across nested dimension 0");
const Tensor& buffer = input_ptr->get_buffer(),
& sizemat = input_ptr->get_nested_size_tensor();
Tensor output_buffer = buffer.new_empty(buffer.sizes());
// split buffer into original tensors
std::vector<int64_t> offsets = NestedTensor_get_offsets(input_ptr);
std::vector<IntArrayRef> shapes = NestedTensor_get_shapes(input_ptr);
// call tensor softmax
// TODO: for cpu, maybe use `parallel_for` if benchmarks show necessity
// to do that, have to merge `aten/src/ATen/native/cpu/SoftMaxKernel.cpp/softmax_kernel`
// 1. it has `parallel_for` and we cannot multi-thread in multi-thread
// 2. cannot dispatch in multi-thread (in this case at::_softmax_out)
for (int64_t i = 0; i < ntensors; i++) {
Tensor out = output_buffer.slice(0, offsets[i], offsets[i + 1]).view(shapes[i]);
at::_softmax_out(
out,
buffer.slice(0, offsets[i], offsets[i + 1]).view(shapes[i]),
positive_dim - 1,
half_to_float);
}
return wrap_buffer(output_buffer, sizemat.clone());
}

} // namespace native
} // namespace at
32 changes: 32 additions & 0 deletions test/test_nestedtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,38 @@ def test_dropout(self, device, dtype):
expect_tensor[j] /= 1.0 - p
self.nt_equal(nt, expect)

# cannot test torch.float16 because: RuntimeError: "softmax_kernel_impl" not implemented for 'Half'
@dtypes(torch.float, torch.double)
@torch.inference_mode()
def test_softmax(self, device, dtype):
# normal nested tensor
ntensors = 4
nt = self.random_nt(device, dtype, ntensors, (4, 4))
# error case: softmax across nested dimension
self.assertRaises(RuntimeError, lambda: torch.nn.functional.softmax(nt, 0))
self.assertRaises(RuntimeError, lambda: torch.nn.functional.softmax(nt, -3))
# error case: dimension out of range
self.assertRaises(IndexError, lambda: torch.nn.functional.softmax(nt, 3))
self.assertRaises(IndexError, lambda: torch.nn.functional.softmax(nt, -4))
# normal case: should equal to padding -inf
softmaxer = torch.nn.Softmax(1)
y0 = softmaxer(nt)
y1 = torch.nn.functional.softmax(nt, 1)
self.nt_equal(y0, y1)
pt = nt.to_padded_tensor(float("-inf"))
# if an entire slice is padded, then softmax will return 0.0 / 0.0 = nan
# however, physically speaking that should be 0.0
expect = torch.nn.functional.softmax(pt, 1).nan_to_num_(0.0)
self.assertEqual(y0.to_padded_tensor(0.0), expect)
# edge case: empty nested tensor
nt0 = torch.nested_tensor([])
y = torch.nn.functional.softmax(nt0, 1)
self.nt_equal(nt0, y)
# edge case: nesting scalars
nt1 = torch.nested_tensor([torch.tensor(0.0), torch.tensor(1.0)])
self.assertRaises(RuntimeError, lambda: torch.nn.functional.softmax(nt1, 0))
self.assertRaises(IndexError, lambda: torch.nn.functional.softmax(nt1, 1))

class TestNestedTensorAutograd(TestCase):
def nt_equal(self, nt1, nt2):
self.assertEqual(nt1.dtype, nt2.dtype)
Expand Down

0 comments on commit 54a1cc5

Please sign in to comment.