Skip to content

Commit

Permalink
Use int64_t for nll_loss with cuda inputs (pytorch#85395)
Browse files Browse the repository at this point in the history
  • Loading branch information
crcrpar authored and pytorchmergebot committed Sep 29, 2022
1 parent 5f26df0 commit ef0baba
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 30 deletions.
56 changes: 29 additions & 27 deletions aten/src/ATen/native/cuda/Loss.cu
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ namespace {

constexpr int NLL_LOSS_THREADS = 32;

// TODO(crcrpar): Think about removing this dispatch, and introducing canUse32BitIndexMath
// NOTE(crcrpar): ATen/native/cuda/Loss.cu's nll loss implementation doesn't have the following dispatch for `target`, which only hardcode int64_t
// With this dispatch, `target` could be Byte and `ignore_index` could be int64_t, which doesn't sound quite reasonable.
#define AT_DISPATCH_NLL_LOSS_INDEX_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, \
AT_PRIVATE_CASE_TYPE_USING_HINT(at::ScalarType::Byte, index_t, __VA_ARGS__) \
Expand All @@ -164,8 +167,8 @@ __global__ void nll_loss_forward_no_reduce_cuda_kernel(
index_t* target,
scalar_t* output,
scalar_t* weights,
int n_classes,
int ignore_index) {
int64_t n_classes,
int64_t ignore_index) {
CUDA_KERNEL_LOOP(index, batch_size) {
int cur_target = target[index];
if (cur_target == ignore_index) {
Expand All @@ -187,12 +190,12 @@ __global__ void nll_loss_forward_reduce_cuda_kernel_1d(
index_t* target,
scalar_t* weights,
bool size_average,
int n_classes,
int64_t n_classes,
int64_t ignore_index) {
CUDA_KERNEL_ASSERT(threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0);

int t = static_cast<int>(*target);
if (t != static_cast<int>(ignore_index)) {
int64_t t = static_cast<int64_t>(*target);
if (t != ignore_index) {
CUDA_KERNEL_ASSERT(t >= 0 && t < n_classes);
const auto cur_weight = weights != nullptr ? weights[t] : scalar_t{1};
*total_weight = cur_weight;
Expand Down Expand Up @@ -223,9 +226,9 @@ __global__ void nll_loss_forward_reduce_cuda_kernel_2d(
index_t* target,
scalar_t* weights,
bool size_average,
int nframe,
int ndim,
int n_classes,
int64_t nframe,
int64_t ndim,
int64_t n_classes,
int64_t ignore_index) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
__shared__ accscalar_t sh_inputs[NLL_LOSS_THREADS],
Expand All @@ -234,8 +237,8 @@ __global__ void nll_loss_forward_reduce_cuda_kernel_2d(
sh_inputs[threadIdx.x] = static_cast<accscalar_t>(0);
acc_weight[threadIdx.x] = static_cast<accscalar_t>(0);
for (int i = threadIdx.x; i < nframe; i += NLL_LOSS_THREADS) {
int t = target[i];
if (t != static_cast<int>(ignore_index)) {
int64_t t = target[i];
if (t != ignore_index) {
CUDA_KERNEL_ASSERT(t >= 0 && t < n_classes);
scalar_t cur_weight =
weights != nullptr ? weights[t] : static_cast<scalar_t>(1);
Expand Down Expand Up @@ -400,11 +403,11 @@ __global__ void nll_loss_backward_no_reduce_cuda_kernel(
PackedTensorAccessor64<scalar_t, 1> grad_output,
PackedTensorAccessor64<scalar_t, 2> grad_input,
scalar_t *weights,
int n_classes,
int ignore_index) {
int64_t n_classes,
int64_t ignore_index) {
CUDA_KERNEL_LOOP(index, batch_size) {
int cur_target = target[index];
int64_t cur_target = static_cast<int64_t>(target[index]);
if (cur_target == ignore_index) {
continue;
}
Expand All @@ -422,16 +425,14 @@ __global__ void nll_loss_backward_reduce_cuda_kernel_1d(
index_t *target,
scalar_t *total_weight,
bool size_average,
int n_classes,
int64_t n_classes,
int64_t ignore_index
) {
int t = static_cast<int>(*target);
if (t != static_cast<int>(ignore_index)) {
const int64_t t = *target;
if (t != ignore_index) {
CUDA_KERNEL_ASSERT(t >= 0 && t < n_classes);
const auto grad = -(size_average ? *grad_output / *total_weight
: *grad_output);
grad_input[t] = weights != nullptr ? weights[t] * grad
: grad;
const auto grad = -(size_average ? *grad_output / *total_weight : *grad_output);
grad_input[t] = weights != nullptr ? weights[t] * grad : grad;
}
}
Expand All @@ -445,17 +446,19 @@ __global__ void nll_loss_backward_reduce_cuda_kernel_2d(
bool size_average,
int nframe,
int ndim,
int n_classes,
int64_t n_classes,
int64_t ignore_index) {
const auto grad = -(size_average ? *grad_output / *total_weight
: *grad_output);
for (int i = threadIdx.x; i < nframe; i += NLL_LOSS_THREADS) {
int t = target[i];
if (t != static_cast<int>(ignore_index)) {
const int64_t t = target[i];
if (t != ignore_index) {
CUDA_KERNEL_ASSERT(t >= 0 && t < n_classes);
grad_input[i * ndim + t] = weights != nullptr ? weights[t] * grad
: grad;
// NOTE(crcrpar): this index could overflow in int64_t as `t` itself can be close to the max.
const uint64_t index = static_cast<uint64_t>(i) * ndim + t;
CUDA_KERNEL_ASSERT(index >= 0);
grad_input[index] = weights != nullptr ? weights[t] * grad : grad;
}
}
}
Expand Down Expand Up @@ -504,8 +507,7 @@ void nll_loss_backward_out_cuda_template(
target.data_ptr<index_t>(),
grad_output.packed_accessor64<scalar_t, 1>(),
grad_input.packed_accessor64<scalar_t, 2>(),
weight.defined() ? weight_.data_ptr<scalar_t>()
: nullptr,
weight.defined() ? weight_.data_ptr<scalar_t>() : nullptr,
n_classes,
ignore_index);
C10_CUDA_KERNEL_LAUNCH_CHECK();
Expand Down
11 changes: 8 additions & 3 deletions aten/src/ATen/native/cuda/NLLLoss2d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ inline scalar_t* optional_data(const Tensor& source) {
using at::cuda::detail::CUDA_NUM_THREADS;
using at::cuda::detail::GET_BLOCKS;

// TODO(crcrpar): Think about introducing `canUse32BitIndexMath` and choose int or int64_t for `target`.
template <typename scalar_t>
C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS)
__global__ void nll_loss2d_forward_no_reduce_kernel(
Expand Down Expand Up @@ -98,11 +99,13 @@ __global__ void nll_loss2d_forward_kernel(
for (int i = (blockIdx.x % blocks_per_sample) * blockDim.x + threadIdx.x;
i < map_nelem;
i += step) {
int t = target[toffset + i];
int64_t t = target[toffset + i];
if (t != ignore_index) {
CUDA_KERNEL_ASSERT(t >= 0 && t < n_classes);
cur_weight = weight != nullptr ? weight[t] : static_cast<scalar_t>(1);
input_sum -= input[ioffset + i + map_nelem * t] * cur_weight;
const auto input_index = ioffset + i + map_nelem * t;
CUDA_KERNEL_ASSERT(input_index >= 0);
input_sum -= input[input_index] * cur_weight;
acc_weight += cur_weight;
}
}
Expand Down Expand Up @@ -185,9 +188,11 @@ __global__ void nll_loss2d_backward_kernel(
for (int i = (blockIdx.x % blocks_per_sample) * blockDim.x + threadIdx.x;
i < map_nelem;
i += step) {
int t = (int)target_thread[i];
const int64_t t = target_thread[i];
if (t != ignore_index) {
CUDA_KERNEL_ASSERT(t >= 0 && t < n_classes);
const auto grad_input_index = i + map_nelem * t;
CUDA_KERNEL_ASSERT(grad_input_index >= 0);
grad_input_thread[i + map_nelem * t] = weights != nullptr ? weights[t] * grad
: grad;
}
Expand Down
57 changes: 57 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17884,6 +17884,39 @@ def test_nll_loss_invalid_weights(self, device):
with self.assertRaisesRegex(RuntimeError, msg):
F.nll_loss(x, t, weight=weight)

# Ref: https://github.com/pytorch/pytorch/issue/85005
@onlyCUDA
@largeTensorTest("45GB", "cpu")
@largeTensorTest("45GB", "cuda")
@parametrize_test("reduction", ("none", "mean", "sum"))
def test_nll_loss_large_tensor(self, device, reduction):
shape = [int(2 ** 16), int(2 ** 16) + 1]

input = torch.randn(shape, device=device, dtype=torch.float32, requires_grad=True)
labels = torch.randint(shape[0], (shape[0],), dtype=torch.long, device=device)

out = F.nll_loss(input, labels, reduction=reduction)

with torch.no_grad():
input_cpu = input.cpu().float().requires_grad_()
labels_cpu = labels.cpu()
out_cpu = F.nll_loss(input_cpu, labels_cpu, reduction=reduction)
# workaround to reduce memory usage vs. self.assertEqual, see #84944
rtol, atol = torch.testing._comparison.get_tolerances(torch.float32, rtol=None, atol=None)
if reduction == "sum":
orig_rtol, orig_atol = rtol, atol
rtol, atol = 7 * rtol, 3 * atol
with torch.no_grad():
self.assertTrue(torch.allclose(out.cpu(), out_cpu, rtol=rtol, atol=atol))
if reduction == "sum":
rtol, atol = orig_rtol, orig_atol

if reduction != "none":
out.backward()
out_cpu.backward()
with torch.no_grad():
self.assertTrue(torch.allclose(input.grad.cpu(), input_cpu.grad, rtol=rtol, atol=atol))

def _nll_loss_helper(self, input_size, reduction, expected, device):
input = torch.rand(input_size, requires_grad=True, device=device)
num_channels = input_size[1]
Expand Down Expand Up @@ -18190,6 +18223,30 @@ def check_equal(loss, inp_targ_1, inp_targ_2):
# i.e. we don't count the ignored_idx at all.
check_equal(loss, (inp1, targ_positive_ignore_index), (inp2[1:], targ_positive_ignore_index[1:]))

# Ref: https://github.com/pytorch/pytorch/issue/85005
@onlyCUDA
@largeTensorTest("45GB", "cpu")
@largeTensorTest("45GB", "cuda")
@parametrize_test("reduction", ("none", "mean", "sum"))
def test_cross_entropy_large_tensor(self, device, reduction):
logits = torch.randn(int(2 ** 16), int(2 ** 16) + 1, dtype=torch.float32, device='cuda', requires_grad=True)
labels = torch.zeros(logits.size(0), dtype=torch.long, device='cuda')
loss = F.cross_entropy(logits, labels, reduction=reduction)
if reduction != "none":
loss.backward()

with torch.no_grad():
logits_cpu = logits.cpu().detach().requires_grad_()
labels_cpu = labels.cpu().detach()
loss_cpu = F.cross_entropy(logits_cpu, labels_cpu, reduction=reduction)
if reduction != "none":
loss_cpu.backward()

# workaround to reduce memory usage vs. self.assertEqual, see #84944
rtol, atol = torch.testing._comparison.get_tolerances(torch.float32, rtol=None, atol=None)
self.assertTrue(torch.allclose(loss.cpu(), loss_cpu, rtol=rtol, atol=atol))
if reduction != "none":
self.assertTrue(torch.allclose(logits.grad.cpu(), logits_cpu.grad, rtol=rtol, atol=atol))

def test_softshrink_negative(self, device):
input = torch.randn(5, device=device, requires_grad=True)
Expand Down

0 comments on commit ef0baba

Please sign in to comment.