Skip to content

Commit

Permalink
Make at::Tensor::to() const (pytorch#8839)
Browse files Browse the repository at this point in the history
* Make at::Tensor::to() const

* Add cheaper checks to Tensor::to
  • Loading branch information
goldsborough committed Jun 26, 2018
1 parent 5cb8586 commit 7a61479
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 6 deletions.
16 changes: 13 additions & 3 deletions aten/src/ATen/TensorOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,15 +255,25 @@ inline Tensor to(
}
} // namespace detail

inline Tensor Tensor::to(Device device, ScalarType dtype, bool non_blocking) {
inline Tensor Tensor::to(Device device, ScalarType dtype, bool non_blocking)
const {
if (this->device() == device && this->dtype() == dtype) {
return *this;
}
return detail::to(*this, options().device(device).dtype(dtype), non_blocking);
}

inline Tensor Tensor::to(ScalarType dtype, bool non_blocking) {
inline Tensor Tensor::to(ScalarType dtype, bool non_blocking) const {
if (this->dtype() == dtype) {
return *this;
}
return detail::to(*this, options().dtype(dtype), non_blocking);
}

inline Tensor Tensor::to(Device device, bool non_blocking) {
inline Tensor Tensor::to(Device device, bool non_blocking) const {
if (this->device() == device) {
return *this;
}
return detail::to(*this, options().device(device), non_blocking);
}
} // namespace at
7 changes: 4 additions & 3 deletions aten/src/ATen/templates/Tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,10 @@ struct Tensor : public detail::TensorBase {
inline Tensor toBackend(Backend b) const;

/// New-style `to()` methods.
Tensor to(Device device, ScalarType dtype, bool non_blocking = false);
Tensor to(ScalarType dtype, bool non_blocking = false);
Tensor to(Device device, bool non_blocking = false);
/// NB: These methods are defined in TensorOptions.h.
Tensor to(Device device, ScalarType dtype, bool non_blocking = false) const;
Tensor to(ScalarType dtype, bool non_blocking = false) const;
Tensor to(Device device, bool non_blocking = false) const;

/// Returns true if the `Tensor` is actually a `torch::autograd::Variable`.
/// Defined in Type.h because of include order issues.
Expand Down

0 comments on commit 7a61479

Please sign in to comment.