From 2aa4f31c1f5e8eaf0c24dd4cfbfb57fd984e6040 Mon Sep 17 00:00:00 2001 From: Drew Hubley <96780897+spectre-ns@users.noreply.github.com> Date: Fri, 24 Nov 2023 08:53:20 -0400 Subject: [PATCH] FFT Radix 2 implementation (#15) * added lfilter * Ran pre-commit * Fixed env file * Improved implementation * Ran pre-commit * Updated to reflect PR comments * Use size_type * Try to fix tests * Moved one pad out * Removed extra pad call * Run pre-commit * Added basic fft more work needed * pre-commit * Tests passing * Updated cmake version * Run pre=commit * Added axis * pre-commit * Added check for power of 2 * Added todo * pre-commit * Added comment * Added power of 2 check * Added TBB to configuration * Ran pre-commit * Added comments * Format again * Update ghworkflow.yml * Update fft.hpp Added comments --- .github/workflows/ghworkflow.yml | 6 +- CMakeLists.txt | 11 +++- include/xtensor-signal/fft.hpp | 99 ++++++++++++++++++++++++++++++++ test/CMakeLists.txt | 11 +++- test/fft_test.cpp | 49 ++++++++++++++++ 5 files changed, 173 insertions(+), 3 deletions(-) create mode 100644 include/xtensor-signal/fft.hpp create mode 100644 test/fft_test.cpp diff --git a/.github/workflows/ghworkflow.yml b/.github/workflows/ghworkflow.yml index 378aee4..e258cfe 100644 --- a/.github/workflows/ghworkflow.yml +++ b/.github/workflows/ghworkflow.yml @@ -1,6 +1,10 @@ name: CI -on: push +on: + workflow_dispatch: + pull_request: + push: + branches: [master] jobs: diff --git a/CMakeLists.txt b/CMakeLists.txt index 84fb918..650966e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -6,7 +6,7 @@ # The full license is in the file LICENSE, distributed with this software. # ############################################################################ -cmake_minimum_required(VERSION 3.1) +cmake_minimum_required(VERSION 3.5) project(xtensor-signal) set(CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake ${CMAKE_MODULE_PATH}) @@ -67,6 +67,8 @@ set(XTENSOR_SIGNAL_HEADERS ${XTENSOR_SIGNAL_INCLUDE_DIR}/xtensor-signal/xtensor_signal.hpp ${XTENSOR_SIGNAL_INCLUDE_DIR}/xtensor-signal/find_peaks.hpp ${XTENSOR_SIGNAL_INCLUDE_DIR}/xtensor-signal/lfilter.hpp + ${XTENSOR_SIGNAL_INCLUDE_DIR}/xtensor-signal/fft.hpp + ) add_library(xtensor-signal INTERFACE) @@ -78,11 +80,18 @@ target_include_directories(xtensor-signal INTERFACE target_link_libraries(xtensor-signal INTERFACE xtensor xsimd) OPTION(BUILD_TESTS "xtensor test suite" OFF) +OPTION(XTENSOR_USE_TBB "Use tbb libraries" OFF) if(BUILD_TESTS) add_subdirectory(test) endif() +if(XTENSOR_USE_TBB) + find_package(TBB REQUIRED) + message(STATUS "Found intel TBB: ${TBB_INCLUDE_DIRS}") +endif() + + # Installation # ============ diff --git a/include/xtensor-signal/fft.hpp b/include/xtensor-signal/fft.hpp new file mode 100644 index 0000000..2b5cd84 --- /dev/null +++ b/include/xtensor-signal/fft.hpp @@ -0,0 +1,99 @@ + +#ifdef XTENSOR_USE_TBB +#include +#endif +#include +#include +#include +#include +#include +#include +#include + +namespace xt::fft { +namespace detail { +template ::type::value_type>::value, + bool>::type = true> +inline auto fft(E &&e) { + using namespace xt::placeholders; + using namespace std::complex_literals; + using value_type = typename std::decay_t::value_type; + using precision = typename value_type::value_type; + auto N = e.size(); + const bool powerOfTwo = !(N == 0) && !(N & (N - 1)); + // check for power of 2 + if (!powerOfTwo || N == 0) { + // TODO: Replace implementation with dft + XTENSOR_THROW(std::runtime_error, "FFT Implementation requires power of 2"); + } + auto pi = xt::numeric_constants::PI; + xt::xtensor ev = e; + if (N <= 1) { + return ev; + } else { +#ifdef XTENSOR_USE_TBB + xt::xtensor even; + xt::xtensor odd; + oneapi::tbb::parallel_invoke( + [&] { even = fft(xt::view(ev, xt::range(0, _, 2))); }, + [&] { odd = fft(xt::view(ev, xt::range(1, _, 2))); }); +#else + auto even = fft(xt::view(ev, xt::range(0, _, 2))); + auto odd = fft(xt::view(ev, xt::range(1, _, 2))); +#endif + + auto range = xt::arange(N / 2); + auto exp = xt::exp(static_cast(-2i) * pi * range / N); + auto t = exp * odd; + auto first_half = even + t; + auto second_half = even - t; + // TODO: should be a call to stack if performance was improved + auto spectrum = xt::xtensor::from_shape({N}); + xt::view(spectrum, xt::range(0, N / 2)) = first_half; + xt::view(spectrum, xt::range(N / 2, N)) = second_half; + return spectrum; + } +} +} // namespace detail + +/** +* @breif 1D FFT of an Nd array along a specified axis +* @param e an Nd expression to be transformed to the fourier domain +* @param axis the axis along which to perform the 1D FFT +* @return a transformed xarray of the specified precision +*/ +template ::type::value_type>::value, + bool>::type = true> +inline auto fft(E &&e, std::ptrdiff_t axis = -1) { + using value_type = typename std::decay_t::value_type; + using precision = typename value_type::value_type; + xt::xarray> out = xt::eval(e); + auto saxis = xt::normalize_axis(e.dimension(), axis); + auto begin = xt::axis_slice_begin(out, saxis); + auto end = xt::axis_slice_end(out, saxis); + for (auto iter = begin; iter != end; iter++) { + xt::noalias(*iter) = detail::fft(*iter); + } + return out; +} + +/** +* @breif 1D FFT of an Nd array along a specified axis +* @param e an Nd expression to be transformed to the fourier domain +* @param axis the axis along which to perform the 1D FFT +* @return a transformed xarray of the specified precision +*/ +template ::type::value_type>::value, + bool>::type = true> +inline auto fft(E &&e, std::ptrdiff_t axis = -1) { + using value_type = typename std::decay::type::value_type; + return fft(xt::cast>(e), axis); +} + +} // namespace xt::fft diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 4cff75a..06e767b 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -7,7 +7,7 @@ # The full license is in the file LICENSE, distributed with this software. # ############################################################################ -cmake_minimum_required(VERSION 3.1) +cmake_minimum_required(VERSION 3.5) if (CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_SOURCE_DIR) project(xtensor-signal-test) @@ -40,15 +40,24 @@ set(XTENSOR_SIGNAL_TESTS test_config.cpp find_peaks_test.cpp lfilter_test.cpp + fft_test.cpp ) if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${_cxx_std_flag} /MP /bigobj") endif() + + file(COPY "test_data" DESTINATION "${CMAKE_BINARY_DIR}/test") add_executable(test_xtensor_signal ${XTENSOR_SIGNAL_TESTS} ${XTENSOR_SIGNAL_HEADERS}) +if(XTENSOR_USE_TBB) + target_compile_definitions(test_xtensor_signal PRIVATE XTENSOR_USE_TBB) + target_include_directories(test_xtensor_signal PRIVATE ${TBB_INCLUDE_DIRS}) + target_link_libraries(test_xtensor_signal PRIVATE ${TBB_LIBRARIES}) +endif() + target_link_libraries(test_xtensor_signal PRIVATE ZLIB::ZLIB xtensor-signal doctest::doctest ${CMAKE_THREAD_LIBS_INIT}) add_custom_target(xtest COMMAND ./test_xtensor_signal DEPENDS test_xtensor_signal) diff --git a/test/fft_test.cpp b/test/fft_test.cpp new file mode 100644 index 0000000..7adc875 --- /dev/null +++ b/test/fft_test.cpp @@ -0,0 +1,49 @@ + +#include "doctest/doctest.h" +#include "xtensor-signal/fft.hpp" +#include + +TEST_SUITE("fft") { + + TEST_CASE("fft_single") { + bool powerOfTwo = !(8 == 0) && !(8 & (8 - 1)); + xt::xtensor input = {1, 1, 1, 1, 0, 0, 0, 0}; + xt::xtensor expectation = {4.000, 2.613, 0.000, 1.082, + 0.000, 1.082, 0.000, 2.613}; + auto result = xt::fft::fft(input); + REQUIRE(xt::all(xt::isclose(xt::abs(result), expectation, .001))); + } + TEST_CASE("fft_double") { + xt::xtensor input = {1, 1, 1, 1, 0, 0, 0, 0}; + xt::xtensor expectation = {4.000, 2.613, 0.000, 1.082, + 0.000, 1.082, 0.000, 2.613}; + auto result = xt::fft::fft(input); + REQUIRE(xt::all(xt::isclose(xt::abs(result), expectation, .001))); + } + TEST_CASE("fft_csingle") { + xt::xtensor, 1> input = {1, 1, 1, 1, 0, 0, 0, 0}; + xt::xtensor expectation = {4.000, 2.613, 0.000, 1.082, + 0.000, 1.082, 0.000, 2.613}; + auto result = xt::fft::fft(input); + REQUIRE(xt::all(xt::isclose(xt::abs(result), expectation, .001))); + } + TEST_CASE("fft_cdouble") { + xt::xtensor, 1> input = {1, 1, 1, 1, 0, 0, 0, 0}; + xt::xtensor expectation = {4.000, 2.613, 0.000, 1.082, + 0.000, 1.082, 0.000, 2.613}; + auto result = xt::fft::fft(input); + REQUIRE(xt::all(xt::isclose(xt::abs(result), expectation, .001))); + } + + TEST_CASE("fft_double_axis0") { + xt::xarray input = {{1, 1}, {1, 1}, {1, 1}, {1, 1}, + {0, 0}, {0, 0}, {0, 0}, {0, 0}}; + xt::xarray expectation = {4.000, 2.613, 0.000, 1.082, + 0.000, 1.082, 0.000, 2.613}; + auto result = xt::fft::fft(input, 0); + auto first_column = xt::view(result, xt::all(), 0); + REQUIRE(xt::all(xt::isclose(xt::abs(first_column), expectation, .001))); + auto second_column = xt::view(result, xt::all(), 1); + REQUIRE(xt::all(xt::isclose(xt::abs(second_column), expectation, .001))); + } +}