forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathSpectralOps.cpp
364 lines (319 loc) · 13.9 KB
/
SpectralOps.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/native/SpectralOpsUtils.h>
#include <ATen/Config.h>
#if !AT_MKL_ENABLED()
namespace at { namespace native {
REGISTER_NO_CPU_DISPATCH(fft_fill_with_conjugate_symmetry_stub, fft_fill_with_conjugate_symmetry_fn);
Tensor _fft_mkl(const Tensor& input, int64_t signal_ndim,
bool complex_input, bool complex_output,
bool inverse, IntArrayRef checked_signal_sizes,
int64_t normalization, bool onesided,
IntArrayRef output_sizes) {
AT_ERROR("fft: ATen not compiled with MKL support");
}
}}
#else // AT_MKL_ENABLED
#include <ATen/ATen.h>
#include <ATen/Config.h>
#include <ATen/Dispatch.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Parallel.h>
#include <ATen/Utils.h>
#include <ATen/native/TensorIterator.h>
#include <algorithm>
#include <vector>
#include <numeric>
#include <cmath>
#include <mkl_dfti.h>
#include <ATen/mkl/Exceptions.h>
#include <ATen/mkl/Descriptors.h>
#include <ATen/mkl/Limits.h>
namespace at { namespace native {
// In real-to-complex transform, MKL FFT only fills half of the values due to
// conjugate symmetry. See native/SpectralUtils.h for more details.
// The following structs are used to fill in the other half with symmetry in
// case of real-to-complex transform with onesided=False flag.
// See NOTE [ Fourier Transform Conjugate Symmetry ] in native/SpectralOpsUtils.h.
template <typename scalar_t>
static __ubsan_ignore_undefined__ // UBSAN gives false positives on using negative indexes with a pointer
void _fft_fill_with_conjugate_symmetry_slice(
Range range, at::ArrayRef<bool> is_mirrored_dim, IntArrayRef signal_half_sizes,
IntArrayRef in_strides, const scalar_t * in_ptr,
IntArrayRef out_strides, scalar_t * out_ptr) {
const auto ndim = signal_half_sizes.size();
DimVector iter_index(ndim, 0);
// We explicitly loop over one row, then use this lambda to iterate over
// n-dimensions. This advances iter_index by one row, while updating in_ptr
// and out_ptr to point to the new row of data.
auto advance_index = [&] {
for (size_t i = 1; i < iter_index.size(); ++i) {
if (iter_index[i] + 1 < signal_half_sizes[i]) {
++iter_index[i];
in_ptr += in_strides[i];
if (is_mirrored_dim[i]) {
if (iter_index[i] == 1) {
out_ptr += (signal_half_sizes[i] - 1) * out_strides[i];
} else {
out_ptr -= out_strides[i];
}
} else {
out_ptr += out_strides[i];
}
return;
}
in_ptr -= in_strides[i] * iter_index[i];
if (is_mirrored_dim[i]) {
out_ptr -= out_strides[i];
} else {
out_ptr -= out_strides[i] * iter_index[i];
}
iter_index[i] = 0;
}
};
// The data slice we operate on may start part-way into the data
// Update iter_index and pointers to reference the start of the slice
if (range.begin > 0) {
iter_index[0] = range.begin % signal_half_sizes[0];
auto linear_idx = range.begin / signal_half_sizes[0];
for (size_t i = 1; i < ndim && linear_idx > 0; ++i) {
iter_index[i] = linear_idx % signal_half_sizes[i];
linear_idx = linear_idx / signal_half_sizes[i];
if (iter_index[i] > 0) {
in_ptr += in_strides[i] * iter_index[i];
if (is_mirrored_dim[i]) {
out_ptr += out_strides[i] * (signal_half_sizes[i] - iter_index[i]);
} else {
out_ptr += out_strides[i] * iter_index[i];
}
}
}
}
auto numel_remaining = range.end - range.begin;
if (is_mirrored_dim[0]) {
// Explicitly loop over a Hermitian mirrored dimension
if (iter_index[0] > 0) {
auto end = std::min(signal_half_sizes[0], iter_index[0] + numel_remaining);
for (int64_t i = iter_index[0]; i < end; ++i) {
out_ptr[(signal_half_sizes[0] - i) * out_strides[0]] = std::conj(in_ptr[i * in_strides[0]]);
}
numel_remaining -= (end - iter_index[0]);
iter_index[0] = 0;
advance_index();
}
while (numel_remaining > 0) {
auto end = std::min(signal_half_sizes[0], numel_remaining);
out_ptr[0] = std::conj(in_ptr[0]);
for (int64_t i = 1; i < end; ++i) {
out_ptr[(signal_half_sizes[0] - i) * out_strides[0]] = std::conj(in_ptr[i * in_strides[0]]);
}
numel_remaining -= end;
advance_index();
}
} else {
// Explicit loop over a non-mirrored dimension, so just a simple conjugated copy
while (numel_remaining > 0) {
auto end = std::min(signal_half_sizes[0], iter_index[0] + numel_remaining);
for (int64_t i = iter_index[0]; i != end; ++i) {
out_ptr[i * out_strides[0]] = std::conj(in_ptr[i * in_strides[0]]);
}
numel_remaining -= (end - iter_index[0]);
iter_index[0] = 0;
advance_index();
}
}
}
static void _fft_fill_with_conjugate_symmetry_cpu_(
ScalarType dtype, IntArrayRef mirror_dims, IntArrayRef signal_half_sizes,
IntArrayRef in_strides_bytes, const void * in_data,
IntArrayRef out_strides_bytes, void * out_data) {
// Convert strides from bytes to elements
const auto element_size = scalarTypeToTypeMeta(dtype).itemsize();
const auto ndim = signal_half_sizes.size();
DimVector in_strides(ndim), out_strides(ndim);
for (int64_t i = 0; i < ndim; ++i) {
TORCH_INTERNAL_ASSERT(in_strides_bytes[i] % element_size == 0);
in_strides[i] = in_strides_bytes[i] / element_size;
TORCH_INTERNAL_ASSERT(out_strides_bytes[i] % element_size == 0);
out_strides[i] = out_strides_bytes[i] / element_size;
}
// Construct boolean mask for mirrored dims
c10::SmallVector<bool, at::kDimVectorStaticSize> is_mirrored_dim(ndim, false);
for (const auto& dim : mirror_dims) {
is_mirrored_dim[dim] = true;
}
const auto numel = at::prod_intlist(signal_half_sizes);
AT_DISPATCH_COMPLEX_TYPES(dtype, "_fft_fill_with_conjugate_symmetry", [&] {
at::parallel_for(0, numel, at::internal::GRAIN_SIZE,
[&](int64_t begin, int64_t end) {
_fft_fill_with_conjugate_symmetry_slice(
{begin, end}, is_mirrored_dim, signal_half_sizes,
in_strides, static_cast<const scalar_t*>(in_data),
out_strides, static_cast<scalar_t*>(out_data));
});
});
}
// Register this one implementation for all cpu types instead of compiling multiple times
REGISTER_ARCH_DISPATCH(fft_fill_with_conjugate_symmetry_stub, DEFAULT, &_fft_fill_with_conjugate_symmetry_cpu_)
REGISTER_AVX_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
REGISTER_AVX2_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
// Constructs an mkl-fft plan descriptor representing the desired transform
// For complex types, strides are in units of 2 * element_size(dtype)
// sizes are for the full signal, including batch size and always two-sided
static DftiDescriptor _plan_mkl_fft(
IntArrayRef in_strides, IntArrayRef out_strides, IntArrayRef sizes,
bool complex_input, bool complex_output,
int64_t normalization, bool forward, ScalarType dtype) {
const int64_t signal_ndim = sizes.size() - 1;
TORCH_INTERNAL_ASSERT(in_strides.size() == sizes.size());
TORCH_INTERNAL_ASSERT(out_strides.size() == sizes.size());
// precision
const DFTI_CONFIG_VALUE prec = [&]{
switch (c10::toValueType(dtype)) {
case ScalarType::Float: return DFTI_SINGLE;
case ScalarType::Double: return DFTI_DOUBLE;
default: TORCH_CHECK(false, "MKL FFT doesn't support tensors of type: ", dtype);
}
}();
// signal type
const DFTI_CONFIG_VALUE signal_type = [&]{
if (forward) {
return complex_input ? DFTI_COMPLEX : DFTI_REAL;
} else {
return complex_output ? DFTI_COMPLEX : DFTI_REAL;
}
}();
// create descriptor with signal size
using MklDimVector = c10::SmallVector<MKL_LONG, at::kDimVectorStaticSize>;
MklDimVector mkl_signal_sizes(sizes.begin() + 1, sizes.end());
DftiDescriptor descriptor;
descriptor.init(prec, signal_type, signal_ndim, mkl_signal_sizes.data());
// out of place FFT
MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_PLACEMENT, DFTI_NOT_INPLACE));
// batch mode
MKL_LONG mkl_batch_size = sizes[0];
MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_NUMBER_OF_TRANSFORMS, mkl_batch_size));
// batch dim stride, i.e., dist between each data
TORCH_CHECK(in_strides[0] <= MKL_LONG_MAX && out_strides[0] <= MKL_LONG_MAX);
MKL_LONG idist = in_strides[0];
MKL_LONG odist = out_strides[0];
MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_INPUT_DISTANCE, idist));
MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_OUTPUT_DISTANCE, odist));
// signal strides
// first val is offset, set to zero (ignored)
MklDimVector mkl_istrides(1 + signal_ndim, 0), mkl_ostrides(1 + signal_ndim, 0);
for (int64_t i = 1; i <= signal_ndim; i++) {
TORCH_CHECK(in_strides[i] <= MKL_LONG_MAX && out_strides[i] <= MKL_LONG_MAX);
mkl_istrides[i] = in_strides[i];
mkl_ostrides[i] = out_strides[i];
}
MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_INPUT_STRIDES, mkl_istrides.data()));
MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_OUTPUT_STRIDES, mkl_ostrides.data()));
// if conjugate domain of real is involved, set standard CCE storage type
// this will become default in MKL in future
if (!complex_input || !complex_output) {
MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_CONJUGATE_EVEN_STORAGE, DFTI_COMPLEX_COMPLEX));
}
// rescale if requested
const auto norm = static_cast<fft_norm_mode>(normalization);
int64_t signal_numel = at::prod_intlist(IntArrayRef(sizes.data() + 1, signal_ndim));
if (norm != fft_norm_mode::none) {
const double scale = (
(norm == fft_norm_mode::by_root_n) ?
1.0 / std::sqrt(static_cast<double>(signal_numel)) :
1.0 / static_cast<double>(signal_numel));
const auto scale_direction = forward ? DFTI_FORWARD_SCALE : DFTI_BACKWARD_SCALE;
MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), scale_direction, scale));
}
if (sizeof(MKL_LONG) < sizeof(int64_t)) {
TORCH_CHECK(signal_numel <= MKL_LONG_MAX,
"MKL FFT: input signal numel exceeds allowed range [1, ", MKL_LONG_MAX, "]");
}
// finalize
MKL_DFTI_CHECK(DftiCommitDescriptor(descriptor.get()));
return descriptor;
}
// MKL DFTI
Tensor _fft_mkl(const Tensor& self, int64_t signal_ndim,
bool complex_input, bool complex_output,
bool inverse, IntArrayRef checked_signal_sizes,
int64_t normalization, bool onesided,
IntArrayRef output_sizes) {
Tensor input = self;
bool need_contiguous = false;
// real/imag dimension must aligned when viewed as of complex type
if (complex_input) {
need_contiguous |= input.stride(-1) != 1;
for (int64_t i = 0; !need_contiguous && i <= signal_ndim; i++) {
need_contiguous |= input.stride(i) % 2 != 0;
}
}
// check if we can use MKL because MKL_LONG is 32bit on some OS, e.g. Windows
// need to check input and output size and strides
// be careful about complex domain, where the stride needs to be divided by 2
// only need to test upper bound MKL_LONG_MAX as these values are non-negative
if (sizeof(MKL_LONG) < sizeof(int64_t)) {
int64_t inumel = 1 /* istride if we contiguous-fy */, onumel = 1;
int64_t isize, osize, istride, ostride;
for (int64_t i = signal_ndim; i >= 0; i--) {
isize = input.size(i);
osize = output_sizes[i];
istride = complex_input ? input.stride(i) >> 1 : input.stride(i);
ostride = onumel;
TORCH_CHECK(isize <= MKL_LONG_MAX && osize <= MKL_LONG_MAX && ostride <= MKL_LONG_MAX,
"MKL FFT: input signal numel exceeds allowed range [1 ~ ", MKL_LONG_MAX, "]");
if (!need_contiguous && istride > MKL_LONG_MAX) {
// If we didn't plan to contiguous-fy but the `istride` exceeds bound,
// check if we can stride (equal to `inumel`) get back within bound if
// we contiguous-fy. If so, then we need to always check `inumel`
// instead for the remaining iterations. The iterations before this are
// fine as `inumel` is non-decreasing.
need_contiguous = true;
}
TORCH_CHECK(!need_contiguous || inumel <= MKL_LONG_MAX,
"MKL FFT: input signal numel exceeds allowed range [1 ~ ", MKL_LONG_MAX, "]");
inumel *= isize;
onumel *= osize;
}
}
if (need_contiguous) {
input = input.contiguous();
}
Tensor output = at::empty(output_sizes, input.options());
DimVector full_sizes(signal_ndim + 1);
full_sizes[0] = self.size(0);
std::copy(checked_signal_sizes.cbegin(), checked_signal_sizes.cend(), full_sizes.begin() + 1);
// If "complex" is true, convert strides from complex viewed as real to complex strides.
// Otherwise, returns a copy of strides if "complex" is false.
auto convert_strides = [signal_ndim](IntArrayRef strides, bool complex) {
DimVector res(signal_ndim + 1);
if (complex) {
for (int64_t i = 0; i < res.size(); ++i) {
res[i] = strides[i] / 2;
}
} else {
res.assign(strides.cbegin(), strides.cend());
}
return res;
};
const auto in_strides = convert_strides(input.strides(), complex_input);
const auto out_strides = convert_strides(output.strides(), complex_output);
auto descriptor = _plan_mkl_fft(
in_strides, out_strides, full_sizes, complex_input, complex_output,
normalization, !inverse, input.scalar_type());
if (inverse) {
MKL_DFTI_CHECK(DftiComputeBackward(descriptor.get(), input.data_ptr(), output.data_ptr()));
} else {
MKL_DFTI_CHECK(DftiComputeForward(descriptor.get(), input.data_ptr(), output.data_ptr()));
}
// now if needed, fill out the other half using Hermitian symmetry dim
if (!complex_input && complex_output && !onesided) {
DimVector signal_dims(signal_ndim);
std::iota(signal_dims.begin(), signal_dims.end(), 1);
auto out_as_complex = at::view_as_complex(output);
at::native::_fft_fill_with_conjugate_symmetry_(out_as_complex, signal_dims);
}
return output;
}
}} // namespace at::native
#endif