Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 97 additions & 28 deletions onnxruntime/core/providers/cpu/signal/dft.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ ONNX_CPU_OPERATOR_KERNEL(STFT, 17,
STFT);

static bool is_real_valued_signal(const onnxruntime::TensorShape& shape) {
return shape.NumDimensions() == 2 || shape[shape.NumDimensions() - 1] == 1;
return shape.NumDimensions() >= 2 && shape[shape.NumDimensions() - 1] == 1;
}

static bool is_complex_valued_signal(const onnxruntime::TensorShape& shape) {
return shape.NumDimensions() > 2 && shape[shape.NumDimensions() - 1] == 2;
return shape.NumDimensions() >= 2 && shape[shape.NumDimensions() - 1] == 2;
}

constexpr static bool is_power_of_2(size_t size) {
Expand Down Expand Up @@ -143,9 +143,28 @@ static Status fft_radix2(OpKernelContext* /*ctx*/, const Tensor* X, Tensor* Y, s

for (size_t i = 0; i < dft_length; i++) {
size_t bit_reversed_index = bit_reverse(i, significant_bits);
auto x = (bit_reversed_index < number_of_samples) ? *(X_data + bit_reversed_index * X_stride) : 0;
auto window_element = window_data ? *(window_data + bit_reversed_index) : 1;
*(Y_data + i * Y_data_stride) = std::complex<T>(1, 0) * x * window_element;
if (bit_reversed_index < number_of_samples) {
auto x = *(X_data + bit_reversed_index * X_stride);
auto window_element = window_data ? *(window_data + bit_reversed_index) : 1;
*(Y_data + i * Y_data_stride) = std::complex<T>(1, 0) * x * window_element;
}
}

// For IRFFT, fix up the negative frequencies using conjugate symmetry
if (is_onesided && inverse) {
size_t conjugate_end = (dft_length % 2 == 0) ? (number_of_samples - 1) : number_of_samples;
for (size_t k = 1; k < conjugate_end; k++) {
// Find position in bit-reversed array that represents natural index N-k
size_t pos_nk = bit_reverse(dft_length - k, significant_bits);

// Get the input value and window at position k
auto x_k = *(X_data + k * X_stride);
auto window_k = window_data ? *(window_data + k) : 1;

// Apply conjugate symmetry to input, then apply window
// X[N-k] = conj(X[k]), then multiply by window
*(Y_data + pos_nk * Y_data_stride) = std::conj(x_k) * window_k;
}
}

// Run fft_radix2
Expand Down Expand Up @@ -179,10 +198,17 @@ static Status fft_radix2(OpKernelContext* /*ctx*/, const Tensor* X, Tensor* Y, s
}

if (is_onesided) {
const size_t output_size = (dft_length >> 1) + 1;
auto destination = reinterpret_cast<std::complex<T>*>(Y->MutableDataRaw()) + Y_offset;
for (size_t i = 0; i < output_size; i++) {
*(destination + Y_stride * i) = *(Y_data + i * Y_data_stride);
if (inverse) {
auto destination = reinterpret_cast<T*>(Y->MutableDataRaw()) + Y_offset;
for (size_t i = 0; i < dft_length; i++) {
*(destination + Y_stride * i) = (*(Y_data + i * Y_data_stride)).real();
}
} else {
const size_t output_size = (dft_length >> 1) + 1;
auto destination = reinterpret_cast<std::complex<T>*>(Y->MutableDataRaw()) + Y_offset;
for (size_t i = 0; i < output_size; i++) {
*(destination + Y_stride * i) = *(Y_data + i * Y_data_stride);
}
}
}

Expand All @@ -202,7 +228,7 @@ T next_power_of_2(T in) {
template <typename T, typename U>
static Status dft_bluestein_z_chirp(
OpKernelContext* ctx, const Tensor* X, Tensor* Y, Tensor& b_fft, Tensor& chirp, size_t X_offset, size_t X_stride, size_t Y_offset, size_t Y_stride,
int64_t axis, size_t dft_length, const Tensor* window, bool inverse, InlinedVector<std::complex<T>>& V,
int64_t axis, size_t dft_length, const Tensor* window, bool is_onesided, bool inverse, InlinedVector<std::complex<T>>& V,
InlinedVector<std::complex<T>>& temp_output) {
static constexpr T pi = static_cast<T>(M_PI);

Expand Down Expand Up @@ -255,7 +281,6 @@ static Status dft_bluestein_z_chirp(

// Get data
auto* X_data = const_cast<U*>(reinterpret_cast<const U*>(X->DataRaw())) + X_offset;
auto* Y_data = reinterpret_cast<std::complex<T>*>(Y->MutableDataRaw()) + Y_offset;
U* window_data = nullptr;
if (window) {
window_data = const_cast<U*>(reinterpret_cast<const U*>(window->DataRaw()));
Expand All @@ -281,6 +306,20 @@ static Status dft_bluestein_z_chirp(
a_n *= window_n;
a_n *= chirp_n;
}
if (inverse && is_onesided) {
// IRFFT: fill in negative frequencies using conjugate symmetry
// For the input X: X[N-k] = conj(X[k]) for k=1..floor((N-1)/2)
// We need to apply this BEFORE the chirp multiplication
// So: a[N-k] = conj(X[k]) * window[N-k] * chirp[N-k]
// = conj(X[k]) * window[k] * chirp[N-k] (window is symmetric)
size_t conjugate_end = (N % 2 == 0) ? (number_of_samples - 1) : number_of_samples;
for (size_t k = 1; k < conjugate_end; k++) {
auto x_k = *(X_data + k * X_stride); // Original input at k
auto window_k = window_data ? *(window_data + k) : 1;
std::complex<T>& chirp_nk = *(chirp_data + N - k);
*(a_data + N - k) = std::conj(x_k) * window_k * chirp_nk;
}
}

// Forward FFT radix2 for the "a" signal
ORT_RETURN_IF_ERROR((fft_radix2<T, std::complex<T>>(ctx, &a, &a_fft, 0, 1, 0, 1, 1, M, nullptr,
Expand All @@ -298,17 +337,33 @@ static Status dft_bluestein_z_chirp(
const auto& Y_shape = Y->Shape();
size_t dft_output_size = static_cast<size_t>(Y_shape[onnxruntime::narrow<size_t>(axis)]);

for (size_t i = 0; i < dft_output_size; i++) {
std::complex<T>& chirp_i = *(chirp_data + i);
std::complex<T>& out = *(Y_data + i * Y_stride);
std::complex<T>& c_i = *(a_data + i);
if (i > 0) {
// The inverse fft is computed using the same cached vandermonde matrix (V) created by the
// forward fft. This reversal causes the output to be reversed as well.
// Therefore we undo the reversal when writing the output back out.
c_i = *(a_data + M - i);
if (inverse && is_onesided) {
// IRFFT: extract real part only
auto* Y_data = reinterpret_cast<T*>(Y->MutableDataRaw()) + Y_offset;
for (size_t i = 0; i < dft_output_size; i++) {
std::complex<T>& chirp_i = *(chirp_data + i);
T& out = *(Y_data + i * Y_stride);
std::complex<T>& c_i = *(a_data + i);
if (i > 0) {
c_i = *(a_data + M - i);
}
out = (c_i * chirp_i * scale).real();
}
} else {
// Standard complex output
auto* Y_data = reinterpret_cast<std::complex<T>*>(Y->MutableDataRaw()) + Y_offset;
for (size_t i = 0; i < dft_output_size; i++) {
std::complex<T>& chirp_i = *(chirp_data + i);
std::complex<T>& out = *(Y_data + i * Y_stride);
std::complex<T>& c_i = *(a_data + i);
if (i > 0) {
// The inverse fft is computed using the same cached vandermonde matrix (V) created by the
// forward fft. This reversal causes the output to be reversed as well.
// Therefore we undo the reversal when writing the output back out.
c_i = *(a_data + M - i);
}
out = c_i * chirp_i * scale;
}
out = c_i * chirp_i * scale;
}
return Status::OK();
}
Expand All @@ -325,7 +380,7 @@ static Status discrete_fourier_transform(OpKernelContext* ctx, const Tensor* X,
auto batch_and_signal_rank = X->Shape().NumDimensions();
auto total_dfts = static_cast<size_t>(X->Shape().Size() / X->Shape()[onnxruntime::narrow<size_t>(axis)]);

auto is_input_real = X->Shape().NumDimensions() == 2 || X->Shape()[X->Shape().NumDimensions() - 1] == 1;
auto is_input_real = X->Shape()[X->Shape().NumDimensions() - 1] == 1;
auto complex_input_factor = is_input_real ? 1 : 2;
if (X->Shape().NumDimensions() > 2) {
total_dfts /= onnxruntime::narrow<size_t>(X->Shape()[X->Shape().NumDimensions() - 1]);
Expand All @@ -349,7 +404,8 @@ static Status discrete_fourier_transform(OpKernelContext* ctx, const Tensor* X,
}

size_t Y_offset = 0;
size_t Y_stride = onnxruntime::narrow<size_t>(Y_shape.SizeFromDimension(SafeInt<size_t>(axis) + 1) / 2);
size_t Y_last_dim_size = (inverse && is_onesided) ? 1 : 2;
size_t Y_stride = onnxruntime::narrow<size_t>(Y_shape.SizeFromDimension(SafeInt<size_t>(axis) + 1) / Y_last_dim_size);
cumulative_packed_stride = total_dfts;
temp = i;
for (size_t r = 0; r < batch_and_signal_rank; r++) {
Expand All @@ -359,15 +415,15 @@ static Status discrete_fourier_transform(OpKernelContext* ctx, const Tensor* X,
cumulative_packed_stride /= onnxruntime::narrow<size_t>(X_shape[r]);
auto index = temp / cumulative_packed_stride;
temp -= (index * cumulative_packed_stride);
Y_offset += index * SafeInt<size_t>(Y_shape.SizeFromDimension(r + 1)) / 2;
Y_offset += index * SafeInt<size_t>(Y_shape.SizeFromDimension(r + 1)) / Y_last_dim_size;
}

if (is_power_of_2(onnxruntime::narrow<size_t>(dft_length))) {
ORT_RETURN_IF_ERROR((fft_radix2<T, U>(ctx, X, Y, X_offset, X_stride, Y_offset, Y_stride, axis, onnxruntime::narrow<size_t>(dft_length), window,
is_onesided, inverse, V, temp_output)));
} else {
ORT_RETURN_IF_ERROR(
(dft_bluestein_z_chirp<T, U>(ctx, X, Y, b_fft, chirp, X_offset, X_stride, Y_offset, Y_stride, axis, onnxruntime::narrow<size_t>(dft_length), window, inverse, V, temp_output)));
(dft_bluestein_z_chirp<T, U>(ctx, X, Y, b_fft, chirp, X_offset, X_stride, Y_offset, Y_stride, axis, onnxruntime::narrow<size_t>(dft_length), window, is_onesided, inverse, V, temp_output)));
}
}

Expand All @@ -383,7 +439,16 @@ static Status discrete_fourier_transform(OpKernelContext* ctx, int64_t axis, boo
const auto is_complex_valued = is_complex_valued_signal(X_shape);
axis = HandleNegativeAxis(axis, X_shape.NumDimensions());

// Validate input for IRFFT
if (inverse && is_onesided) {
ORT_RETURN_IF(!is_complex_valued,
"Inverse one-sided DFT (IRFFT) requires complex-valued input (last dimension must be 2)");
}

int64_t number_of_samples = static_cast<int64_t>(X_shape[onnxruntime::narrow<size_t>(axis)]);
if (inverse && is_onesided) {
number_of_samples = (number_of_samples - 1) << 1;
}
if (dft_length) {
const auto& dft_length_shape = dft_length->Shape();
ORT_RETURN_IF(!dft_length_shape.IsScalar(), "dft_length must be a scalar value.");
Expand All @@ -393,13 +458,17 @@ static Status discrete_fourier_transform(OpKernelContext* ctx, int64_t axis, boo

// Get the DFT output size. Onesided will return only the unique values!
// note: x >> 1 === std::floor(x / 2.f)
auto dft_output_size = is_onesided ? ((number_of_samples >> 1) + 1) : number_of_samples;
// For IRFFT (inverse && onesided), output is full-size real signal
// For RFFT (!inverse && onesided), output is one-sided complex spectrum
auto dft_output_size = (is_onesided && !inverse) ? ((number_of_samples >> 1) + 1) : number_of_samples;

// Get output shape
auto Y_shape = onnxruntime::TensorShape(X_shape);
if (X_shape.NumDimensions() == 2) {
Y_shape = onnxruntime::TensorShape({X_shape[0], dft_output_size, 2});
if (inverse && is_onesided) {
// IRFFT: output is real-valued
Y_shape[Y_shape.NumDimensions() - 1] = 1; // Real output
} else {
// Complex output
Y_shape[Y_shape.NumDimensions() - 1] = 2;
}
Y_shape[onnxruntime::narrow<size_t>(axis)] = dft_output_size;
Expand Down
Loading
Loading