Skip to content

Commit

Permalink
Unify isViewable, handle n-dimensional empty tensors. (pytorch#8883)
Browse files Browse the repository at this point in the history
* Unify isViewable, handle n-dimensional empty tensors.

1) Unifies the two isViewable functions in ATen and TH.
2) Handle n-dimensional empty tensors in the implementation
3) Clarify some comments.

This requires an extra copy in the TH case, but that will go away.

* Also unify THCTensor version.

* Remove C-linkage from THTensor_compute_stride.

* Update comment.
  • Loading branch information
gchanan committed Jun 26, 2018
1 parent 6e28d4d commit 31327dd
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 132 deletions.
44 changes: 2 additions & 42 deletions aten/src/ATen/native/TensorShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "ATen/NativeFunctions.h"
#include "ATen/WrapDimUtils.h"
#include "ATen/optional.h"
#include <TH/THTensor.hpp>

#include <algorithm>

Expand Down Expand Up @@ -224,53 +225,12 @@ static std::vector<int64_t> infer_size(IntList shape, int64_t numel) {
throw std::runtime_error(ss.str());
}

static at::optional<std::vector<int64_t>>
compute_stride(const Tensor& self, IntList newshape) {
auto oldstride = self.strides();
auto oldshape = self.sizes();
if (oldshape.empty()) {
return std::vector<int64_t>(newshape.size(), 1);
}

std::vector<int64_t> newstride(newshape.size());
int64_t view_d = newshape.size() - 1;
// stride for each subspace in the chunk
int64_t chunk_base_stride = oldstride.back();
// numel in current chunk
int64_t tensor_numel = 1;
int64_t view_numel = 1;
for (int64_t tensor_d = oldshape.size() - 1; tensor_d >= 0; tensor_d--) {
tensor_numel *= oldshape[tensor_d];
// if end of tensor size chunk, check view
if ((tensor_d == 0) ||
(oldshape[tensor_d - 1] != 1 && oldstride[tensor_d - 1] != tensor_numel * chunk_base_stride)) {
while (view_d >= 0 && (view_numel < tensor_numel || newshape[view_d] == 1)) {
newstride[view_d] = view_numel * chunk_base_stride;
view_numel *= newshape[view_d];
view_d--;
}
if (view_numel != tensor_numel) {
return {};
}
if (tensor_d > 0) {
chunk_base_stride = oldstride[tensor_d - 1];
tensor_numel = 1;
view_numel = 1;
}
}
}
if (view_d != -1) {
return {};
}
return newstride;
}

Tensor reshape(const Tensor& self, IntList proposed_shape) {
if (self.type().is_sparse()) {
AT_ERROR("reshape is not implemented for sparse tensors");
}
auto shape = infer_size(proposed_shape, self.numel());
if (auto stride = compute_stride(self, shape)) {
if (auto stride = THTensor_compute_stride(self.sizes(), self.strides(), shape)) {
return self.as_strided(shape, *stride);
}
return at::_unsafe_view(self.clone(), shape);
Expand Down
68 changes: 68 additions & 0 deletions aten/src/TH/THTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
#include "generic/THTensorLapack.cpp"
#include "THGenerateFloatTypes.h"

#include <numeric>

void THTensor_free(THTensor *self)
{
if(!self)
Expand All @@ -54,3 +56,69 @@ void THTensor_free(THTensor *self)
}
}
}

// On a high level,
// 1. separate oldshape chunks of dimensions, where the dimensions are
// ``contiguous'' in each chunk, i.e., oldstride[i] = oldshape[i+1] * oldstride[i+1]
// 2. newshape must be able to be separated into same number of chunks as oldshape was separated into,
// where each chunk of newshape has matching ``numel'', i.e., number of subspaces,
// as the corresponding chunk of oldshape.
at::optional<std::vector<int64_t>>
THTensor_compute_stride(at::IntList oldshape, at::IntList oldstride, at::IntList newshape) {
if (oldshape.empty()) {
return std::vector<int64_t>(newshape.size(), 1);
}

// NOTE: stride is arbitrary is somewhat arbitrary in the numel() == 0 case;
// to match NumPy behavior we copy the strides if the size matches, otherwise
// we use the stride as if it were computed via resize.
// This could perhaps be combined with the below code, but the complexity didn't seem worth it.
int64_t numel = std::accumulate(oldshape.begin(), oldshape.end(), 1, std::multiplies<int64_t>());
if (numel == 0 && oldshape.equals(newshape)) {
return std::vector<int64_t>(oldstride);
}

std::vector<int64_t> newstride(newshape.size());
if (numel == 0) {
int64_t view_numel = 1;
for (int64_t view_d = newshape.size() - 1; view_d >= 0; view_d--) {
if (view_d == newshape.size() - 1) {
newstride[view_d] = 1;
} else {
newstride[view_d] = std::max<int64_t>(newshape[view_d+1], 1) * newstride[view_d+1];
}
}
return newstride;
}

int64_t view_d = newshape.size() - 1;
// stride for each subspace in the chunk
int64_t chunk_base_stride = oldstride.back();
// numel in current chunk
int64_t tensor_numel = 1;
int64_t view_numel = 1;
for (int64_t tensor_d = oldshape.size() - 1; tensor_d >= 0; tensor_d--) {
tensor_numel *= oldshape[tensor_d];
// if end of tensor size chunk, check view
if ((tensor_d == 0) ||
(oldshape[tensor_d - 1] != 1 && oldstride[tensor_d - 1] != tensor_numel * chunk_base_stride)) {
while (view_d >= 0 && (view_numel < tensor_numel || newshape[view_d] == 1)) {
newstride[view_d] = view_numel * chunk_base_stride;
view_numel *= newshape[view_d];
view_d--;
}
if (view_numel != tensor_numel) {
return at::nullopt;
}
if (tensor_d > 0) {
chunk_base_stride = oldstride[tensor_d - 1];
tensor_numel = 1;
view_numel = 1;
}
}
}
if (view_d != -1) {
return at::nullopt;
}
return newstride;
}
2 changes: 2 additions & 0 deletions aten/src/TH/THTensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,5 @@ typedef struct THTensor
#include "THGenerateAllTypes.h"

TH_API void THTensor_free(THTensor *self);
at::optional<std::vector<int64_t>> THTensor_compute_stride(at::IntList oldshape, at::IntList oldstride,
at::IntList newshape);
52 changes: 7 additions & 45 deletions aten/src/TH/generic/THTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,58 +243,20 @@ THTensor *THTensor_(newUnfold)(THTensor *tensor, int dimension_, int64_t size_,
return self;
}

// Also sets new_stride if viewable.
//
// On a high level,
// 1. separate tensor->size into chunks of dimensions, where the dimensions are
// ``contiguous'' in each chunk, i.e., stride[i] = size[i+1] * stride[i+1]
// 2. view_size must be able to be separated into same number of chunks, where
// each chunk pair has matching ``numel'', i.e., number of subspaces.
static int THTensor_(isViewable)(THTensor *tensor, THLongStorage *view_size, THLongStorage *new_stride) {
// dim indices
int64_t tensor_d = tensor->_dim() - 1;
if (tensor_d < 0) {
return 1;
}
int64_t view_d = view_size->size - 1;
// stride for each subspace in the chunk
int64_t chunk_base_stride = tensor->stride[tensor_d];
// numel in current chunk
int64_t tensor_numel = 1;
int64_t view_numel = 1;
for (; tensor_d >= 0; tensor_d--) {
tensor_numel *= tensor->size[tensor_d];
// if end of tensor size chunk, check view
if ((tensor_d == 0) ||
(tensor->size[tensor_d - 1] != 1 && tensor->stride[tensor_d - 1] != tensor_numel * chunk_base_stride)) {
while (view_d >= 0 && (view_numel < tensor_numel || THLongStorage_data(view_size)[view_d] == 1)) {
THLongStorage_data(new_stride)[view_d] = view_numel * chunk_base_stride;
view_numel *= THLongStorage_data(view_size)[view_d];
view_d--;
}
if (view_numel != tensor_numel) {
return 0;
}
if (tensor_d > 0) {
chunk_base_stride = tensor->stride[tensor_d - 1];
tensor_numel = 1;
view_numel = 1;
}
}
}
// check that we iterated through all view size
return view_d == -1;
}

THTensor *THTensor_(newView)(THTensor *tensor, THLongStorage *size)
{
ptrdiff_t numel = THTensor_(nElement)(tensor);
THTensor *self = THTensor_(new)();
THLongStorage *inferred_size = THLongStorage_newInferSize(size, numel);
THLongStorage *new_stride = THLongStorage_newWithSize(size->size);
THArgCheck(THTensor_(isViewable)(tensor, inferred_size, new_stride), 2, "view size is "
auto stride = THTensor_compute_stride(at::IntList(tensor->size, tensor->dim()),
at::IntList(tensor->stride, tensor->dim()),
at::IntList(inferred_size->data<int64_t>(), inferred_size->size));
THArgCheck(stride.has_value(), 2, "view size is "
"not compatible with input tensor's size and stride (at least one dimension spans "
"across two contiguous subspaces). Call .contiguous() before .view().");
auto stride_value = *stride;
THLongStorage *new_stride = THLongStorage_newWithSize(stride_value.size());
THLongStorage_rawCopy(new_stride, stride_value.data());
THTensor_(setStorage)(self, tensor->storage, tensor->storageOffset, inferred_size, new_stride);
THLongStorage_free(inferred_size);
THLongStorage_free(new_stride);
Expand Down
52 changes: 7 additions & 45 deletions aten/src/THC/generic/THCTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,58 +232,20 @@ THCTensor *THCTensor_(newUnfold)(THCState *state, THCTensor *tensor, int dimensi
return self;
}

// Also sets new_stride if viewable.
//
// On a high level,
// 1. separate tensor->size into chunks of dimensions, where the dimensions are
// ``contiguous'' in each chunk, i.e., stride[i] = size[i+1] * stride[i+1]
// 2. view_size must be able to be separated into same number of chunks, where
// each chunk pair has matching ``numel'', i.e., number of subspaces.
static int THCTensor_(isViewable)(THCState *state, THCTensor *tensor, THLongStorage *view_size, THLongStorage *new_stride) {
// dim indices
int64_t tensor_d = tensor->_dim() - 1;
if (tensor_d < 0) {
return 1;
}
int64_t view_d = view_size->size - 1;
// stride for each subspace in the chunk
int64_t chunk_base_stride = tensor->stride[tensor_d];
// numel in current chunk
int64_t tensor_numel = 1;
int64_t view_numel = 1;
for (; tensor_d >= 0; tensor_d--) {
tensor_numel *= tensor->size[tensor_d];
// if end of tensor size chunk, check view
if ((tensor_d == 0) ||
(tensor->size[tensor_d - 1] != 1 && tensor->stride[tensor_d - 1] != tensor_numel * chunk_base_stride)) {
while (view_d >= 0 && (view_numel < tensor_numel || THLongStorage_data(view_size)[view_d] == 1)) {
THLongStorage_data(new_stride)[view_d] = view_numel * chunk_base_stride;
view_numel *= THLongStorage_data(view_size)[view_d];
view_d--;
}
if (view_numel != tensor_numel) {
return 0;
}
if (tensor_d > 0) {
chunk_base_stride = tensor->stride[tensor_d - 1];
tensor_numel = 1;
view_numel = 1;
}
}
}
// check that we iterated through all view size
return view_d == -1;
}

THCTensor *THCTensor_(newView)(THCState *state, THCTensor *tensor, THLongStorage *size)
{
ptrdiff_t numel = THCTensor_(nElement)(state, tensor);
THCTensor *self = THCTensor_(new)(state);
THLongStorage *inferred_size = THLongStorage_newInferSize(size, numel);
THLongStorage *new_stride = THLongStorage_newWithSize(size->size);
THArgCheck(THCTensor_(isViewable)(state, tensor, inferred_size, new_stride), 2, "View size is "
auto stride = THTensor_compute_stride(at::IntList(tensor->size, tensor->dim()),
at::IntList(tensor->stride, tensor->dim()),
at::IntList(inferred_size->data<int64_t>(), inferred_size->size));
THArgCheck(stride.has_value(), 2, "view size is "
"not compatible with input tensor's size and stride (at least one dimension spans "
"across two contiguous subspaces). Call .contiguous() before .view().");
auto stride_value = *stride;
THLongStorage *new_stride = THLongStorage_newWithSize(stride_value.size());
THLongStorage_rawCopy(new_stride, stride_value.data());
THCTensor_(setStorage)(state, self, tensor->storage, tensor->storageOffset, inferred_size, new_stride);
THLongStorage_free(inferred_size);
THLongStorage_free(new_stride);
Expand Down

0 comments on commit 31327dd

Please sign in to comment.