Skip to content

Commit

Permalink
Implemented xtensor FFT
Browse files Browse the repository at this point in the history
  • Loading branch information
spectre-ns committed Apr 24, 2024
1 parent 22ad9ea commit 932e387
Show file tree
Hide file tree
Showing 14 changed files with 615 additions and 3 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,4 @@ __pycache__

# Generated files
*.pc
.vscode/settings.json
6 changes: 6 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ set(XTENSOR_HEADERS
${XTENSOR_INCLUDE_DIR}/xtensor/xfixed.hpp
${XTENSOR_INCLUDE_DIR}/xtensor/xfunction.hpp
${XTENSOR_INCLUDE_DIR}/xtensor/xfunctor_view.hpp
${XTENSOR_INCLUDE_DIR}/xtensor/xfft.hpp
${XTENSOR_INCLUDE_DIR}/xtensor/xgenerator.hpp
${XTENSOR_INCLUDE_DIR}/xtensor/xhistogram.hpp
${XTENSOR_INCLUDE_DIR}/xtensor/xindex_view.hpp
Expand Down Expand Up @@ -199,6 +200,7 @@ target_link_libraries(xtensor INTERFACE xtl)

OPTION(XTENSOR_ENABLE_ASSERT "xtensor bound check" OFF)
OPTION(XTENSOR_CHECK_DIMENSION "xtensor dimension check" OFF)
OPTION(XTENSOR_FORCE_TEMPORARY_MEMORY_IN_ASSIGNMENTS "xtensor force the use of temporary memory when assigning instead of an automatic overlap check" ON)
OPTION(BUILD_TESTS "xtensor test suite" OFF)
OPTION(BUILD_BENCHMARK "xtensor benchmark" OFF)
OPTION(DOWNLOAD_GTEST "build gtest from downloaded sources" OFF)
Expand All @@ -219,6 +221,10 @@ if(XTENSOR_CHECK_DIMENSION)
add_definitions(-DXTENSOR_ENABLE_CHECK_DIMENSION)
endif()

if(XTENSOR_FORCE_TEMPORARY_MEMORY_IN_ASSIGNMENTS)
add_definitions(-DXTENSOR_FORCE_TEMPORARY_MEMORY_IN_ASSIGNMENTS)
endif()

if(DEFAULT_COLUMN_MAJOR)
add_definitions(-DXTENSOR_DEFAULT_LAYOUT=layout_type::column_major)
endif()
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# ![xtensor](docs/source/xtensor.svg)

![linux](https://github.com/xtensor-stack/xtensor/actions/workflows/linux.yml/badge.svg)
![osx](https://github.com/xtensor-stack/xtensor/actions/workflows/osx.yml/badge.svg)
![windows](https://github.com/xtensor-stack/xtensor/actions/workflows/windows.yml/badge.svg)
[![GHA Linux](https://github.com/xtensor-stack/xtensor/actions/workflows/linux.yml/badge.svg)](https://github.com/xtensor-stack/xtensor/actions/workflows/linux.yml)
[![GHA OSX](https://github.com/xtensor-stack/xtensor/actions/workflows/osx.yml/badge.svg)](https://github.com/xtensor-stack/xtensor/actions/workflows/osx.yml)
[![GHA Windows](https://github.com/xtensor-stack/xtensor/actions/workflows/windows.yml/badge.svg)](https://github.com/xtensor-stack/xtensor/actions/workflows/windows.yml)
[![Documentation](http://readthedocs.org/projects/xtensor/badge/?version=latest)](https://xtensor.readthedocs.io/en/latest/?badge=latest)
[![Doxygen -> gh-pages](https://github.com/xtensor-stack/xtensor/workflows/gh-pages/badge.svg)](https://xtensor-stack.github.io/xtensor)
[![Binder](https://mybinder.org/badge.svg)](https://mybinder.org/v2/gh/xtensor-stack/xtensor/stable?filepath=notebooks%2Fxtensor.ipynb)
Expand Down
1 change: 1 addition & 0 deletions docs/source/api/container_index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,4 @@ xexpression API is actually implemented in ``xstrided_container`` and ``xcontain
xindex_view
xfunctor_view
xrepeat
xfft
17 changes: 17 additions & 0 deletions docs/source/xfft.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
.. Copyright (c) 2016, Johan Mabille, Sylvain Corlay and Wolf Vollprecht
Distributed under the terms of the BSD 3-Clause License.
The full license is in the file LICENSE, distributed with this software.
xfft
====

Defined in ``xtensor/xfft.hpp``

.. doxygenclass:: xt::fft_convolve
:project: xtensor
:members:

.. doxygentypedef:: xt::fft
:project: xtensor

.. doxygentypedef:: xt::ifft
:project: xtensor
23 changes: 23 additions & 0 deletions include/xtensor/xbroadcast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,29 @@ namespace xt
return linear_end(c.expression());
}

/*************************************
* overlapping_memory_checker_traits *
*************************************/

template <class E>
struct overlapping_memory_checker_traits<
E,
std::enable_if_t<!has_memory_address<E>::value && is_specialization_of<xbroadcast, E>::value>>
{
static bool check_overlap(const E& expr, const memory_range& dst_range)
{
if (expr.size() == 0)
{
return false;
}
else
{
using ChildE = std::decay_t<decltype(expr.expression())>;
return overlapping_memory_checker_traits<ChildE>::check_overlap(expr.expression(), dst_range);
}
}
};

/**
* @class xbroadcast
* @brief Broadcasted xexpression to a specified shape.
Expand Down
241 changes: 241 additions & 0 deletions include/xtensor/xfft.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
#ifdef XTENSOR_USE_TBB
#include <oneapi/tbb.h>
#endif
#include <stdexcept>

#include <xtl/xcomplex.hpp>

#include <xtensor/xarray.hpp>
#include <xtensor/xaxis_slice_iterator.hpp>
#include <xtensor/xbuilder.hpp>
#include <xtensor/xcomplex.hpp>
#include <xtensor/xmath.hpp>
#include <xtensor/xnoalias.hpp>
#include <xtensor/xview.hpp>

namespace xt
{
namespace fft
{
namespace detail
{
template <
class E,
typename std::enable_if<xtl::is_complex<typename std::decay<E>::type::value_type>::value, bool>::type = true>
inline auto radix2(E&& e)
{
using namespace xt::placeholders;
using namespace std::complex_literals;
using value_type = typename std::decay_t<E>::value_type;
using precision = typename value_type::value_type;
auto N = e.size();
const bool powerOfTwo = !(N == 0) && !(N & (N - 1));
// check for power of 2
if (!powerOfTwo || N == 0)
{
// TODO: Replace implementation with dft
XTENSOR_THROW(std::runtime_error, "FFT Implementation requires power of 2");
}
auto pi = xt::numeric_constants<precision>::PI;
xt::xtensor<value_type, 1> ev = e;
if (N <= 1)
{
return ev;
}
else
{
#ifdef XTENSOR_USE_TBB
xt::xtensor<value_type, 1> even;
xt::xtensor<value_type, 1> odd;
oneapi::tbb::parallel_invoke(
[&]
{
even = radix2(xt::view(ev, xt::range(0, _, 2)));
},
[&]
{
odd = radix2(xt::view(ev, xt::range(1, _, 2)));
}
);
#else
auto even = radix2(xt::view(ev, xt::range(0, _, 2)));
auto odd = radix2(xt::view(ev, xt::range(1, _, 2)));
#endif

auto range = xt::arange<double>(N / 2);
auto exp = xt::exp(static_cast<value_type>(-2i) * pi * range / N);
auto t = exp * odd;
auto first_half = even + t;
auto second_half = even - t;
// TODO: should be a call to stack if performance was improved
auto spectrum = xt::xtensor<value_type, 1>::from_shape({N});
xt::view(spectrum, xt::range(0, N / 2)) = first_half;
xt::view(spectrum, xt::range(N / 2, N)) = second_half;
return spectrum;
}
}

template <typename E>
auto transform_bluestein(E&& data)
{
using value_type = typename std::decay_t<E>::value_type;
using precision = typename value_type::value_type;

// Find a power-of-2 convolution length m such that m >= n * 2 + 1
const std::size_t n = data.size();
size_t m = std::ceil(std::log2(n * 2 + 1));
m = std::pow(2, m);

// Trignometric table
auto exp_table = xt::xtensor<std::complex<precision>, 1>::from_shape({n});
xt::xtensor<std::size_t, 1> i = xt::pow(xt::linspace<std::size_t>(0, n - 1, n), 2);
i %= (n * 2);

auto angles = xt::eval(precision{3.141592653589793238463} * i / n);
auto j = std::complex<precision>(0, 1);
exp_table = xt::exp(-angles * j);

// Temporary vectors and preprocessing
auto av = xt::empty<std::complex<precision>>({m});
xt::view(av, xt::range(0, n)) = data * exp_table;


auto bv = xt::empty<std::complex<precision>>({m});
xt::view(bv, xt::range(0, n)) = ::xt::conj(exp_table);
xt::view(bv, xt::range(-n + 1, xt::placeholders::_)) = xt::view(
::xt::conj(xt::flip(exp_table)),
xt::range(xt::placeholders::_, -1)
);

// Convolution
auto xv = radix2(av);
auto yv = radix2(bv);
auto spectrum_k = xv * yv;
auto complex_args = xt::conj(spectrum_k);
auto fft_res = radix2(complex_args);
auto cv = xt::conj(fft_res) / m;

return xt::eval(xt::view(cv, xt::range(0, n)) * exp_table);
}
} // namespace detail

/**
* @brief 1D FFT of an Nd array along a specified axis
* @param e an Nd expression to be transformed to the fourier domain
* @param axis the axis along which to perform the 1D FFT
* @return a transformed xarray of the specified precision
*/
template <
class E,
typename std::enable_if<xtl::is_complex<typename std::decay<E>::type::value_type>::value, bool>::type = true>
inline auto fft(E&& e, std::ptrdiff_t axis = -1)
{
using value_type = typename std::decay_t<E>::value_type;
using precision = typename value_type::value_type;
const auto saxis = xt::normalize_axis(e.dimension(), axis);
const size_t N = e.shape(saxis);
const bool powerOfTwo = !(N == 0) && !(N & (N - 1));
xt::xarray<std::complex<precision>> out = xt::eval(e);
auto begin = xt::axis_slice_begin(out, saxis);
auto end = xt::axis_slice_end(out, saxis);
for (auto iter = begin; iter != end; iter++)
{
if (powerOfTwo)
{
xt::noalias(*iter) = detail::radix2(*iter);
}
else
{
xt::noalias(*iter) = detail::transform_bluestein(*iter);
}
}
return out;
}

/**
* @breif 1D FFT of an Nd array along a specified axis
* @param e an Nd expression to be transformed to the fourier domain
* @param axis the axis along which to perform the 1D FFT
* @return a transformed xarray of the specified precision
*/
template <
class E,
typename std::enable_if<!xtl::is_complex<typename std::decay<E>::type::value_type>::value, bool>::type = true>
inline auto fft(E&& e, std::ptrdiff_t axis = -1)
{
using value_type = typename std::decay<E>::type::value_type;
return fft(xt::cast<std::complex<value_type>>(e), axis);
}

template <
class E,
typename std::enable_if<xtl::is_complex<typename std::decay<E>::type::value_type>::value, bool>::type = true>
auto ifft(E&& e, std::ptrdiff_t axis = -1)
{
// check the length of the data on that axis
const std::size_t n = e.shape(axis);
if (n == 0)
{
XTENSOR_THROW(std::runtime_error, "Cannot take the iFFT along an empty dimention");
}
auto complex_args = xt::conj(e);
auto fft_res = xt::fft::fft(complex_args, axis);
fft_res = xt::conj(fft_res);
return fft_res;
}

template <
class E,
typename std::enable_if<!xtl::is_complex<typename std::decay<E>::type::value_type>::value, bool>::type = true>
inline auto ifft(E&& e, std::ptrdiff_t axis = -1)
{
using value_type = typename std::decay<E>::type::value_type;
return ifft(xt::cast<std::complex<value_type>>(e), axis);
}

/*
* @brief performs a circular fft convolution xvec and yvec must
* be the same shape.
* @param xvec first array of the convolution
* @param yvec second array of the convolution
* @param axis axis along which to perform the convolution
*/
template <typename E1, typename E2>
auto convolve(E1&& xvec, E2&& yvec, std::ptrdiff_t axis = -1)
{
// we could broadcast but that could get complicated???
if (xvec.dimension() != yvec.dimension())
{
XTENSOR_THROW(std::runtime_error, "Mismatched dimentions");
}

auto saxis = xt::normalize_axis(xvec.dimension(), axis);
if (xvec.shape(saxis) != yvec.shape(saxis))
{
XTENSOR_THROW(std::runtime_error, "Mismatched lengths along slice axis");
}

const std::size_t n = xvec.shape(saxis);

auto xv = fft(xvec, axis);
auto yv = fft(yvec, axis);

auto begin_x = xt::axis_slice_begin(xv, saxis);
auto end_x = xt::axis_slice_end(xv, saxis);
auto iter_y = xt::axis_slice_begin(yv, saxis);

for (auto iter = begin_x; iter != end_x; iter++)
{
(*iter) = (*iter_y++) * (*iter);
}

auto outvec = ifft(xv, axis);

// Scaling (because this FFT implementation omits it)
outvec = outvec / n;

return outvec;
}

}
} // namespace xt::fft
36 changes: 36 additions & 0 deletions include/xtensor/xfunction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,42 @@ namespace xt
{
};

/*************************************
* overlapping_memory_checker_traits *
*************************************/

template <class E>
struct overlapping_memory_checker_traits<
E,
std::enable_if_t<!has_memory_address<E>::value && is_specialization_of<xfunction, E>::value>>
{
template <std::size_t I = 0, class... T, std::enable_if_t<(I == sizeof...(T)), int> = 0>
static bool check_tuple(const std::tuple<T...>&, const memory_range&)
{
return false;
}

template <std::size_t I = 0, class... T, std::enable_if_t<(I < sizeof...(T)), int> = 0>
static bool check_tuple(const std::tuple<T...>& t, const memory_range& dst_range)
{
using ChildE = std::decay_t<decltype(std::get<I>(t))>;
return overlapping_memory_checker_traits<ChildE>::check_overlap(std::get<I>(t), dst_range)
|| check_tuple<I + 1>(t, dst_range);
}

static bool check_overlap(const E& expr, const memory_range& dst_range)
{
if (expr.size() == 0)
{
return false;
}
else
{
return check_tuple(expr.arguments(), dst_range);
}
}
};

/*************
* xfunction *
*************/
Expand Down
Loading

0 comments on commit 932e387

Please sign in to comment.