-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* added lfilter * Ran pre-commit * Fixed env file * Improved implementation * Ran pre-commit * Updated to reflect PR comments * Use size_type * Try to fix tests * Moved one pad out * Removed extra pad call * Run pre-commit * Added basic fft more work needed * pre-commit * Tests passing * Updated cmake version * Run pre=commit * Added axis * pre-commit * Added check for power of 2 * Added todo * pre-commit * Added comment * Added power of 2 check * Added TBB to configuration * Ran pre-commit * Added comments * Format again * Update ghworkflow.yml * Update fft.hpp Added comments
- Loading branch information
1 parent
fb8a672
commit 2aa4f31
Showing
5 changed files
with
173 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,10 @@ | ||
name: CI | ||
|
||
on: push | ||
on: | ||
workflow_dispatch: | ||
pull_request: | ||
push: | ||
branches: [master] | ||
|
||
jobs: | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
|
||
#ifdef XTENSOR_USE_TBB | ||
#include <oneapi/tbb.h> | ||
#endif | ||
#include <stdexcept> | ||
#include <xtensor/xarray.hpp> | ||
#include <xtensor/xaxis_slice_iterator.hpp> | ||
#include <xtensor/xbuilder.hpp> | ||
#include <xtensor/xnoalias.hpp> | ||
#include <xtensor/xview.hpp> | ||
#include <xtl/xcomplex.hpp> | ||
|
||
namespace xt::fft { | ||
namespace detail { | ||
template <class E, | ||
typename std::enable_if< | ||
xtl::is_complex<typename std::decay<E>::type::value_type>::value, | ||
bool>::type = true> | ||
inline auto fft(E &&e) { | ||
using namespace xt::placeholders; | ||
using namespace std::complex_literals; | ||
using value_type = typename std::decay_t<E>::value_type; | ||
using precision = typename value_type::value_type; | ||
auto N = e.size(); | ||
const bool powerOfTwo = !(N == 0) && !(N & (N - 1)); | ||
// check for power of 2 | ||
if (!powerOfTwo || N == 0) { | ||
// TODO: Replace implementation with dft | ||
XTENSOR_THROW(std::runtime_error, "FFT Implementation requires power of 2"); | ||
} | ||
auto pi = xt::numeric_constants<precision>::PI; | ||
xt::xtensor<value_type, 1> ev = e; | ||
if (N <= 1) { | ||
return ev; | ||
} else { | ||
#ifdef XTENSOR_USE_TBB | ||
xt::xtensor<value_type, 1> even; | ||
xt::xtensor<value_type, 1> odd; | ||
oneapi::tbb::parallel_invoke( | ||
[&] { even = fft(xt::view(ev, xt::range(0, _, 2))); }, | ||
[&] { odd = fft(xt::view(ev, xt::range(1, _, 2))); }); | ||
#else | ||
auto even = fft(xt::view(ev, xt::range(0, _, 2))); | ||
auto odd = fft(xt::view(ev, xt::range(1, _, 2))); | ||
#endif | ||
|
||
auto range = xt::arange<double>(N / 2); | ||
auto exp = xt::exp(static_cast<value_type>(-2i) * pi * range / N); | ||
auto t = exp * odd; | ||
auto first_half = even + t; | ||
auto second_half = even - t; | ||
// TODO: should be a call to stack if performance was improved | ||
auto spectrum = xt::xtensor<value_type, 1>::from_shape({N}); | ||
xt::view(spectrum, xt::range(0, N / 2)) = first_half; | ||
xt::view(spectrum, xt::range(N / 2, N)) = second_half; | ||
return spectrum; | ||
} | ||
} | ||
} // namespace detail | ||
|
||
/** | ||
* @breif 1D FFT of an Nd array along a specified axis | ||
* @param e an Nd expression to be transformed to the fourier domain | ||
* @param axis the axis along which to perform the 1D FFT | ||
* @return a transformed xarray of the specified precision | ||
*/ | ||
template <class E, | ||
typename std::enable_if< | ||
xtl::is_complex<typename std::decay<E>::type::value_type>::value, | ||
bool>::type = true> | ||
inline auto fft(E &&e, std::ptrdiff_t axis = -1) { | ||
using value_type = typename std::decay_t<E>::value_type; | ||
using precision = typename value_type::value_type; | ||
xt::xarray<std::complex<precision>> out = xt::eval(e); | ||
auto saxis = xt::normalize_axis(e.dimension(), axis); | ||
auto begin = xt::axis_slice_begin(out, saxis); | ||
auto end = xt::axis_slice_end(out, saxis); | ||
for (auto iter = begin; iter != end; iter++) { | ||
xt::noalias(*iter) = detail::fft(*iter); | ||
} | ||
return out; | ||
} | ||
|
||
/** | ||
* @breif 1D FFT of an Nd array along a specified axis | ||
* @param e an Nd expression to be transformed to the fourier domain | ||
* @param axis the axis along which to perform the 1D FFT | ||
* @return a transformed xarray of the specified precision | ||
*/ | ||
template <class E, | ||
typename std::enable_if< | ||
!xtl::is_complex<typename std::decay<E>::type::value_type>::value, | ||
bool>::type = true> | ||
inline auto fft(E &&e, std::ptrdiff_t axis = -1) { | ||
using value_type = typename std::decay<E>::type::value_type; | ||
return fft(xt::cast<std::complex<value_type>>(e), axis); | ||
} | ||
|
||
} // namespace xt::fft |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
|
||
#include "doctest/doctest.h" | ||
#include "xtensor-signal/fft.hpp" | ||
#include <xtensor/xio.hpp> | ||
|
||
TEST_SUITE("fft") { | ||
|
||
TEST_CASE("fft_single") { | ||
bool powerOfTwo = !(8 == 0) && !(8 & (8 - 1)); | ||
xt::xtensor<float, 1> input = {1, 1, 1, 1, 0, 0, 0, 0}; | ||
xt::xtensor<float, 1> expectation = {4.000, 2.613, 0.000, 1.082, | ||
0.000, 1.082, 0.000, 2.613}; | ||
auto result = xt::fft::fft(input); | ||
REQUIRE(xt::all(xt::isclose(xt::abs(result), expectation, .001))); | ||
} | ||
TEST_CASE("fft_double") { | ||
xt::xtensor<double, 1> input = {1, 1, 1, 1, 0, 0, 0, 0}; | ||
xt::xtensor<double, 1> expectation = {4.000, 2.613, 0.000, 1.082, | ||
0.000, 1.082, 0.000, 2.613}; | ||
auto result = xt::fft::fft(input); | ||
REQUIRE(xt::all(xt::isclose(xt::abs(result), expectation, .001))); | ||
} | ||
TEST_CASE("fft_csingle") { | ||
xt::xtensor<std::complex<float>, 1> input = {1, 1, 1, 1, 0, 0, 0, 0}; | ||
xt::xtensor<float, 1> expectation = {4.000, 2.613, 0.000, 1.082, | ||
0.000, 1.082, 0.000, 2.613}; | ||
auto result = xt::fft::fft(input); | ||
REQUIRE(xt::all(xt::isclose(xt::abs(result), expectation, .001))); | ||
} | ||
TEST_CASE("fft_cdouble") { | ||
xt::xtensor<std::complex<double>, 1> input = {1, 1, 1, 1, 0, 0, 0, 0}; | ||
xt::xtensor<double, 1> expectation = {4.000, 2.613, 0.000, 1.082, | ||
0.000, 1.082, 0.000, 2.613}; | ||
auto result = xt::fft::fft(input); | ||
REQUIRE(xt::all(xt::isclose(xt::abs(result), expectation, .001))); | ||
} | ||
|
||
TEST_CASE("fft_double_axis0") { | ||
xt::xarray<double> input = {{1, 1}, {1, 1}, {1, 1}, {1, 1}, | ||
{0, 0}, {0, 0}, {0, 0}, {0, 0}}; | ||
xt::xarray<double> expectation = {4.000, 2.613, 0.000, 1.082, | ||
0.000, 1.082, 0.000, 2.613}; | ||
auto result = xt::fft::fft(input, 0); | ||
auto first_column = xt::view(result, xt::all(), 0); | ||
REQUIRE(xt::all(xt::isclose(xt::abs(first_column), expectation, .001))); | ||
auto second_column = xt::view(result, xt::all(), 1); | ||
REQUIRE(xt::all(xt::isclose(xt::abs(second_column), expectation, .001))); | ||
} | ||
} |