diff --git a/onnxruntime/core/providers/cpu/signal/dft.cc b/onnxruntime/core/providers/cpu/signal/dft.cc index 50fe7d1344eaf..9806bc179b445 100644 --- a/onnxruntime/core/providers/cpu/signal/dft.cc +++ b/onnxruntime/core/providers/cpu/signal/dft.cc @@ -41,11 +41,11 @@ ONNX_CPU_OPERATOR_KERNEL(STFT, 17, STFT); static bool is_real_valued_signal(const onnxruntime::TensorShape& shape) { - return shape.NumDimensions() == 2 || shape[shape.NumDimensions() - 1] == 1; + return shape.NumDimensions() >= 2 && shape[shape.NumDimensions() - 1] == 1; } static bool is_complex_valued_signal(const onnxruntime::TensorShape& shape) { - return shape.NumDimensions() > 2 && shape[shape.NumDimensions() - 1] == 2; + return shape.NumDimensions() >= 2 && shape[shape.NumDimensions() - 1] == 2; } constexpr static bool is_power_of_2(size_t size) { @@ -143,9 +143,28 @@ static Status fft_radix2(OpKernelContext* /*ctx*/, const Tensor* X, Tensor* Y, s for (size_t i = 0; i < dft_length; i++) { size_t bit_reversed_index = bit_reverse(i, significant_bits); - auto x = (bit_reversed_index < number_of_samples) ? *(X_data + bit_reversed_index * X_stride) : 0; - auto window_element = window_data ? *(window_data + bit_reversed_index) : 1; - *(Y_data + i * Y_data_stride) = std::complex(1, 0) * x * window_element; + if (bit_reversed_index < number_of_samples) { + auto x = *(X_data + bit_reversed_index * X_stride); + auto window_element = window_data ? *(window_data + bit_reversed_index) : 1; + *(Y_data + i * Y_data_stride) = std::complex(1, 0) * x * window_element; + } + } + + // For IRFFT, fix up the negative frequencies using conjugate symmetry + if (is_onesided && inverse) { + size_t conjugate_end = (dft_length % 2 == 0) ? (number_of_samples - 1) : number_of_samples; + for (size_t k = 1; k < conjugate_end; k++) { + // Find position in bit-reversed array that represents natural index N-k + size_t pos_nk = bit_reverse(dft_length - k, significant_bits); + + // Get the input value and window at position k + auto x_k = *(X_data + k * X_stride); + auto window_k = window_data ? *(window_data + k) : 1; + + // Apply conjugate symmetry to input, then apply window + // X[N-k] = conj(X[k]), then multiply by window + *(Y_data + pos_nk * Y_data_stride) = std::conj(x_k) * window_k; + } } // Run fft_radix2 @@ -179,10 +198,17 @@ static Status fft_radix2(OpKernelContext* /*ctx*/, const Tensor* X, Tensor* Y, s } if (is_onesided) { - const size_t output_size = (dft_length >> 1) + 1; - auto destination = reinterpret_cast*>(Y->MutableDataRaw()) + Y_offset; - for (size_t i = 0; i < output_size; i++) { - *(destination + Y_stride * i) = *(Y_data + i * Y_data_stride); + if (inverse) { + auto destination = reinterpret_cast(Y->MutableDataRaw()) + Y_offset; + for (size_t i = 0; i < dft_length; i++) { + *(destination + Y_stride * i) = (*(Y_data + i * Y_data_stride)).real(); + } + } else { + const size_t output_size = (dft_length >> 1) + 1; + auto destination = reinterpret_cast*>(Y->MutableDataRaw()) + Y_offset; + for (size_t i = 0; i < output_size; i++) { + *(destination + Y_stride * i) = *(Y_data + i * Y_data_stride); + } } } @@ -202,7 +228,7 @@ T next_power_of_2(T in) { template static Status dft_bluestein_z_chirp( OpKernelContext* ctx, const Tensor* X, Tensor* Y, Tensor& b_fft, Tensor& chirp, size_t X_offset, size_t X_stride, size_t Y_offset, size_t Y_stride, - int64_t axis, size_t dft_length, const Tensor* window, bool inverse, InlinedVector>& V, + int64_t axis, size_t dft_length, const Tensor* window, bool is_onesided, bool inverse, InlinedVector>& V, InlinedVector>& temp_output) { static constexpr T pi = static_cast(M_PI); @@ -255,7 +281,6 @@ static Status dft_bluestein_z_chirp( // Get data auto* X_data = const_cast(reinterpret_cast(X->DataRaw())) + X_offset; - auto* Y_data = reinterpret_cast*>(Y->MutableDataRaw()) + Y_offset; U* window_data = nullptr; if (window) { window_data = const_cast(reinterpret_cast(window->DataRaw())); @@ -281,6 +306,20 @@ static Status dft_bluestein_z_chirp( a_n *= window_n; a_n *= chirp_n; } + if (inverse && is_onesided) { + // IRFFT: fill in negative frequencies using conjugate symmetry + // For the input X: X[N-k] = conj(X[k]) for k=1..floor((N-1)/2) + // We need to apply this BEFORE the chirp multiplication + // So: a[N-k] = conj(X[k]) * window[N-k] * chirp[N-k] + // = conj(X[k]) * window[k] * chirp[N-k] (window is symmetric) + size_t conjugate_end = (N % 2 == 0) ? (number_of_samples - 1) : number_of_samples; + for (size_t k = 1; k < conjugate_end; k++) { + auto x_k = *(X_data + k * X_stride); // Original input at k + auto window_k = window_data ? *(window_data + k) : 1; + std::complex& chirp_nk = *(chirp_data + N - k); + *(a_data + N - k) = std::conj(x_k) * window_k * chirp_nk; + } + } // Forward FFT radix2 for the "a" signal ORT_RETURN_IF_ERROR((fft_radix2>(ctx, &a, &a_fft, 0, 1, 0, 1, 1, M, nullptr, @@ -298,17 +337,33 @@ static Status dft_bluestein_z_chirp( const auto& Y_shape = Y->Shape(); size_t dft_output_size = static_cast(Y_shape[onnxruntime::narrow(axis)]); - for (size_t i = 0; i < dft_output_size; i++) { - std::complex& chirp_i = *(chirp_data + i); - std::complex& out = *(Y_data + i * Y_stride); - std::complex& c_i = *(a_data + i); - if (i > 0) { - // The inverse fft is computed using the same cached vandermonde matrix (V) created by the - // forward fft. This reversal causes the output to be reversed as well. - // Therefore we undo the reversal when writing the output back out. - c_i = *(a_data + M - i); + if (inverse && is_onesided) { + // IRFFT: extract real part only + auto* Y_data = reinterpret_cast(Y->MutableDataRaw()) + Y_offset; + for (size_t i = 0; i < dft_output_size; i++) { + std::complex& chirp_i = *(chirp_data + i); + T& out = *(Y_data + i * Y_stride); + std::complex& c_i = *(a_data + i); + if (i > 0) { + c_i = *(a_data + M - i); + } + out = (c_i * chirp_i * scale).real(); + } + } else { + // Standard complex output + auto* Y_data = reinterpret_cast*>(Y->MutableDataRaw()) + Y_offset; + for (size_t i = 0; i < dft_output_size; i++) { + std::complex& chirp_i = *(chirp_data + i); + std::complex& out = *(Y_data + i * Y_stride); + std::complex& c_i = *(a_data + i); + if (i > 0) { + // The inverse fft is computed using the same cached vandermonde matrix (V) created by the + // forward fft. This reversal causes the output to be reversed as well. + // Therefore we undo the reversal when writing the output back out. + c_i = *(a_data + M - i); + } + out = c_i * chirp_i * scale; } - out = c_i * chirp_i * scale; } return Status::OK(); } @@ -325,7 +380,7 @@ static Status discrete_fourier_transform(OpKernelContext* ctx, const Tensor* X, auto batch_and_signal_rank = X->Shape().NumDimensions(); auto total_dfts = static_cast(X->Shape().Size() / X->Shape()[onnxruntime::narrow(axis)]); - auto is_input_real = X->Shape().NumDimensions() == 2 || X->Shape()[X->Shape().NumDimensions() - 1] == 1; + auto is_input_real = X->Shape()[X->Shape().NumDimensions() - 1] == 1; auto complex_input_factor = is_input_real ? 1 : 2; if (X->Shape().NumDimensions() > 2) { total_dfts /= onnxruntime::narrow(X->Shape()[X->Shape().NumDimensions() - 1]); @@ -349,7 +404,8 @@ static Status discrete_fourier_transform(OpKernelContext* ctx, const Tensor* X, } size_t Y_offset = 0; - size_t Y_stride = onnxruntime::narrow(Y_shape.SizeFromDimension(SafeInt(axis) + 1) / 2); + size_t Y_last_dim_size = (inverse && is_onesided) ? 1 : 2; + size_t Y_stride = onnxruntime::narrow(Y_shape.SizeFromDimension(SafeInt(axis) + 1) / Y_last_dim_size); cumulative_packed_stride = total_dfts; temp = i; for (size_t r = 0; r < batch_and_signal_rank; r++) { @@ -359,7 +415,7 @@ static Status discrete_fourier_transform(OpKernelContext* ctx, const Tensor* X, cumulative_packed_stride /= onnxruntime::narrow(X_shape[r]); auto index = temp / cumulative_packed_stride; temp -= (index * cumulative_packed_stride); - Y_offset += index * SafeInt(Y_shape.SizeFromDimension(r + 1)) / 2; + Y_offset += index * SafeInt(Y_shape.SizeFromDimension(r + 1)) / Y_last_dim_size; } if (is_power_of_2(onnxruntime::narrow(dft_length))) { @@ -367,7 +423,7 @@ static Status discrete_fourier_transform(OpKernelContext* ctx, const Tensor* X, is_onesided, inverse, V, temp_output))); } else { ORT_RETURN_IF_ERROR( - (dft_bluestein_z_chirp(ctx, X, Y, b_fft, chirp, X_offset, X_stride, Y_offset, Y_stride, axis, onnxruntime::narrow(dft_length), window, inverse, V, temp_output))); + (dft_bluestein_z_chirp(ctx, X, Y, b_fft, chirp, X_offset, X_stride, Y_offset, Y_stride, axis, onnxruntime::narrow(dft_length), window, is_onesided, inverse, V, temp_output))); } } @@ -383,7 +439,16 @@ static Status discrete_fourier_transform(OpKernelContext* ctx, int64_t axis, boo const auto is_complex_valued = is_complex_valued_signal(X_shape); axis = HandleNegativeAxis(axis, X_shape.NumDimensions()); + // Validate input for IRFFT + if (inverse && is_onesided) { + ORT_RETURN_IF(!is_complex_valued, + "Inverse one-sided DFT (IRFFT) requires complex-valued input (last dimension must be 2)"); + } + int64_t number_of_samples = static_cast(X_shape[onnxruntime::narrow(axis)]); + if (inverse && is_onesided) { + number_of_samples = (number_of_samples - 1) << 1; + } if (dft_length) { const auto& dft_length_shape = dft_length->Shape(); ORT_RETURN_IF(!dft_length_shape.IsScalar(), "dft_length must be a scalar value."); @@ -393,13 +458,17 @@ static Status discrete_fourier_transform(OpKernelContext* ctx, int64_t axis, boo // Get the DFT output size. Onesided will return only the unique values! // note: x >> 1 === std::floor(x / 2.f) - auto dft_output_size = is_onesided ? ((number_of_samples >> 1) + 1) : number_of_samples; + // For IRFFT (inverse && onesided), output is full-size real signal + // For RFFT (!inverse && onesided), output is one-sided complex spectrum + auto dft_output_size = (is_onesided && !inverse) ? ((number_of_samples >> 1) + 1) : number_of_samples; // Get output shape auto Y_shape = onnxruntime::TensorShape(X_shape); - if (X_shape.NumDimensions() == 2) { - Y_shape = onnxruntime::TensorShape({X_shape[0], dft_output_size, 2}); + if (inverse && is_onesided) { + // IRFFT: output is real-valued + Y_shape[Y_shape.NumDimensions() - 1] = 1; // Real output } else { + // Complex output Y_shape[Y_shape.NumDimensions() - 1] = 2; } Y_shape[onnxruntime::narrow(axis)] = dft_output_size; diff --git a/onnxruntime/test/providers/cpu/signal/signal_ops_test.cc b/onnxruntime/test/providers/cpu/signal/signal_ops_test.cc index 9b44b5aa4c4fe..77c1bfa933459 100644 --- a/onnxruntime/test/providers/cpu/signal/signal_ops_test.cc +++ b/onnxruntime/test/providers/cpu/signal/signal_ops_test.cc @@ -305,5 +305,226 @@ TEST(SignalOpsTest, MelWeightMatrixFloat) { test.Run(); } +// IRFFT tests - inverse one-sided DFT (complex to real) +static void TestIRFFTRadix2Float(int since_version) { + OpTester test("DFT", since_version); + + // One-sided complex input (result of RFFT on 8 real samples) + vector input_shape = {1, 5, 2}; // floor(8/2) + 1 = 5 frequency bins + vector input = {36.000f, 0.000f, -4.000f, 9.65685f, -4.000f, 4.000f, + -4.000f, 1.65685f, -4.000f, 0.000f}; + + // Expected real output (should match original signal from RFFT) + vector output_shape = {1, 8, 1}; + vector expected_output = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; + + test.AddInput("input", input_shape, input); + if (since_version == 20) { + test.AddInput("dft_length", {}, {8}); + test.AddInput("axis", {}, {1}); + } + test.AddAttribute("onesided", static_cast(true)); + test.AddAttribute("inverse", static_cast(true)); + test.AddOutput("output", output_shape, expected_output); + test.SetOutputAbsErr("output", 0.0001f); + test.Run(); +} + +static void TestIRFFTNaiveFloat(int since_version) { + OpTester test("DFT", since_version); + + // One-sided complex input (result of RFFT on 5 real samples) + vector input_shape = {1, 3, 2}; // floor(5/2) + 1 = 3 frequency bins + vector input = {15.000000f, 0.0000000f, -2.499999f, 3.4409550f, -2.500000f, 0.8123000f}; + + // Expected real output + vector output_shape = {1, 5, 1}; + vector expected_output = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}; + + test.AddInput("input", input_shape, input); + // dft_length is required for IRFFT to distinguish between even and odd lengths + test.AddInput("dft_length", {}, {5}); + if (since_version == 20) { + test.AddInput("axis", {}, {-2}); + } + test.AddAttribute("onesided", static_cast(true)); + test.AddAttribute("inverse", static_cast(true)); + test.AddOutput("output", output_shape, expected_output); + test.SetOutputAbsErr("output", 0.0001f); + test.Run(); +} + +// Test RFFT -> IRFFT round trip +static void TestRFFTIRFFTRoundTrip(int since_version) { + class RFFTIRFFTTester : public OpTester { + public: + explicit RFFTIRFFTTester(int since_version) : OpTester("DFT", since_version) {} + + protected: + void AddNodes(Graph& graph, vector& graph_inputs, vector& graph_outputs, + vector>& add_attribute_funcs) override { + // Create intermediate output for RFFT + ONNX_NAMESPACE::TypeProto intermediate_type; + intermediate_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + vector intermediate_outputs; + intermediate_outputs.push_back(&graph.GetOrCreateNodeArg("rfft_output", &intermediate_type)); + + // Add RFFT node (forward one-sided) + OpTester::AddNodes(graph, graph_inputs, intermediate_outputs, add_attribute_funcs); + + if (this->Opset() < kOpsetVersion20) { + // For opset 17-19, just pass through + } else { + // For opset 20, pass dft_length and axis to IRFFT + assert(graph_inputs.size() == 3); + intermediate_outputs.push_back(graph_inputs[1]); + intermediate_outputs.push_back(graph_inputs[2]); + } + + // Add IRFFT node (inverse one-sided) + Node& irfft = graph.AddNode("irfft", "DFT", "inverse one-sided", intermediate_outputs, graph_outputs); + irfft.AddAttribute("onesided", static_cast(true)); + irfft.AddAttribute("inverse", static_cast(true)); + } + }; + + RFFTIRFFTTester test(since_version); + + // Real input signal + vector input_shape = {2, 8, 1}; + vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, + 8.0f, 7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f}; + + test.AddInput("input", input_shape, input_data); + if (since_version >= kOpsetVersion20) { + test.AddInput("dft_length", {}, {8}); + test.AddInput("axis", {}, {1}); + } + test.AddAttribute("onesided", static_cast(true)); + + // Output should match input (round trip) + test.AddOutput("output", input_shape, input_data); + test.SetOutputAbsErr("output", 0.001f); + test.Run(); +} + +TEST(SignalOpsTest, DFT17_IRFFT_radix2) { + TestIRFFTRadix2Float(kMinOpsetVersion); +} + +TEST(SignalOpsTest, DFT20_IRFFT_radix2) { + TestIRFFTRadix2Float(kOpsetVersion20); +} + +TEST(SignalOpsTest, DFT17_IRFFT_naive) { + TestIRFFTNaiveFloat(kMinOpsetVersion); +} + +TEST(SignalOpsTest, DFT20_IRFFT_naive) { + TestIRFFTNaiveFloat(kOpsetVersion20); +} + +TEST(SignalOpsTest, DFT17_RFFT_IRFFT_roundtrip) { + TestRFFTIRFFTRoundTrip(kMinOpsetVersion); +} + +TEST(SignalOpsTest, DFT20_RFFT_IRFFT_roundtrip) { + TestRFFTIRFFTRoundTrip(kOpsetVersion20); +} + +// Test 2D complex input (single 1D signal without batch dimension) +static void TestDFT2DComplex(int since_version) { + OpTester test("DFT", since_version); + + // 2D complex input: [signal_length, 2] + // This represents a single 1D complex signal without a batch dimension + vector input_shape = {8, 2}; + vector input = { + 1.0f, 0.0f, // complex(1, 0) + 2.0f, 0.0f, // complex(2, 0) + 3.0f, 0.0f, // complex(3, 0) + 4.0f, 0.0f, // complex(4, 0) + 5.0f, 0.0f, // complex(5, 0) + 6.0f, 0.0f, // complex(6, 0) + 7.0f, 0.0f, // complex(7, 0) + 8.0f, 0.0f // complex(8, 0) + }; + + // Expected output: DFT of the complex input + // Should have same shape [8, 2] for complex output + vector output_shape = {8, 2}; + vector expected_output = { + 36.000f, 0.000f, // bin 0 + -4.000f, 9.65685f, // bin 1 + -4.000f, 4.000f, // bin 2 + -4.000f, 1.65685f, // bin 3 + -4.000f, 0.000f, // bin 4 + -4.000f, -1.65685f, // bin 5 + -4.000f, -4.000f, // bin 6 + -4.000f, -9.65685f // bin 7 + }; + + test.AddInput("input", input_shape, input); + if (since_version == 20) { + test.AddInput("dft_length", {}, {8}); + test.AddInput("axis", {}, {0}); // axis=0 for 2D input + } else { + // For Opset 17, set axis attribute explicitly + test.AddAttribute("axis", static_cast(0)); + } + test.AddAttribute("onesided", static_cast(false)); + test.AddOutput("output", output_shape, expected_output); + test.SetOutputAbsErr("output", 0.0001f); + test.Run(); +} + +TEST(SignalOpsTest, DFT17_2D_complex) { + TestDFT2DComplex(kMinOpsetVersion); +} + +TEST(SignalOpsTest, DFT20_2D_complex) { + TestDFT2DComplex(kOpsetVersion20); +} + +// Test 2D real input (single 1D signal without batch dimension) +static void TestDFT2DReal(int since_version) { + OpTester test("DFT", since_version); + + // 2D real input: [signal_length, 1] + // This represents a single 1D real signal without a batch dimension + vector input_shape = {8, 1}; + vector input = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; + + // Expected output: RFFT of the real input (one-sided) + vector output_shape = {5, 2}; // floor(8/2) + 1 = 5 bins + vector expected_output = { + 36.000f, 0.000f, + -4.000f, 9.65685f, + -4.000f, 4.000f, + -4.000f, 1.65685f, + -4.000f, 0.000f}; + + test.AddInput("input", input_shape, input); + if (since_version == 20) { + test.AddInput("dft_length", {}, {8}); + test.AddInput("axis", {}, {0}); // axis=0 for 2D input + } else { + // For Opset 17, set axis attribute explicitly + test.AddAttribute("axis", static_cast(0)); + } + test.AddAttribute("onesided", static_cast(true)); + test.AddOutput("output", output_shape, expected_output); + test.SetOutputAbsErr("output", 0.0001f); + test.Run(); +} + +TEST(SignalOpsTest, DFT17_2D_real) { + TestDFT2DReal(kMinOpsetVersion); +} + +TEST(SignalOpsTest, DFT20_2D_real) { + TestDFT2DReal(kOpsetVersion20); +} + } // namespace test } // namespace onnxruntime