diff --git a/aten/src/ATen/native/cuda/Loss.cu b/aten/src/ATen/native/cuda/Loss.cu index fcb3229198ab7..ed0d61c24a5fd 100644 --- a/aten/src/ATen/native/cuda/Loss.cu +++ b/aten/src/ATen/native/cuda/Loss.cu @@ -152,6 +152,9 @@ namespace { constexpr int NLL_LOSS_THREADS = 32; +// TODO(crcrpar): Think about removing this dispatch, and introducing canUse32BitIndexMath +// NOTE(crcrpar): ATen/native/cuda/Loss.cu's nll loss implementation doesn't have the following dispatch for `target`, which only hardcode int64_t +// With this dispatch, `target` could be Byte and `ignore_index` could be int64_t, which doesn't sound quite reasonable. #define AT_DISPATCH_NLL_LOSS_INDEX_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, \ AT_PRIVATE_CASE_TYPE_USING_HINT(at::ScalarType::Byte, index_t, __VA_ARGS__) \ @@ -164,8 +167,8 @@ __global__ void nll_loss_forward_no_reduce_cuda_kernel( index_t* target, scalar_t* output, scalar_t* weights, - int n_classes, - int ignore_index) { + int64_t n_classes, + int64_t ignore_index) { CUDA_KERNEL_LOOP(index, batch_size) { int cur_target = target[index]; if (cur_target == ignore_index) { @@ -187,12 +190,12 @@ __global__ void nll_loss_forward_reduce_cuda_kernel_1d( index_t* target, scalar_t* weights, bool size_average, - int n_classes, + int64_t n_classes, int64_t ignore_index) { CUDA_KERNEL_ASSERT(threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0); - int t = static_cast(*target); - if (t != static_cast(ignore_index)) { + int64_t t = static_cast(*target); + if (t != ignore_index) { CUDA_KERNEL_ASSERT(t >= 0 && t < n_classes); const auto cur_weight = weights != nullptr ? weights[t] : scalar_t{1}; *total_weight = cur_weight; @@ -223,9 +226,9 @@ __global__ void nll_loss_forward_reduce_cuda_kernel_2d( index_t* target, scalar_t* weights, bool size_average, - int nframe, - int ndim, - int n_classes, + int64_t nframe, + int64_t ndim, + int64_t n_classes, int64_t ignore_index) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) __shared__ accscalar_t sh_inputs[NLL_LOSS_THREADS], @@ -234,8 +237,8 @@ __global__ void nll_loss_forward_reduce_cuda_kernel_2d( sh_inputs[threadIdx.x] = static_cast(0); acc_weight[threadIdx.x] = static_cast(0); for (int i = threadIdx.x; i < nframe; i += NLL_LOSS_THREADS) { - int t = target[i]; - if (t != static_cast(ignore_index)) { + int64_t t = target[i]; + if (t != ignore_index) { CUDA_KERNEL_ASSERT(t >= 0 && t < n_classes); scalar_t cur_weight = weights != nullptr ? weights[t] : static_cast(1); @@ -400,11 +403,11 @@ __global__ void nll_loss_backward_no_reduce_cuda_kernel( PackedTensorAccessor64 grad_output, PackedTensorAccessor64 grad_input, scalar_t *weights, - int n_classes, - int ignore_index) { + int64_t n_classes, + int64_t ignore_index) { CUDA_KERNEL_LOOP(index, batch_size) { - int cur_target = target[index]; + int64_t cur_target = static_cast(target[index]); if (cur_target == ignore_index) { continue; } @@ -422,16 +425,14 @@ __global__ void nll_loss_backward_reduce_cuda_kernel_1d( index_t *target, scalar_t *total_weight, bool size_average, - int n_classes, + int64_t n_classes, int64_t ignore_index ) { - int t = static_cast(*target); - if (t != static_cast(ignore_index)) { + const int64_t t = *target; + if (t != ignore_index) { CUDA_KERNEL_ASSERT(t >= 0 && t < n_classes); - const auto grad = -(size_average ? *grad_output / *total_weight - : *grad_output); - grad_input[t] = weights != nullptr ? weights[t] * grad - : grad; + const auto grad = -(size_average ? *grad_output / *total_weight : *grad_output); + grad_input[t] = weights != nullptr ? weights[t] * grad : grad; } } @@ -445,17 +446,19 @@ __global__ void nll_loss_backward_reduce_cuda_kernel_2d( bool size_average, int nframe, int ndim, - int n_classes, + int64_t n_classes, int64_t ignore_index) { const auto grad = -(size_average ? *grad_output / *total_weight : *grad_output); for (int i = threadIdx.x; i < nframe; i += NLL_LOSS_THREADS) { - int t = target[i]; - if (t != static_cast(ignore_index)) { + const int64_t t = target[i]; + if (t != ignore_index) { CUDA_KERNEL_ASSERT(t >= 0 && t < n_classes); - grad_input[i * ndim + t] = weights != nullptr ? weights[t] * grad - : grad; + // NOTE(crcrpar): this index could overflow in int64_t as `t` itself can be close to the max. + const uint64_t index = static_cast(i) * ndim + t; + CUDA_KERNEL_ASSERT(index >= 0); + grad_input[index] = weights != nullptr ? weights[t] * grad : grad; } } } @@ -504,8 +507,7 @@ void nll_loss_backward_out_cuda_template( target.data_ptr(), grad_output.packed_accessor64(), grad_input.packed_accessor64(), - weight.defined() ? weight_.data_ptr() - : nullptr, + weight.defined() ? weight_.data_ptr() : nullptr, n_classes, ignore_index); C10_CUDA_KERNEL_LAUNCH_CHECK(); diff --git a/aten/src/ATen/native/cuda/NLLLoss2d.cu b/aten/src/ATen/native/cuda/NLLLoss2d.cu index a2027587d1c5e..d3f1284625293 100644 --- a/aten/src/ATen/native/cuda/NLLLoss2d.cu +++ b/aten/src/ATen/native/cuda/NLLLoss2d.cu @@ -44,6 +44,7 @@ inline scalar_t* optional_data(const Tensor& source) { using at::cuda::detail::CUDA_NUM_THREADS; using at::cuda::detail::GET_BLOCKS; +// TODO(crcrpar): Think about introducing `canUse32BitIndexMath` and choose int or int64_t for `target`. template C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS) __global__ void nll_loss2d_forward_no_reduce_kernel( @@ -98,11 +99,13 @@ __global__ void nll_loss2d_forward_kernel( for (int i = (blockIdx.x % blocks_per_sample) * blockDim.x + threadIdx.x; i < map_nelem; i += step) { - int t = target[toffset + i]; + int64_t t = target[toffset + i]; if (t != ignore_index) { CUDA_KERNEL_ASSERT(t >= 0 && t < n_classes); cur_weight = weight != nullptr ? weight[t] : static_cast(1); - input_sum -= input[ioffset + i + map_nelem * t] * cur_weight; + const auto input_index = ioffset + i + map_nelem * t; + CUDA_KERNEL_ASSERT(input_index >= 0); + input_sum -= input[input_index] * cur_weight; acc_weight += cur_weight; } } @@ -185,9 +188,11 @@ __global__ void nll_loss2d_backward_kernel( for (int i = (blockIdx.x % blocks_per_sample) * blockDim.x + threadIdx.x; i < map_nelem; i += step) { - int t = (int)target_thread[i]; + const int64_t t = target_thread[i]; if (t != ignore_index) { CUDA_KERNEL_ASSERT(t >= 0 && t < n_classes); + const auto grad_input_index = i + map_nelem * t; + CUDA_KERNEL_ASSERT(grad_input_index >= 0); grad_input_thread[i + map_nelem * t] = weights != nullptr ? weights[t] * grad : grad; } diff --git a/test/test_nn.py b/test/test_nn.py index bfefe4f5a642e..9694c7d841ba2 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -17884,6 +17884,39 @@ def test_nll_loss_invalid_weights(self, device): with self.assertRaisesRegex(RuntimeError, msg): F.nll_loss(x, t, weight=weight) + # Ref: https://github.com/pytorch/pytorch/issue/85005 + @onlyCUDA + @largeTensorTest("45GB", "cpu") + @largeTensorTest("45GB", "cuda") + @parametrize_test("reduction", ("none", "mean", "sum")) + def test_nll_loss_large_tensor(self, device, reduction): + shape = [int(2 ** 16), int(2 ** 16) + 1] + + input = torch.randn(shape, device=device, dtype=torch.float32, requires_grad=True) + labels = torch.randint(shape[0], (shape[0],), dtype=torch.long, device=device) + + out = F.nll_loss(input, labels, reduction=reduction) + + with torch.no_grad(): + input_cpu = input.cpu().float().requires_grad_() + labels_cpu = labels.cpu() + out_cpu = F.nll_loss(input_cpu, labels_cpu, reduction=reduction) + # workaround to reduce memory usage vs. self.assertEqual, see #84944 + rtol, atol = torch.testing._comparison.get_tolerances(torch.float32, rtol=None, atol=None) + if reduction == "sum": + orig_rtol, orig_atol = rtol, atol + rtol, atol = 7 * rtol, 3 * atol + with torch.no_grad(): + self.assertTrue(torch.allclose(out.cpu(), out_cpu, rtol=rtol, atol=atol)) + if reduction == "sum": + rtol, atol = orig_rtol, orig_atol + + if reduction != "none": + out.backward() + out_cpu.backward() + with torch.no_grad(): + self.assertTrue(torch.allclose(input.grad.cpu(), input_cpu.grad, rtol=rtol, atol=atol)) + def _nll_loss_helper(self, input_size, reduction, expected, device): input = torch.rand(input_size, requires_grad=True, device=device) num_channels = input_size[1] @@ -18190,6 +18223,30 @@ def check_equal(loss, inp_targ_1, inp_targ_2): # i.e. we don't count the ignored_idx at all. check_equal(loss, (inp1, targ_positive_ignore_index), (inp2[1:], targ_positive_ignore_index[1:])) + # Ref: https://github.com/pytorch/pytorch/issue/85005 + @onlyCUDA + @largeTensorTest("45GB", "cpu") + @largeTensorTest("45GB", "cuda") + @parametrize_test("reduction", ("none", "mean", "sum")) + def test_cross_entropy_large_tensor(self, device, reduction): + logits = torch.randn(int(2 ** 16), int(2 ** 16) + 1, dtype=torch.float32, device='cuda', requires_grad=True) + labels = torch.zeros(logits.size(0), dtype=torch.long, device='cuda') + loss = F.cross_entropy(logits, labels, reduction=reduction) + if reduction != "none": + loss.backward() + + with torch.no_grad(): + logits_cpu = logits.cpu().detach().requires_grad_() + labels_cpu = labels.cpu().detach() + loss_cpu = F.cross_entropy(logits_cpu, labels_cpu, reduction=reduction) + if reduction != "none": + loss_cpu.backward() + + # workaround to reduce memory usage vs. self.assertEqual, see #84944 + rtol, atol = torch.testing._comparison.get_tolerances(torch.float32, rtol=None, atol=None) + self.assertTrue(torch.allclose(loss.cpu(), loss_cpu, rtol=rtol, atol=atol)) + if reduction != "none": + self.assertTrue(torch.allclose(logits.grad.cpu(), logits_cpu.grad, rtol=rtol, atol=atol)) def test_softshrink_negative(self, device): input = torch.randn(5, device=device, requires_grad=True)