Skip to content

Commit

Permalink
Fix nonzero and tensor printing of n-dimensional empty tensors. (pyto…
Browse files Browse the repository at this point in the history
  • Loading branch information
gchanan authored and ezyang committed Jun 25, 2018
1 parent 1e7fcb5 commit 04440d2
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 15 deletions.
23 changes: 12 additions & 11 deletions aten/src/ATen/Formatting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,10 +260,7 @@ std::ostream& print(std::ostream& stream, const Tensor & tensor_, int64_t linesi
stream << defaultfloat << tensor.data<double>()[0] << std::endl;
stream << "[ " << tensor_.pImpl->toString() << "{} ]";
} else if(tensor.ndimension() == 1) {
if (tensor.numel() == 0) {
stream << "[ Tensor (empty) ]";
}
else {
if (tensor.numel() > 0) {
double scale;
int64_t sz;
std::tie(scale, sz) = __printFormat(stream, tensor);
Expand All @@ -274,18 +271,22 @@ std::ostream& print(std::ostream& stream, const Tensor & tensor_, int64_t linesi
for(int64_t i = 0; i < tensor.size(0); i++) {
stream << std::setw(sz) << tensor_p[i]/scale << std::endl;
}
stream << "[ " << tensor_.pImpl->toString() << "{" << tensor.size(0) << "} ]";
}
stream << "[ " << tensor_.pImpl->toString() << "{" << tensor.size(0) << "} ]";
} else if(tensor.ndimension() == 2) {
__printMatrix(stream, tensor, linesize, 0);
if (tensor.numel() > 0) {
__printMatrix(stream, tensor, linesize, 0);
}
stream << "[ " << tensor_.pImpl->toString() << "{" << tensor.size(0) << "," << tensor.size(1) << "} ]";
} else {
if (tensor.numel() > 0) {
__printTensor(stream, tensor, linesize);
stream << "[ " << tensor_.pImpl->toString() << "{" << tensor.size(0);
for(int64_t i = 1; i < tensor.ndimension(); i++) {
stream << "," << tensor.size(i);
}
stream << "} ]";
}
stream << "[ " << tensor_.pImpl->toString() << "{" << tensor.size(0);
for(int64_t i = 1; i < tensor.ndimension(); i++) {
stream << "," << tensor.size(i);
}
stream << "} ]";
}
}
return stream;
Expand Down
6 changes: 3 additions & 3 deletions aten/src/TH/generic/THTensorMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,20 +288,20 @@ void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor)
#ifdef DEBUG
THAssert(numel <= LONG_MAX);
#endif
THLongTensor_resize2d(subscript, numel, tensor->_dim());
THLongTensor_resize2d(subscript, numel, tensor->dim());

/* Second pass populates subscripts */
subscript_data = THLongTensor_data(subscript);
TH_TENSOR_APPLY(real, tensor,
if IS_NONZERO(*tensor_data) {
div = 1;

for (dim = tensor->_dim() - 1; dim >= 0; dim--) {
for (dim = tensor->dim() - 1; dim >= 0; dim--) {
*(subscript_data + dim) = (i/div) % tensor->size[dim];
div *= tensor->size[dim];
}

subscript_data += tensor->_dim();
subscript_data += tensor->dim();
}
++i;);
}
Expand Down
2 changes: 1 addition & 1 deletion aten/src/THC/generic/THCTensorMath.cu
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ void THCTensor_(nonzero)(THCState* state, THCudaLongTensor *tensor,
self = THCTensor_(newContiguous)(state, self);
thrust::device_ptr<real> self_data(THCTensor_(data)(state, self));

int num_dim = THCTensor_(_nDimension)(state, self);
int num_dim = THCTensor_(nDimension)(state, self);
int64_t N = THCTensor_(nElement)(state, self);

THCudaLongTensor_resize2d(state, tensor, N, num_dim);
Expand Down
11 changes: 11 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6208,6 +6208,17 @@ def test_nonzero(self):
for i in range(dst1.size(0)):
self.assertNotEqual(tensor[dst1[i, 0], dst1[i, 1], dst1[i, 2]].item(), 0)

def test_nonzero_empty(self):
if not torch._C._use_zero_size_dim():
return

devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
for device in devices:
x = torch.randn(0, 2, 0, 5, 0, device=device)
y = torch.nonzero(x)
self.assertEqual(0, y.numel())
self.assertEqual(torch.Size([0, 5]), y.shape)

def test_deepcopy(self):
from copy import deepcopy
a = torch.randn(5, 5)
Expand Down

0 comments on commit 04440d2

Please sign in to comment.