From fb8a672fec2810cb2bdcc58fdfafd6470548b529 Mon Sep 17 00:00:00 2001 From: Drew Hubley <96780897+spectre-ns@users.noreply.github.com> Date: Sun, 19 Nov 2023 09:40:59 -0400 Subject: [PATCH] Lfilter (#13) * added lfilter * Ran pre-commit * Fixed env file * Improved implementation * Ran pre-commit * Updated to reflect PR comments * Use size_type * Try to fix tests * Moved one pad out * Removed extra pad call * Run pre-commit --- .pre-commit-config.yaml | 6 +- CMakeLists.txt | 1 + environment-dev.yml | 1 + include/xtensor-signal/lfilter.hpp | 95 ++++++++++++++++++++++++++++++ test/CMakeLists.txt | 1 + test/lfilter_test.cpp | 60 +++++++++++++++++++ 6 files changed, 161 insertions(+), 3 deletions(-) create mode 100644 include/xtensor-signal/lfilter.hpp create mode 100644 test/lfilter_test.cpp diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a0e9260..004718b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,13 +18,13 @@ repos: - id: detect-private-key - id: check-merge-conflict - repo: https://github.com/Lucas-C/pre-commit-hooks - rev: v1.4.2 + rev: v1.5.4 hooks: - id: forbid-tabs - id: remove-tabs args: [--whitespaces-count, '4'] - repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks - rev: v2.7.0 + rev: v2.10.0 hooks: - id: pretty-format-yaml args: [--autofix, --indent, '2'] @@ -41,7 +41,7 @@ repos: files: environment.yaml # Externally provided executables (so we can use them with editors as well). - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v15.0.7 + rev: v16.0.6 hooks: - id: clang-format files: .*\.[hc]pp$ diff --git a/CMakeLists.txt b/CMakeLists.txt index 800ab1a..84fb918 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -66,6 +66,7 @@ endif() set(XTENSOR_SIGNAL_HEADERS ${XTENSOR_SIGNAL_INCLUDE_DIR}/xtensor-signal/xtensor_signal.hpp ${XTENSOR_SIGNAL_INCLUDE_DIR}/xtensor-signal/find_peaks.hpp + ${XTENSOR_SIGNAL_INCLUDE_DIR}/xtensor-signal/lfilter.hpp ) add_library(xtensor-signal INTERFACE) diff --git a/environment-dev.yml b/environment-dev.yml index 50c45e9..bdf9016 100644 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -14,3 +14,4 @@ dependencies: - openblas - doctest - zlib +- pre-commit diff --git a/include/xtensor-signal/lfilter.hpp b/include/xtensor-signal/lfilter.hpp new file mode 100644 index 0000000..76f8043 --- /dev/null +++ b/include/xtensor-signal/lfilter.hpp @@ -0,0 +1,95 @@ +#ifndef XTENSOR_SIGNAL_LFILTER_HPP +#define XTENSOR_SIGNAL_LFILTER_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace xt { +namespace signal { +namespace detail { +template +inline auto lfilter(E1 &&b, E2 &&a, E3 &&x, E4 zi) { + using value_type = typename std::decay_t::value_type; + using size_type = typename std::decay_t::size_type; + if (zi.shape(0) != x.shape(0)) { + XTENSOR_THROW( + std::runtime_error, + "Accumulator initialization must be the same length as the input"); + } + if (x.dimension() != 1) { + XTENSOR_THROW(std::runtime_error, + "Implementation only works on 1D arguments"); + } + if (a.dimension() != 1) { + XTENSOR_THROW(std::runtime_error, + "Implementation only works on 1D arguments"); + } + if (b.dimension() != 1) { + XTENSOR_THROW(std::runtime_error, + "Implementation only works on 1D arguments"); + } + xt::xtensor out = + xt::zeros({x.shape(0) + 2 * (a.shape(0) - 1)}); + auto padded_x = xt::pad(x, b.shape(0) - 1); + for (size_type i = 0; i < x.shape(0); i++) { + auto b_accum = + xt::sum(b * + xt::flip(xt::view(padded_x, xt::range(i, i + b.shape(0))))) + + zi(i); + + auto a_accum = + b_accum - + xt::sum(xt::view(a, xt::range(1, xt::placeholders::_)) * + xt::flip(xt::view(out, xt::range(i, i + a.shape(0) - 1)))); + auto result = a_accum / a(0); + out(i + a.shape(0) - 1) = result(); + } + out = xt::view(out, xt::range(a.shape(0) - 1, -(a.shape(0) - 1))); + return out; +} +} // namespace detail + +/* + * @brief performs a 1D filter operation along the specified axis. Performs + * operations immediately. + * @param b the numerator of the filter expression + * @param a the denominator of the filter expression + * @param x input dataset + * @param axis the axis along which to perform the filter operation + * @param zi initial condition of the filter accumulator + * @return filtered version of x + * @todo Add implementation bound to MKL or HPC library for IIR and FIR + */ +template +inline auto lfilter(E1 &&b, E2 &&a, E3 &&x, std::ptrdiff_t axis = -1, + E4 zi = xt::xnone()) { + using value_type = typename std::decay_t::value_type; + xt::xarray out(x); + auto saxis = xt::normalize_axis(out.dimension(), axis); + auto begin = xt::axis_slice_begin(out, saxis); + auto end = xt::axis_slice_end(out, saxis); + + for (auto iter = begin; iter != end; iter++) { + if constexpr (std::is_same::type, + decltype(xt::xnone())>::value == false) { + xt::noalias(*iter) = detail::lfilter(b, a, *iter, zi); + } else { + xt::noalias(*iter) = detail::lfilter( + b, a, *iter, xt::zeros({(*iter).shape(0)})); + } + } + return out; +} +} // namespace signal +} // namespace xt + +#endif diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 280c4cf..4cff75a 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -39,6 +39,7 @@ set(XTENSOR_SIGNAL_TESTS main.cpp test_config.cpp find_peaks_test.cpp + lfilter_test.cpp ) if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC") diff --git a/test/lfilter_test.cpp b/test/lfilter_test.cpp new file mode 100644 index 0000000..043f950 --- /dev/null +++ b/test/lfilter_test.cpp @@ -0,0 +1,60 @@ +#include "doctest/doctest.h" +#include "xtensor-signal/lfilter.hpp" +#include "xtensor/xio.hpp" +#include "xtensor/xrandom.hpp" +#include "xtensor/xsort.hpp" +#include "xtensor/xview.hpp" + +TEST_SUITE("lfilter") { + + TEST_CASE("3rdOrderButterworth") { + // credit + // https://rosettacode.org/wiki/Apply_a_digital_filter_(direct_form_II_transposed)#C++ + // define the signal + xt::xtensor sig = { + -0.917843918645, 0.141984778794, 1.20536903482, 0.190286794412, + -0.662370894973, -1.00700480494, -0.404707073677, 0.800482325044, + 0.743500089861, 1.01090520172, 0.741527555207, 0.277841675195, + 0.400833448236, -0.2085993586, -0.172842103641, -0.134316096293, + 0.0259303398477, 0.490105989562, 0.549391221511, 0.9047198589}; + + xt::xtensor expectation = { + -0.152974, -0.435258, -0.136043, 0.697503, 0.656445, + -0.435483, -1.08924, -0.537677, 0.51705, 1.05225, + 0.961854, 0.69569, 0.424356, 0.196262, -0.0278351, + -0.211722, -0.174746, 0.0692584, 0.385446, 0.651771}; + + // Constants for a Butterworth filter (order 3, low pass) + xt::xtensor a = {1.00000000, -2.77555756e-16, 3.33333333e-01, + -1.85037171e-17}; + xt::xtensor b = {0.16666667, 0.5, 0.5, 0.16666667}; + + auto res = xt::signal::lfilter(b, a, sig); + REQUIRE(xt::all(xt::isclose(res, expectation))); + } + TEST_CASE("3rdOrderButterworth_MultipleDims") { + xt::xtensor sig = { + -0.917843918645, 0.141984778794, 1.20536903482, 0.190286794412, + -0.662370894973, -1.00700480494, -0.404707073677, 0.800482325044, + 0.743500089861, 1.01090520172, 0.741527555207, 0.277841675195, + 0.400833448236, -0.2085993586, -0.172842103641, -0.134316096293, + 0.0259303398477, 0.490105989562, 0.549391221511, 0.9047198589}; + + xt::xtensor expectation = { + -0.152974, -0.435258, -0.136043, 0.697503, 0.656445, + -0.435483, -1.08924, -0.537677, 0.51705, 1.05225, + 0.961854, 0.69569, 0.424356, 0.196262, -0.0278351, + -0.211722, -0.174746, 0.0692584, 0.385446, 0.651771}; + + xt::xarray sig2 = + xt::stack(std::make_tuple(sig, xt::zeros_like(sig)), 1); + + // Constants for a Butterworth filter (order 3, low pass) + xt::xtensor a = {1.00000000, -2.77555756e-16, 3.33333333e-01, + -1.85037171e-17}; + xt::xtensor b = {0.16666667, 0.5, 0.5, 0.16666667}; + auto res = xt::signal::lfilter(b, a, sig2, 0); + REQUIRE(xt::all(xt::isclose(xt::view(res, xt::all(), 0), expectation))); + REQUIRE(xt::all(xt::isclose(xt::view(res, xt::all(), 1), 0))); + } +}