Skip to content

Commit

Permalink
torch.fft: Multi-dimensional transforms (pytorch#44550)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#44550

Part of the `torch.fft` work (pytorchgh-42175).
This adds n-dimensional transforms: `fftn`, `ifftn`, `rfftn` and `irfftn`.

This is aiming for correctness first, with the implementation on top of the existing `_fft_with_size` restrictions. I plan to follow up later with a more efficient rewrite that makes `_fft_with_size` work with arbitrary numbers of dimensions.

Test Plan: Imported from OSS

Reviewed By: ngimel

Differential Revision: D23846032

Pulled By: mruberry

fbshipit-source-id: e6950aa8be438ec5cb95fb10bd7b8bc9ffb7d824
  • Loading branch information
peterbell10 authored and facebook-github-bot committed Sep 24, 2020
1 parent 070fe15 commit 6a2e9eb
Show file tree
Hide file tree
Showing 7 changed files with 662 additions and 7 deletions.
14 changes: 11 additions & 3 deletions aten/src/ATen/WrapDimUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,15 @@ static inline int64_t maybe_wrap_dim(int64_t dim, const std::vector<std::vector<
return maybe_wrap_dim(dim, tensor_sizes[0].size());
}

// wrap each of dims basing on dim_post_expr
static inline void maybe_wrap_dims(std::vector<int64_t>& dims, int64_t dim_post_expr) {
// wrap each dim in the dims array, taking dim_post_expr as the true number of dimensions
static inline void maybe_wrap_dims_n(int64_t* dims, int64_t ndims, int64_t dim_post_expr) {
if (dim_post_expr <= 0) {
dim_post_expr = 1; // this will make range [-1, 0]
}
int64_t min = -dim_post_expr;
int64_t max = dim_post_expr - 1;
for (auto& dim : dims) {
for (int64_t i = 0; i < ndims; ++i) {
auto &dim = dims[i];
if (dim < min || dim > max) {
TORCH_CHECK_INDEX(false,
"Dimension out of range (expected to be in range of [",
Expand All @@ -47,6 +48,13 @@ static inline void maybe_wrap_dims(std::vector<int64_t>& dims, int64_t dim_post_
}
}

// Wrap each dim in a contiguous container, taking dim_post_expr as the true number of dimensions
// E.g. could also be std::array or c10::SmallVector
template <typename Container>
inline void maybe_wrap_dims(Container& dims, int64_t dim_post_expr) {
return maybe_wrap_dims_n(dims.data(), dims.size(), dim_post_expr);
}

// previously, size [0] tensors were the only possible empty tensors; thus, it wasn't possible
// to cat empty tensors unless all the other tensors were 1-dimensional, so we allowed these tensors
// to be "skipped" (both for wrap dimension behavior and dimension size checking).
Expand Down
168 changes: 168 additions & 0 deletions aten/src/ATen/native/SpectralOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,116 @@ Tensor fft_c2c(Tensor input, c10::optional<int64_t> n_opt,
return out;
}

// Dimensions to transform, and the signal shape in those dimensions
struct ShapeAndDims {
DimVector shape, dim;
};

// Pre-process n-dimensional fft's `s` and `dim` arguments.
// Wraps dimensions and applies defaulting behavior.
// Also checks transform dims are unique and transform shape is non-empty.
ShapeAndDims canonicalize_fft_shape_and_dim_args(
Tensor input, c10::optional<IntArrayRef> shape, c10::optional<IntArrayRef> dim) {
const int64_t input_dim = input.dim();
const IntArrayRef input_sizes = input.sizes();
ShapeAndDims ret;

if (dim) {
ret.dim.resize(dim->size());
std::copy(dim->begin(), dim->end(), ret.dim.begin());
maybe_wrap_dims(ret.dim, input_dim);

// Check dims are unique
DimVector copy = ret.dim;
std::sort(copy.begin(), copy.end());
auto duplicate = std::adjacent_find(copy.begin(), copy.end());
TORCH_CHECK(duplicate == copy.end(), "FFT dims must be unique");
}

if (shape) {
// Has shape, may have dim
TORCH_CHECK(!dim || dim->size() == shape->size(),
"When given, dim and shape arguments must have the same length");
TORCH_CHECK(shape->size() <= input_dim,
"Got shape with ", shape->size(), " values but input tensor "
"only has ", input_dim, " dimensions.");
const int64_t transform_ndim = shape->size();
// If shape is given, dims defaults to the last shape.size() dimensions
if (!dim) {
ret.dim.resize(transform_ndim);
std::iota(ret.dim.begin(), ret.dim.end(), input_dim - transform_ndim);
}

// Translate shape of -1 to the default length
ret.shape.resize(transform_ndim);
for (int64_t i = 0; i < transform_ndim; ++i) {
const auto n = (*shape)[i];
ret.shape[i] = n == -1 ? input_sizes[ret.dim[i]] : n;
}
} else if (!dim) {
// No shape, no dim
ret.dim.resize(input_dim);
std::iota(ret.dim.begin(), ret.dim.end(), int64_t{0});
ret.shape.resize(input_dim);
std::copy(input_sizes.begin(), input_sizes.end(), ret.shape.begin());
} else {
// No shape, has dim
ret.shape.resize(ret.dim.size());
for (int64_t i = 0; i < ret.dim.size(); ++i) {
ret.shape[i] = input_sizes[ret.dim[i]];
}
}

for (int64_t i = 0; i < ret.shape.size(); ++i) {
TORCH_CHECK(ret.shape[i] > 0,
"Invalid number of data points (", ret.shape[i], ") specified");
}

return ret;
}

// Complex to complex n-dimensional fft
Tensor fftn_c2c(
const Tensor& input, IntArrayRef shape, IntArrayRef dim,
c10::optional<std::string> norm_str, bool forward) {
TORCH_CHECK(input.is_complex(), "Expected a complex input tensor to FFT");
const auto input_dim = input.dim();

Tensor x = resize_fft_input(input, dim, shape);
x = at::view_as_real(x);

const int64_t transform_ndim = dim.size();
const auto norm = norm_from_string(norm_str, forward);
// _fft_with_size only supports 3 dimensions being transformed at a time.
// This limit is inherited from cuFFT.
constexpr int64_t max_signal_ndim = 3;

// Transform n dimensions, up to 3 at a time
// TODO: rewrite _fft_with_size to transform more than 3 dimensions at once.
for (int64_t i = 0; i < transform_ndim; i += max_signal_ndim) {
const int64_t signal_ndim = std::min(transform_ndim - i, max_signal_ndim);
DimVector source_dim(signal_ndim);
DimVector dest_dim(signal_ndim);

for (int64_t j = 0; j < signal_ndim; ++j) {
source_dim[j] = dim[i + j];
dest_dim[j] = j + (input_dim - signal_ndim);
}

// _fft operates on up-to the last 3 dims, so move selected dims to the end
x = at::movedim(x, source_dim, dest_dim);

x = _fft(x, signal_ndim, /*complex_input=*/true, /*complex_output=*/true,
/*inverse=*/!forward, /*signal_sizes=*/{}, /*normalization=*/norm,
/*onesided=*/false);

// Move transform dims back to their original order
x = at::movedim(x, dest_dim, source_dim);
}

return at::view_as_complex(x);
}

}

// torch.fft.fft, analogous to NumPy's numpy.fft.fft
Expand Down Expand Up @@ -240,6 +350,64 @@ Tensor fft_ihfft(const Tensor& self, c10::optional<int64_t> n, int64_t dim,
return fft_r2c(self, n, dim, norm, /*forward=*/false, /*onesided=*/true);
}

Tensor fft_fftn(const Tensor& self, c10::optional<IntArrayRef> s,
c10::optional<IntArrayRef> dim,
c10::optional<std::string> norm) {
auto desc = canonicalize_fft_shape_and_dim_args(self, s, dim);
// TODO: For real input, perform rfftn then mirror with conjugate symmetry
Tensor input = promote_tensor_fft(self, /*require_complex=*/true);
return fftn_c2c(input, desc.shape, desc.dim, norm, /*forward=*/true);
}

Tensor fft_ifftn(const Tensor& self, c10::optional<IntArrayRef> s,
c10::optional<IntArrayRef> dim,
c10::optional<std::string> norm) {
auto desc = canonicalize_fft_shape_and_dim_args(self, s, dim);
Tensor input = promote_tensor_fft(self, /*require_complex=*/true);
return fftn_c2c(input, desc.shape, desc.dim, norm, /*forward=*/false);
}

Tensor fft_rfftn(const Tensor& self, c10::optional<IntArrayRef> s,
c10::optional<IntArrayRef> dim,
c10::optional<std::string> norm) {
auto desc = canonicalize_fft_shape_and_dim_args(self, s, dim);
TORCH_CHECK(desc.shape.size() > 0, "rfftn must transform at least one axis");

const auto last_dim = desc.dim.back();
const auto last_shape = desc.shape.back();
desc.shape.pop_back();
desc.dim.pop_back();

// rfft on last dim to get hermitian complex shape
auto x = native::fft_rfft(self, last_shape, last_dim, norm);
// Normal fft on remaining dims
return fftn_c2c(x, desc.shape, desc.dim, norm, /*forward=*/true);
}

Tensor fft_irfftn(const Tensor& self, c10::optional<IntArrayRef> s,
c10::optional<IntArrayRef> dim,
c10::optional<std::string> norm) {
auto desc = canonicalize_fft_shape_and_dim_args(self, s, dim);
TORCH_CHECK(desc.shape.size() > 0, "irfftn must transform at least one axis");

const auto last_dim = desc.dim.back();
const auto last_shape = [&]() -> c10::optional<int64_t> {
// If shape is defaulted in the last dimension,
// pass nullopt to irfft and let it calculate the default size
if (!s.has_value() || (s->back() == -1)) {
return c10::nullopt;
}
return desc.shape.back();
}();
desc.shape.pop_back();
desc.dim.pop_back();

// Normal ifft for all but last dim
Tensor x = promote_tensor_fft(self, /*require_complex=*/true);
x = fftn_c2c(x, desc.shape, desc.dim, norm, /*forward=*/false);
// Then 1d irfft on last dim to get real output
return native::fft_irfft(x, last_shape, last_dim, norm);
}

// This is a pass-through wrapper function that does the size check and
// inferences. The actual forward implementation function is called
Expand Down
20 changes: 20 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7943,6 +7943,26 @@
use_c10_dispatcher: full
variants: function

- func: fft_fftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor
python_module: fft
use_c10_dispatcher: full
variants: function

- func: fft_ifftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor
python_module: fft
use_c10_dispatcher: full
variants: function

- func: fft_rfftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor
python_module: fft
use_c10_dispatcher: full
variants: function

- func: fft_irfftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor
python_module: fft
use_c10_dispatcher: full
variants: function

- func: fft(Tensor self, int signal_ndim, bool normalized=False) -> Tensor
use_c10_dispatcher: full
variants: function, method
Expand Down
4 changes: 4 additions & 0 deletions docs/source/fft.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ Functions

.. autofunction:: fft
.. autofunction:: ifft
.. autofunction:: fftn
.. autofunction:: ifftn
.. autofunction:: rfft
.. autofunction:: irfft
.. autofunction:: rfftn
.. autofunction:: irfftn
.. autofunction:: hfft
.. autofunction:: ihfft
Loading

0 comments on commit 6a2e9eb

Please sign in to comment.