Skip to content

Commit

Permalink
FFT Radix 2 implementation (#15)
Browse files Browse the repository at this point in the history
* added lfilter

* Ran pre-commit

* Fixed env file

* Improved implementation

* Ran pre-commit

* Updated to reflect PR comments

* Use size_type

* Try to fix tests

* Moved one pad out

* Removed extra pad call

* Run pre-commit

* Added basic fft more work needed

* pre-commit

* Tests passing

* Updated cmake version

* Run pre=commit

* Added axis

* pre-commit

* Added check for power of 2

* Added todo

* pre-commit

* Added comment

* Added power of 2 check

* Added TBB to configuration

* Ran pre-commit

* Added comments

* Format again

* Update ghworkflow.yml

* Update fft.hpp

Added comments
  • Loading branch information
spectre-ns authored Nov 24, 2023
1 parent fb8a672 commit 2aa4f31
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 3 deletions.
6 changes: 5 additions & 1 deletion .github/workflows/ghworkflow.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
name: CI

on: push
on:
workflow_dispatch:
pull_request:
push:
branches: [master]

jobs:

Expand Down
11 changes: 10 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# The full license is in the file LICENSE, distributed with this software. #
############################################################################

cmake_minimum_required(VERSION 3.1)
cmake_minimum_required(VERSION 3.5)
project(xtensor-signal)

set(CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake ${CMAKE_MODULE_PATH})
Expand Down Expand Up @@ -67,6 +67,8 @@ set(XTENSOR_SIGNAL_HEADERS
${XTENSOR_SIGNAL_INCLUDE_DIR}/xtensor-signal/xtensor_signal.hpp
${XTENSOR_SIGNAL_INCLUDE_DIR}/xtensor-signal/find_peaks.hpp
${XTENSOR_SIGNAL_INCLUDE_DIR}/xtensor-signal/lfilter.hpp
${XTENSOR_SIGNAL_INCLUDE_DIR}/xtensor-signal/fft.hpp

)

add_library(xtensor-signal INTERFACE)
Expand All @@ -78,11 +80,18 @@ target_include_directories(xtensor-signal INTERFACE
target_link_libraries(xtensor-signal INTERFACE xtensor xsimd)

OPTION(BUILD_TESTS "xtensor test suite" OFF)
OPTION(XTENSOR_USE_TBB "Use tbb libraries" OFF)

if(BUILD_TESTS)
add_subdirectory(test)
endif()

if(XTENSOR_USE_TBB)
find_package(TBB REQUIRED)
message(STATUS "Found intel TBB: ${TBB_INCLUDE_DIRS}")
endif()


# Installation
# ============

Expand Down
99 changes: 99 additions & 0 deletions include/xtensor-signal/fft.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@

#ifdef XTENSOR_USE_TBB
#include <oneapi/tbb.h>
#endif
#include <stdexcept>
#include <xtensor/xarray.hpp>
#include <xtensor/xaxis_slice_iterator.hpp>
#include <xtensor/xbuilder.hpp>
#include <xtensor/xnoalias.hpp>
#include <xtensor/xview.hpp>
#include <xtl/xcomplex.hpp>

namespace xt::fft {
namespace detail {
template <class E,
typename std::enable_if<
xtl::is_complex<typename std::decay<E>::type::value_type>::value,
bool>::type = true>
inline auto fft(E &&e) {
using namespace xt::placeholders;
using namespace std::complex_literals;
using value_type = typename std::decay_t<E>::value_type;
using precision = typename value_type::value_type;
auto N = e.size();
const bool powerOfTwo = !(N == 0) && !(N & (N - 1));
// check for power of 2
if (!powerOfTwo || N == 0) {
// TODO: Replace implementation with dft
XTENSOR_THROW(std::runtime_error, "FFT Implementation requires power of 2");
}
auto pi = xt::numeric_constants<precision>::PI;
xt::xtensor<value_type, 1> ev = e;
if (N <= 1) {
return ev;
} else {
#ifdef XTENSOR_USE_TBB
xt::xtensor<value_type, 1> even;
xt::xtensor<value_type, 1> odd;
oneapi::tbb::parallel_invoke(
[&] { even = fft(xt::view(ev, xt::range(0, _, 2))); },
[&] { odd = fft(xt::view(ev, xt::range(1, _, 2))); });
#else
auto even = fft(xt::view(ev, xt::range(0, _, 2)));
auto odd = fft(xt::view(ev, xt::range(1, _, 2)));
#endif

auto range = xt::arange<double>(N / 2);
auto exp = xt::exp(static_cast<value_type>(-2i) * pi * range / N);
auto t = exp * odd;
auto first_half = even + t;
auto second_half = even - t;
// TODO: should be a call to stack if performance was improved
auto spectrum = xt::xtensor<value_type, 1>::from_shape({N});
xt::view(spectrum, xt::range(0, N / 2)) = first_half;
xt::view(spectrum, xt::range(N / 2, N)) = second_half;
return spectrum;
}
}
} // namespace detail

/**
* @breif 1D FFT of an Nd array along a specified axis
* @param e an Nd expression to be transformed to the fourier domain
* @param axis the axis along which to perform the 1D FFT
* @return a transformed xarray of the specified precision
*/
template <class E,
typename std::enable_if<
xtl::is_complex<typename std::decay<E>::type::value_type>::value,
bool>::type = true>
inline auto fft(E &&e, std::ptrdiff_t axis = -1) {
using value_type = typename std::decay_t<E>::value_type;
using precision = typename value_type::value_type;
xt::xarray<std::complex<precision>> out = xt::eval(e);
auto saxis = xt::normalize_axis(e.dimension(), axis);
auto begin = xt::axis_slice_begin(out, saxis);
auto end = xt::axis_slice_end(out, saxis);
for (auto iter = begin; iter != end; iter++) {
xt::noalias(*iter) = detail::fft(*iter);
}
return out;
}

/**
* @breif 1D FFT of an Nd array along a specified axis
* @param e an Nd expression to be transformed to the fourier domain
* @param axis the axis along which to perform the 1D FFT
* @return a transformed xarray of the specified precision
*/
template <class E,
typename std::enable_if<
!xtl::is_complex<typename std::decay<E>::type::value_type>::value,
bool>::type = true>
inline auto fft(E &&e, std::ptrdiff_t axis = -1) {
using value_type = typename std::decay<E>::type::value_type;
return fft(xt::cast<std::complex<value_type>>(e), axis);
}

} // namespace xt::fft
11 changes: 10 additions & 1 deletion test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# The full license is in the file LICENSE, distributed with this software. #
############################################################################

cmake_minimum_required(VERSION 3.1)
cmake_minimum_required(VERSION 3.5)

if (CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_SOURCE_DIR)
project(xtensor-signal-test)
Expand Down Expand Up @@ -40,15 +40,24 @@ set(XTENSOR_SIGNAL_TESTS
test_config.cpp
find_peaks_test.cpp
lfilter_test.cpp
fft_test.cpp
)

if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${_cxx_std_flag} /MP /bigobj")
endif()



file(COPY "test_data" DESTINATION "${CMAKE_BINARY_DIR}/test")

add_executable(test_xtensor_signal ${XTENSOR_SIGNAL_TESTS} ${XTENSOR_SIGNAL_HEADERS})
if(XTENSOR_USE_TBB)
target_compile_definitions(test_xtensor_signal PRIVATE XTENSOR_USE_TBB)
target_include_directories(test_xtensor_signal PRIVATE ${TBB_INCLUDE_DIRS})
target_link_libraries(test_xtensor_signal PRIVATE ${TBB_LIBRARIES})
endif()

target_link_libraries(test_xtensor_signal PRIVATE ZLIB::ZLIB xtensor-signal doctest::doctest ${CMAKE_THREAD_LIBS_INIT})

add_custom_target(xtest COMMAND ./test_xtensor_signal DEPENDS test_xtensor_signal)
49 changes: 49 additions & 0 deletions test/fft_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@

#include "doctest/doctest.h"
#include "xtensor-signal/fft.hpp"
#include <xtensor/xio.hpp>

TEST_SUITE("fft") {

TEST_CASE("fft_single") {
bool powerOfTwo = !(8 == 0) && !(8 & (8 - 1));
xt::xtensor<float, 1> input = {1, 1, 1, 1, 0, 0, 0, 0};
xt::xtensor<float, 1> expectation = {4.000, 2.613, 0.000, 1.082,
0.000, 1.082, 0.000, 2.613};
auto result = xt::fft::fft(input);
REQUIRE(xt::all(xt::isclose(xt::abs(result), expectation, .001)));
}
TEST_CASE("fft_double") {
xt::xtensor<double, 1> input = {1, 1, 1, 1, 0, 0, 0, 0};
xt::xtensor<double, 1> expectation = {4.000, 2.613, 0.000, 1.082,
0.000, 1.082, 0.000, 2.613};
auto result = xt::fft::fft(input);
REQUIRE(xt::all(xt::isclose(xt::abs(result), expectation, .001)));
}
TEST_CASE("fft_csingle") {
xt::xtensor<std::complex<float>, 1> input = {1, 1, 1, 1, 0, 0, 0, 0};
xt::xtensor<float, 1> expectation = {4.000, 2.613, 0.000, 1.082,
0.000, 1.082, 0.000, 2.613};
auto result = xt::fft::fft(input);
REQUIRE(xt::all(xt::isclose(xt::abs(result), expectation, .001)));
}
TEST_CASE("fft_cdouble") {
xt::xtensor<std::complex<double>, 1> input = {1, 1, 1, 1, 0, 0, 0, 0};
xt::xtensor<double, 1> expectation = {4.000, 2.613, 0.000, 1.082,
0.000, 1.082, 0.000, 2.613};
auto result = xt::fft::fft(input);
REQUIRE(xt::all(xt::isclose(xt::abs(result), expectation, .001)));
}

TEST_CASE("fft_double_axis0") {
xt::xarray<double> input = {{1, 1}, {1, 1}, {1, 1}, {1, 1},
{0, 0}, {0, 0}, {0, 0}, {0, 0}};
xt::xarray<double> expectation = {4.000, 2.613, 0.000, 1.082,
0.000, 1.082, 0.000, 2.613};
auto result = xt::fft::fft(input, 0);
auto first_column = xt::view(result, xt::all(), 0);
REQUIRE(xt::all(xt::isclose(xt::abs(first_column), expectation, .001)));
auto second_column = xt::view(result, xt::all(), 1);
REQUIRE(xt::all(xt::isclose(xt::abs(second_column), expectation, .001)));
}
}

0 comments on commit 2aa4f31

Please sign in to comment.