Skip to content

Commit

Permalink
Fix const correctness for VmapPhysicalView struct methods (pytorch#41940
Browse files Browse the repository at this point in the history
)

Summary:
Pull Request resolved: pytorch#41940

See title. I marked methods that don't mutate the VmapPhysicalView as
`const`.

Test Plan: - wait for tests

Reviewed By: albanD

Differential Revision: D22764102

Pulled By: zou3519

fbshipit-source-id: 40f957ad61c85f0e5684357562a541a2712b1f38
  • Loading branch information
zou3519 authored and facebook-github-bot committed Jul 28, 2020
1 parent 2bc7dae commit 5124436
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 12 deletions.
12 changes: 6 additions & 6 deletions aten/src/ATen/VmapTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,15 @@ MultiBatchVmapTransform::logicalToPhysical(TensorList logical_tensors) {
TORCH_INTERNAL_ASSERT(false, "NYI");
}

int64_t VmapPhysicalView::numBatchDims() {
int64_t VmapPhysicalView::numBatchDims() const {
return levels_.count();
}

int64_t VmapPhysicalView::numLogicalDims() {
int64_t VmapPhysicalView::numLogicalDims() const {
return /*physical*/tensor_.dim() - numBatchDims();
}

VmapDimVector VmapPhysicalView::getPhysicalDims(IntArrayRef logical_dims) {
VmapDimVector VmapPhysicalView::getPhysicalDims(IntArrayRef logical_dims) const {
auto logical_ndim = numLogicalDims();
// NB: fmap doesn't have a SmallVector variant, so we don't use it here.
VmapDimVector result;
Expand All @@ -79,12 +79,12 @@ VmapDimVector VmapPhysicalView::getPhysicalDims(IntArrayRef logical_dims) {
return result;
}

int64_t VmapPhysicalView::getPhysicalDim(int64_t logical_dim) {
int64_t VmapPhysicalView::getPhysicalDim(int64_t logical_dim) const {
auto logical_ndim = numLogicalDims();
return maybe_wrap_dim(logical_dim, logical_ndim) + numBatchDims();
}

VmapDimVector VmapPhysicalView::getPhysicalShape(IntArrayRef logical_shape) {
VmapDimVector VmapPhysicalView::getPhysicalShape(IntArrayRef logical_shape) const {
VmapDimVector result;
result.reserve(logical_shape.size() + numBatchDims());
auto tensor_sizes = tensor_.sizes();
Expand All @@ -105,7 +105,7 @@ static BatchDims computeFrontBatchDimsFromLevels(std::bitset<kVmapNumLevels> lev
return bdims;
}

Tensor VmapPhysicalView::newLogicalFromPhysical(const Tensor& physical) {
Tensor VmapPhysicalView::newLogicalFromPhysical(const Tensor& physical) const {
return makeBatched(physical, computeFrontBatchDimsFromLevels(levels_));
}

Expand Down
13 changes: 7 additions & 6 deletions aten/src/ATen/VmapTransforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ struct TORCH_API VmapPhysicalView {
}

Tensor& tensor() { return tensor_; }
const Tensor& tensor() const { return tensor_; }

// Maps logical dim indices to physical dim indices. Also does dim wrapping.
//
Expand All @@ -111,23 +112,23 @@ struct TORCH_API VmapPhysicalView {
// This is because the size of levels tell us that the first two dimensions
// of `tensor_` are batch dimensions, so a logical dim of `n` is actually
// a physical dim of `n + 2`.
VmapDimVector getPhysicalDims(IntArrayRef logical_dims);
int64_t getPhysicalDim(int64_t logical_dim);
VmapDimVector getPhysicalDims(IntArrayRef logical_dims) const;
int64_t getPhysicalDim(int64_t logical_dim) const;

// Maps a logical shape to a physical shape by pre-pending the batch
// sizes to the logical shape.
VmapDimVector getPhysicalShape(IntArrayRef logical_shape);
VmapDimVector getPhysicalShape(IntArrayRef logical_shape) const;

// Maps a physical tensor to a new logical tensor (BatchedTensor),
// using the mapping info stored in this VmapPhysicalView.
// Assumes that all of the "batch dimensions" are at the front
// of the physical tensor.
Tensor newLogicalFromPhysical(const Tensor& physical);
Tensor newLogicalFromPhysical(const Tensor& physical) const;

int64_t numBatchDims();
int64_t numBatchDims() const;

private:
int64_t numLogicalDims();
int64_t numLogicalDims() const;

std::bitset<kVmapNumLevels> levels_;
Tensor tensor_;
Expand Down

0 comments on commit 5124436

Please sign in to comment.