diff --git a/icicle_v3/backend/cpu/include/cpu_ntt.h b/icicle_v3/backend/cpu/include/cpu_ntt.h index 839936d37..b20903885 100644 --- a/icicle_v3/backend/cpu/include/cpu_ntt.h +++ b/icicle_v3/backend/cpu/include/cpu_ntt.h @@ -6,6 +6,7 @@ #include "icicle/fields/field_config.h" #include "icicle/vec_ops.h" +#include #include #include #include @@ -19,6 +20,14 @@ using namespace icicle; namespace ntt_cpu { + // TODO SHANIE - after implementing real parallelism, try different sizes to choose the optimal one. Or consider using + // a function to calculate subset sizes + constexpr uint32_t layers_subntt_log_size[31][3] = { + {0, 0, 0}, {1, 0, 0}, {2, 0, 0}, {3, 0, 0}, {4, 0, 0}, {5, 0, 0}, {3, 3, 0}, {4, 3, 0}, + {4, 4, 0}, {5, 4, 0}, {5, 5, 0}, {4, 4, 3}, {4, 4, 4}, {5, 4, 4}, {5, 5, 4}, {5, 5, 5}, + {8, 8, 0}, {9, 8, 0}, {9, 9, 0}, {10, 9, 0}, {10, 10, 0}, {11, 10, 0}, {11, 11, 0}, {12, 11, 0}, + {12, 12, 0}, {13, 12, 0}, {13, 13, 0}, {14, 13, 0}, {14, 14, 0}, {15, 14, 0}, {15, 15, 0}}; + template class CpuNttDomain { @@ -28,6 +37,8 @@ namespace ntt_cpu { std::mutex domain_mutex; public: + std::unordered_map coset_index = {}; + static eIcicleError cpu_ntt_init_domain(const Device& device, const S& primitive_root, const NTTInitDomainConfig& config); static eIcicleError cpu_ntt_release_domain(const Device& device); @@ -92,6 +103,7 @@ namespace ntt_cpu { temp_twiddles[0] = S::one(); for (int i = 1; i <= s_ntt_domain.max_size; i++) { temp_twiddles[i] = temp_twiddles[i - 1] * tw_omega; + s_ntt_domain.coset_index[temp_twiddles[i]] = i; } s_ntt_domain.twiddles = std::move(temp_twiddles); // Assign twiddles using unique_ptr } @@ -130,36 +142,101 @@ namespace ntt_cpu { return rev; } + inline uint64_t idx_in_mem( + int element, int block_idx, int subntt_idx, const std::vector layers_sntt_log_size = {}, int layer = 0) + { + int s0 = layers_sntt_log_size[0]; + int s1 = layers_sntt_log_size[1]; + int s2 = layers_sntt_log_size[2]; + switch (layer) { + case 0: + return block_idx + ((subntt_idx + (element << s1)) << s2); + case 1: + return block_idx + ((element + (subntt_idx << s1)) << s2); + case 2: + return ((block_idx << (s1 + s2)) & ((1 << (s0 + s1 + s2)) - 1)) + + (((block_idx << (s1 + s2)) >> (s0 + s1 + s2)) << s2) + element; + default: + ICICLE_ASSERT(false) << "Unsupported layer"; + } + return -1; + } + template - eIcicleError reorder_by_bit_reverse(int logn, E* output, int batch_size) + eIcicleError reorder_by_bit_reverse( + int log_original_size, + E* elements, + int batch_size, + bool columns_batch, + int block_idx = 0, + int subntt_idx = 0, + std::vector layers_sntt_log_size = {}, + int layer = 0) { - uint64_t size = 1 << logn; + uint64_t subntt_size = (layers_sntt_log_size.empty()) ? 1 << log_original_size : 1 << layers_sntt_log_size[layer]; + int subntt_log_size = (layers_sntt_log_size.empty()) ? log_original_size : layers_sntt_log_size[layer]; + uint64_t original_size = (1 << log_original_size); + int stride = columns_batch ? batch_size : 1; for (int batch = 0; batch < batch_size; ++batch) { - E* current_output = output + batch * size; - int rev; - for (int i = 0; i < size; ++i) { - rev = bit_reverse(i, logn); - if (i < rev) { std::swap(current_output[i], current_output[rev]); } + E* current_elements = columns_batch ? elements + batch : elements + batch * original_size; + uint64_t rev; + uint64_t i_mem_idx; + uint64_t rev_mem_idx; + for (uint64_t i = 0; i < subntt_size; ++i) { + rev = bit_reverse(i, subntt_log_size); + if (!layers_sntt_log_size.empty()) { + i_mem_idx = idx_in_mem(i, block_idx, subntt_idx, layers_sntt_log_size, layer); + rev_mem_idx = idx_in_mem(rev, block_idx, subntt_idx, layers_sntt_log_size, layer); + } else { + i_mem_idx = i; + rev_mem_idx = rev; + } + if (i < rev) { + if (i_mem_idx < original_size && rev_mem_idx < original_size) { // Ensure indices are within bounds + std::swap(current_elements[stride * i_mem_idx], current_elements[stride * rev_mem_idx]); + } else { + // Handle out-of-bounds error + ICICLE_LOG_ERROR << "i=" << i << ", rev=" << rev << ", original_size=" << original_size; + ICICLE_LOG_ERROR << "Index out of bounds: i_mem_idx=" << i_mem_idx << ", rev_mem_idx=" << rev_mem_idx; + return eIcicleError::INVALID_ARGUMENT; + } + } } } return eIcicleError::SUCCESS; } template - void dit_ntt(E* elements, uint64_t size, int batch_size, const S* twiddles, NTTDir dir, int domain_max_size) + void dit_ntt( + E* elements, + uint64_t total_ntt_size, + int batch_size, + bool columns_batch, + const S* twiddles, + NTTDir dir, + int domain_max_size, + int block_idx = 0, + int subntt_idx = 0, + std::vector layers_sntt_log_size = {}, + int layer = 0) // R --> N { + uint64_t subntt_size = 1 << layers_sntt_log_size[layer]; + int stride = columns_batch ? batch_size : 1; for (int batch = 0; batch < batch_size; ++batch) { - E* current_elements = elements + batch * size; - for (int len = 2; len <= size; len <<= 1) { + E* current_elements = columns_batch ? elements + batch : elements + batch * total_ntt_size; + for (int len = 2; len <= subntt_size; len <<= 1) { int half_len = len / 2; - int step = (size / len) * (domain_max_size / size); - for (int i = 0; i < size; i += len) { + int step = (subntt_size / len) * (domain_max_size / subntt_size); + for (int i = 0; i < subntt_size; i += len) { for (int j = 0; j < half_len; ++j) { int tw_idx = (dir == NTTDir::kForward) ? j * step : domain_max_size - j * step; - E u = current_elements[i + j]; - E v = current_elements[i + j + half_len] * twiddles[tw_idx]; - current_elements[i + j] = u + v; - current_elements[i + j + half_len] = u - v; + uint64_t u_mem_idx = stride * idx_in_mem(i + j, block_idx, subntt_idx, layers_sntt_log_size, layer); + uint64_t v_mem_idx = + stride * idx_in_mem(i + j + half_len, block_idx, subntt_idx, layers_sntt_log_size, layer); + E u = current_elements[u_mem_idx]; + E v = current_elements[v_mem_idx] * twiddles[tw_idx]; + current_elements[u_mem_idx] = u + v; + current_elements[v_mem_idx] = u - v; } } } @@ -167,53 +244,61 @@ namespace ntt_cpu { } template - void dif_ntt(E* elements, uint64_t size, int batch_size, const S* twiddles, NTTDir dir, int domain_max_size) + void dif_ntt( + E* elements, + uint64_t total_ntt_size, + int batch_size, + bool columns_batch, + const S* twiddles, + NTTDir dir, + int domain_max_size, + int block_idx = 0, + int subntt_idx = 0, + std::vector layers_sntt_log_size = {}, + int layer = 0) { + uint64_t subntt_size = 1 << layers_sntt_log_size[layer]; + int stride = columns_batch ? batch_size : 1; for (int batch = 0; batch < batch_size; ++batch) { - E* current_elements = elements + batch * size; - for (int len = size; len >= 2; len >>= 1) { + E* current_elements = columns_batch ? elements + batch : elements + batch * total_ntt_size; + for (int len = subntt_size; len >= 2; len >>= 1) { int half_len = len / 2; - int step = (size / len) * (domain_max_size / size); - for (int i = 0; i < size; i += len) { + int step = (subntt_size / len) * (domain_max_size / subntt_size); + for (int i = 0; i < subntt_size; i += len) { for (int j = 0; j < half_len; ++j) { int tw_idx = (dir == NTTDir::kForward) ? j * step : domain_max_size - j * step; - E u = current_elements[i + j]; - E v = current_elements[i + j + half_len]; - current_elements[i + j] = u + v; - current_elements[i + j + half_len] = (u - v) * twiddles[tw_idx]; + uint64_t u_mem_idx = stride * idx_in_mem(i + j, block_idx, subntt_idx, layers_sntt_log_size, layer); + uint64_t v_mem_idx = + stride * idx_in_mem(i + j + half_len, block_idx, subntt_idx, layers_sntt_log_size, layer); + E u = current_elements[u_mem_idx]; + E v = current_elements[v_mem_idx]; + current_elements[u_mem_idx] = u + v; + current_elements[v_mem_idx] = (u - v) * twiddles[tw_idx]; } } } } } - template - void transpose(const E* input, E* output, int rows, int cols) - { - for (int col = 0; col < cols; ++col) { - for (int row = 0; row < rows; ++row) { - output[col * rows + row] = input[row * cols + col]; - } - } - } - template eIcicleError coset_mul( int logn, int domain_max_size, E* elements, int batch_size, + bool columns_batch, const S* twiddles = nullptr, int stride = 0, const std::unique_ptr& arbitrary_coset = nullptr, bool bit_rev = false, - NTTDir dir = NTTDir::kForward, - bool columns_batch = false) + NTTDir dir = NTTDir::kForward) { uint64_t size = 1 << logn; + uint64_t i_mem_idx; int idx; + int batch_stride = columns_batch ? batch_size : 1; for (int batch = 0; batch < batch_size; ++batch) { - E* current_elements = elements + batch * size; + E* current_elements = columns_batch ? elements + batch : elements + batch * size; if (arbitrary_coset) { for (int i = 1; i < size; ++i) { idx = columns_batch ? batch : i; @@ -224,7 +309,7 @@ namespace ntt_cpu { for (int i = 1; i < size; ++i) { idx = bit_rev ? stride * (bit_reverse(i, logn)) : stride * i; idx = dir == NTTDir::kForward ? idx : domain_max_size - idx; - current_elements[i] = current_elements[i] * twiddles[idx]; + current_elements[batch_stride * i] = current_elements[batch_stride * i] * twiddles[idx]; } } } @@ -232,49 +317,318 @@ namespace ntt_cpu { } template - eIcicleError - cpu_ntt_ref(const Device& device, const E* input, uint64_t size, NTTDir dir, NTTConfig& config, E* output) + void refactor_and_reorder( + E* layer_output, + E* next_layer_input, + const S* twiddles, + int batch_size, + bool columns_batch, + int domain_max_size, + std::vector layers_sntt_log_size = {}, + int layer = 0, + icicle::NTTDir dir = icicle::NTTDir::kForward) { - if (size & (size - 1)) { - ICICLE_LOG_ERROR << "Size must be a power of 2. Size = " << size; - return eIcicleError::INVALID_ARGUMENT; + int sntt_size = 1 << layers_sntt_log_size[1]; + int nof_sntts = 1 << layers_sntt_log_size[0]; + int ntt_size = 1 << (layers_sntt_log_size[0] + layers_sntt_log_size[1]); + auto temp_elements = + std::make_unique(ntt_size * batch_size); // TODO shanie - consider using an algorithm for sorting in-place + int stride = columns_batch ? batch_size : 1; + for (int batch = 0; batch < batch_size; ++batch) { + E* cur_layer_output = columns_batch ? layer_output + batch : layer_output + batch * ntt_size; + E* cur_temp_elements = columns_batch ? temp_elements.get() + batch : temp_elements.get() + batch * ntt_size; + for (int sntt_idx = 0; sntt_idx < nof_sntts; sntt_idx++) { + for (int elem = 0; elem < sntt_size; elem++) { + uint64_t tw_idx = (dir == NTTDir::kForward) + ? ((domain_max_size / ntt_size) * sntt_idx * elem) + : domain_max_size - ((domain_max_size / ntt_size) * sntt_idx * elem); + cur_temp_elements[stride * (sntt_idx * sntt_size + elem)] = + cur_layer_output[stride * (elem * nof_sntts + sntt_idx)] * twiddles[tw_idx]; + } + } } + std::copy(temp_elements.get(), temp_elements.get() + ntt_size * batch_size, next_layer_input); + } - // Copy input to "temp_elements" instead of pointing temp_elements to input to ensure freeing temp_elements does not - // free the input, preventing a potential double-free error. - // TODO [SHANIE]: Later, remove temp_elements and perform all calculations in-place - // (implement NTT for the case where columns_batch=true, in-place). + template + void refactor_output( + E* layer_output, + E* next_layer_input, + uint64_t tot_ntt_size, + int batch_size, + bool columns_batch, + const S* twiddles, + int domain_max_size, + std::vector layers_sntt_log_size = {}, + int layer = 0, + icicle::NTTDir dir = icicle::NTTDir::kForward) + { + int subntt_size = 1 << layers_sntt_log_size[0]; + int nof_subntts = 1 << layers_sntt_log_size[1]; + int nof_blocks = 1 << layers_sntt_log_size[2]; + int i, j; + int ntt_size = layer == 0 ? 1 << (layers_sntt_log_size[0] + layers_sntt_log_size[1]) + : 1 << (layers_sntt_log_size[0] + layers_sntt_log_size[1] + layers_sntt_log_size[2]); + int stride = columns_batch ? batch_size : 1; + for (int batch = 0; batch < batch_size; ++batch) { + E* current_layer_output = columns_batch ? layer_output + batch : layer_output + batch * tot_ntt_size; + E* current_next_layer_input = columns_batch ? next_layer_input + batch : next_layer_input + batch * tot_ntt_size; + for (int block_idx = 0; block_idx < nof_blocks; block_idx++) { + for (int sntt_idx = 0; sntt_idx < nof_subntts; sntt_idx++) { + for (int elem = 0; elem < subntt_size; elem++) { + uint64_t elem_mem_idx = stride * idx_in_mem(elem, block_idx, sntt_idx, layers_sntt_log_size, 0); + i = (layer == 0) ? elem : elem + sntt_idx * subntt_size; + j = (layer == 0) ? sntt_idx : block_idx; + uint64_t tw_idx = (dir == NTTDir::kForward) ? ((domain_max_size / ntt_size) * j * i) + : domain_max_size - ((domain_max_size / ntt_size) * j * i); + current_next_layer_input[elem_mem_idx] = current_layer_output[elem_mem_idx] * twiddles[tw_idx]; + } + } + } + } + } - const uint64_t total_size = size * config.batch_size; - auto temp_elements = std::make_unique(total_size); - auto vec_ops_config = default_vec_ops_config(); - if (config.columns_batch) { - transpose(input, temp_elements.get(), size, config.batch_size); - } else { - std::copy(input, input + total_size, temp_elements.get()); + template + eIcicleError reorder_input( + E* input, uint64_t size, int batch_size, bool columns_batch, const std::vector layers_sntt_log_size = {}) + { // TODO shanie future - consider using an algorithm for efficient reordering + if (layers_sntt_log_size.empty()) { + ICICLE_LOG_ERROR << "layers_sntt_log_size is null"; + return eIcicleError::INVALID_ARGUMENT; } - const int logn = int(log2(size)); + int stride = columns_batch ? batch_size : 1; + auto temp_input = std::make_unique(batch_size * size); + for (int batch = 0; batch < batch_size; ++batch) { + E* current_elements = columns_batch ? input + batch : input + batch * size; + E* current_temp_input = columns_batch ? temp_input.get() + batch : temp_input.get() + batch * size; + uint64_t idx = 0; + uint64_t new_idx = 0; + int cur_ntt_log_size = layers_sntt_log_size[0]; + int next_ntt_log_size = layers_sntt_log_size[1]; + for (int i = 0; i < size; i++) { + int subntt_idx = i >> cur_ntt_log_size; + int element = i & ((1 << cur_ntt_log_size) - 1); + new_idx = subntt_idx + (element << next_ntt_log_size); + current_temp_input[stride * i] = current_elements[stride * new_idx]; + } + } + std::copy(temp_input.get(), temp_input.get() + batch_size * size, input); + return eIcicleError::SUCCESS; + } + + template + eIcicleError reorder_output( + E* output, + uint64_t size, + const std::vector layers_sntt_log_size = {}, + int batch_size = 1, + bool columns_batch = 0) + { // TODO shanie future - consider using an algorithm for efficient reordering + if (layers_sntt_log_size.empty()) { + ICICLE_LOG_ERROR << "layers_sntt_log_size is null"; + return eIcicleError::INVALID_ARGUMENT; + } + int temp_output_size = columns_batch ? size * batch_size : size; + auto temp_output = std::make_unique(temp_output_size); + uint64_t idx = 0; + uint64_t mem_idx = 0; + uint64_t new_idx = 0; + int subntt_idx; + int element; + int s0 = layers_sntt_log_size[0]; + int s1 = layers_sntt_log_size[1]; + int s2 = layers_sntt_log_size[2]; + int p0, p1, p2; + int stride = columns_batch ? batch_size : 1; + int rep = columns_batch ? batch_size : 1; + for (int batch = 0; batch < rep; ++batch) { + E* current_elements = + columns_batch + ? output + batch + : output; // if columns_batch=false, then output is already shifted by batch*size when calling the function + E* current_temp_output = columns_batch ? temp_output.get() + batch : temp_output.get(); + for (int i = 0; i < size; i++) { + if (layers_sntt_log_size[2]) { + p0 = (i >> (s1 + s2)); + p1 = (((i >> s2) & ((1 << (s1)) - 1)) << s0); + p2 = ((i & ((1 << s2) - 1)) << (s0 + s1)); + new_idx = p0 + p1 + p2; + current_temp_output[stride * new_idx] = current_elements[stride * i]; + } else { + subntt_idx = i >> s1; + element = i & ((1 << s1) - 1); + new_idx = subntt_idx + (element << s0); + current_temp_output[stride * new_idx] = current_elements[stride * i]; + } + } + } + std::copy(temp_output.get(), temp_output.get() + temp_output_size, output); + return eIcicleError::SUCCESS; + } + + template + eIcicleError cpu_ntt_basic( + const icicle::Device& device, + E* input, + uint64_t original_size, + icicle::NTTDir dir, + const icicle::NTTConfig& config, + E* output, + int block_idx = 0, + int subntt_idx = 0, + const std::vector layers_sntt_log_size = {}, + int layer = 0) + { + const uint64_t subntt_size = (1 << layers_sntt_log_size[layer]); + const uint64_t total_memory_size = original_size * config.batch_size; + const int log_original_size = int(log2(original_size)); const S* twiddles = CpuNttDomain::s_ntt_domain.get_twiddles(); const int domain_max_size = CpuNttDomain::s_ntt_domain.get_max_size(); - std::unique_ptr arbitrary_coset = nullptr; - int coset_stride = 0; - if (domain_max_size < size) { + if (domain_max_size < subntt_size) { ICICLE_LOG_ERROR << "NTT domain size is less than input size. Domain size = " << domain_max_size - << ", Input size = " << size; + << ", Input size = " << subntt_size; return eIcicleError::INVALID_ARGUMENT; } - if (config.coset_gen != S::one()) { // TODO SHANIE - implement more efficient way to find coset_stride - for (int i = 1; i <= domain_max_size; i++) { - if (twiddles[i] == config.coset_gen) { - coset_stride = i; - break; + bool dit = true; + bool input_rev = false; + bool output_rev = false; + // bool need_to_reorder = false; + bool need_to_reorder = true; + // switch (config.ordering) { // kNN, kNR, kRN, kRR, kNM, kMN + // case Ordering::kNN: //dit R --> N + // need_to_reorder = true; + // break; + // case Ordering::kNR: // dif N --> R + // case Ordering::kNM: // dif N --> R + // dit = false; + // output_rev = true; + // break; + // case Ordering::kRR: // dif N --> R + // input_rev = true; + // output_rev = true; + // need_to_reorder = true; + // dit = false; // dif + // break; + // case Ordering::kRN: //dit R --> N + // case Ordering::kMN: //dit R --> N + // input_rev = true; + // break; + // default: + // return eIcicleError::INVALID_ARGUMENT; + // } + + if (need_to_reorder) { + reorder_by_bit_reverse( + log_original_size, input, config.batch_size, config.columns_batch, block_idx, subntt_idx, layers_sntt_log_size, + layer); + } // TODO - check if access the fixed indexes instead of reordering may be more efficient? + + // NTT/INTT + if (dit) { + dit_ntt( + input, original_size, config.batch_size, config.columns_batch, twiddles, dir, domain_max_size, block_idx, + subntt_idx, layers_sntt_log_size, layer); // R --> N + } else { + dif_ntt( + input, original_size, config.batch_size, config.columns_batch, twiddles, dir, domain_max_size, block_idx, + subntt_idx, layers_sntt_log_size, layer); // N --> R + } + + return eIcicleError::SUCCESS; + } + + template + eIcicleError cpu_ntt_parallel( + const Device& device, + uint64_t size, + uint64_t original_size, + NTTDir dir, + const NTTConfig& config, + E* output, + const S* twiddles, + const int domain_max_size = 0) + { + const int logn = int(log2(size)); + std::vector layers_sntt_log_size( + std::begin(layers_subntt_log_size[logn]), std::end(layers_subntt_log_size[logn])); + // Assuming that NTT fits in the cache, so we split the NTT to layers and calculate them one after the other. + // Subntts inside the same laye are calculate in parallel. + // Sorting is not needed, since the elements needed for each subntt are close to each other in memory. + // Instead of sorting, we are using the function idx_in_mem to calculate the memory index of each element. + for (int layer = 0; layer < layers_sntt_log_size.size(); layer++) { + if (layer == 0) { + int log_nof_subntts = layers_sntt_log_size[1]; + int log_nof_blocks = layers_sntt_log_size[2]; + for (int block_idx = 0; block_idx < (1 << log_nof_blocks); block_idx++) { + for (int subntt_idx = 0; subntt_idx < (1 << log_nof_subntts); subntt_idx++) { + cpu_ntt_basic( + device, output, original_size, dir, config, output, block_idx, subntt_idx, layers_sntt_log_size, layer); + } + } + } + if (layer == 1 && layers_sntt_log_size[1]) { + int log_nof_subntts = layers_sntt_log_size[0]; + int log_nof_blocks = layers_sntt_log_size[2]; + for (int block_idx = 0; block_idx < (1 << log_nof_blocks); block_idx++) { + for (int subntt_idx = 0; subntt_idx < (1 << log_nof_subntts); subntt_idx++) { + cpu_ntt_basic( + device, output /*input*/, original_size, dir, config, output, block_idx, subntt_idx, layers_sntt_log_size, + layer); // input=output (in-place) + } + } + } + if (layer == 2 && layers_sntt_log_size[2]) { + int log_nof_blocks = layers_sntt_log_size[0] + layers_sntt_log_size[1]; + for (int block_idx = 0; block_idx < (1 << log_nof_blocks); block_idx++) { + cpu_ntt_basic( + device, output /*input*/, original_size, dir, config, output, block_idx, 0 /*subntt_idx - not used*/, + layers_sntt_log_size, layer); // input=output (in-place) + } + } + if (layer != 2 && layers_sntt_log_size[layer + 1] != 0) { + refactor_output( + output, output /*input for next layer*/, original_size, config.batch_size, config.columns_batch, twiddles, + domain_max_size, layers_sntt_log_size, layer, dir); + } + } + // Sort the output at the end so that elements will be in right order. + // TODO SHANIE - After implementing for different ordering, maybe this should be done in a different place + // - When implementing real parallelism, consider sorting in parallel and in-place + if (layers_sntt_log_size[1]) { // at least 2 layers + if (config.columns_batch) { + reorder_output(output, size, layers_sntt_log_size, config.batch_size, config.columns_batch); + } else { + for (int b = 0; b < config.batch_size; b++) { + reorder_output( + output + b * original_size, size, layers_sntt_log_size, config.batch_size, config.columns_batch); } } - if (coset_stride == 0) { // if the coset_gen is not found in the twiddles, calculate arbitrary coset + } + return eIcicleError::SUCCESS; + } + + template + eIcicleError cpu_ntt(const Device& device, const E* input, uint64_t size, NTTDir dir, NTTConfig& config, E* output) + { + if (size & (size - 1)) { + ICICLE_LOG_ERROR << "Size must be a power of 2. size = " << size; + return eIcicleError::INVALID_ARGUMENT; + } + const int logn = int(log2(size)); + const S* twiddles = CpuNttDomain::s_ntt_domain.get_twiddles(); + const int domain_max_size = CpuNttDomain::s_ntt_domain.get_max_size(); + + // TODO SHANIE - move to init domain + int coset_stride = 0; + std::unique_ptr arbitrary_coset = nullptr; + if (config.coset_gen != S::one()) { // TODO SHANIE - implement more efficient way to find coset_stride + try { + coset_stride = CpuNttDomain::s_ntt_domain.coset_index.at(config.coset_gen); + ICICLE_LOG_DEBUG << "Coset generator found in twiddles. coset_stride=" << coset_stride; + } catch (const std::out_of_range& oor) { ICICLE_LOG_DEBUG << "Coset generator not found in twiddles. Calculating arbitrary coset."; - auto temp_cosets = std::make_unique(domain_max_size + 1); arbitrary_coset = std::make_unique(domain_max_size + 1); arbitrary_coset[0] = S::one(); S coset_gen = dir == NTTDir::kForward ? config.coset_gen : S::inverse(config.coset_gen); // inverse for INTT @@ -284,74 +638,81 @@ namespace ntt_cpu { } } - bool dit = true; - bool input_rev = false; - bool output_rev = false; - bool need_to_reorder = false; - bool coset = (config.coset_gen != S::one() && dir == NTTDir::kForward); - switch (config.ordering) { // kNN, kNR, kRN, kRR, kNM, kMN - case Ordering::kNN: - need_to_reorder = true; - break; - case Ordering::kNR: - case Ordering::kNM: - dit = false; // dif - output_rev = true; - break; - case Ordering::kRR: - input_rev = true; - output_rev = true; - need_to_reorder = true; - dit = false; // dif - break; - case Ordering::kRN: - case Ordering::kMN: - input_rev = true; - break; - default: - return eIcicleError::INVALID_ARGUMENT; + std::copy(input, input + size * config.batch_size, output); + if (config.ordering == Ordering::kRN || config.ordering == Ordering::kRR) { + reorder_by_bit_reverse( + logn, output, config.batch_size, + config.columns_batch); // TODO - check if access the fixed indexes instead of reordering may be more efficient? } - if (coset) { + if (config.coset_gen != S::one() && dir == NTTDir::kForward) { + // bool input_rev = config.ordering == Ordering::kRR || config.ordering == Ordering::kMN || config.ordering == + // Ordering::kRN; + bool input_rev = false; coset_mul( - logn, domain_max_size, temp_elements.get(), config.batch_size, twiddles, coset_stride, arbitrary_coset, - input_rev); + logn, domain_max_size, output, config.batch_size, config.columns_batch, twiddles, coset_stride, arbitrary_coset, + input_rev, dir); } + std::vector layers_sntt_log_size( + std::begin(layers_subntt_log_size[logn]), std::end(layers_subntt_log_size[logn])); + + if (logn > 15) { + // TODO future - maybe can start 4'rth layer in parallel to 3'rd layer? + // Assuming that NTT doesn't fit in the cache, so we split the NTT to 2 layers and calculate them one after the + // other. Inside each layer each sub-NTT calculation is split to layers as well, and those are calculated in + // parallel. Sorting is done between the layers, so that the elements needed for each sunbtt are close to each + // other in memory. + + int stride = config.columns_batch ? config.batch_size : 1; + reorder_input(output, size, config.batch_size, config.columns_batch, layers_sntt_log_size); + for (int subntt_idx = 0; subntt_idx < (1 << layers_sntt_log_size[1]); subntt_idx++) { + E* current_elements = + output + stride * (subntt_idx << layers_sntt_log_size[0]); // output + subntt_idx * subntt_size + cpu_ntt_parallel( + device, (1 << layers_sntt_log_size[0]), size, dir, config, current_elements, twiddles, domain_max_size); + } + refactor_and_reorder( + output, output /*input for next layer*/, twiddles, config.batch_size, config.columns_batch, domain_max_size, + layers_sntt_log_size, 0 /*layer*/, dir); + for (int subntt_idx = 0; subntt_idx < (1 << layers_sntt_log_size[0]); subntt_idx++) { + E* current_elements = + output + stride * (subntt_idx << layers_sntt_log_size[1]); // output + subntt_idx * subntt_size + cpu_ntt_parallel( + device, (1 << layers_sntt_log_size[1]), size, dir, config, current_elements, twiddles, domain_max_size); + } + if (config.columns_batch) { + reorder_output(output, size, layers_sntt_log_size, config.batch_size, config.columns_batch); + } else { + for (int b = 0; b < config.batch_size; b++) { + reorder_output(output + b * size, size, layers_sntt_log_size, config.batch_size, config.columns_batch); + } + } - if (need_to_reorder) { reorder_by_bit_reverse(logn, temp_elements.get(), config.batch_size); } - - // NTT/INTT - if (dit) { - dit_ntt(temp_elements.get(), size, config.batch_size, twiddles, dir, domain_max_size); } else { - dif_ntt(temp_elements.get(), size, config.batch_size, twiddles, dir, domain_max_size); + cpu_ntt_parallel(device, size, size, dir, config, output, twiddles, domain_max_size); } - if (dir == NTTDir::kInverse) { - // Normalize results + if (dir == NTTDir::kInverse) { // TODO SHANIE - do that in parallel S inv_size = S::inv_log_size(logn); - for (int i = 0; i < total_size; ++i) { - temp_elements[i] = temp_elements[i] * inv_size; + for (uint64_t i = 0; i < size * config.batch_size; ++i) { + output[i] = output[i] * inv_size; } if (config.coset_gen != S::one()) { + // bool output_rev = config.ordering == Ordering::kNR || config.ordering == Ordering::kNM || config.ordering == + // Ordering::kRR; + bool output_rev = false; coset_mul( - logn, domain_max_size, temp_elements.get(), config.batch_size, twiddles, coset_stride, arbitrary_coset, - output_rev, dir); + logn, domain_max_size, output, config.batch_size, config.columns_batch, twiddles, coset_stride, + arbitrary_coset, output_rev, dir); } } - if (config.columns_batch) { - transpose(temp_elements.get(), output, config.batch_size, size); - } else { - std::copy(temp_elements.get(), temp_elements.get() + total_size, output); + if (config.ordering == Ordering::kNR || config.ordering == Ordering::kRR) { + reorder_by_bit_reverse( + logn, output, config.batch_size, + config.columns_batch); // TODO - check if access the fixed indexes instead of reordering may be more efficient? } return eIcicleError::SUCCESS; } - - template - eIcicleError cpu_ntt(const Device& device, const E* input, uint64_t size, NTTDir dir, NTTConfig& config, E* output) - { - return cpu_ntt_ref(device, input, size, dir, config, output); - } } // namespace ntt_cpu \ No newline at end of file