diff --git a/icicle/backend/cpu/include/cpu_ntt_domain.h b/icicle/backend/cpu/include/cpu_ntt_domain.h index 01ce6794b..3cfe88018 100644 --- a/icicle/backend/cpu/include/cpu_ntt_domain.h +++ b/icicle/backend/cpu/include/cpu_ntt_domain.h @@ -26,7 +26,6 @@ namespace ntt_cpu { int max_size = 0; int max_log_size = 0; std::unique_ptr twiddles; - std::unique_ptr inv_log_sizes; std::unique_ptr winograd8_twiddles; std::unique_ptr winograd8_twiddles_inv; std::unique_ptr winograd16_twiddles; @@ -43,7 +42,6 @@ namespace ntt_cpu { static eIcicleError get_root_of_unity_from_domain(const Device& device, uint64_t logn, S* rou /*OUT*/); const inline S* get_twiddles() const { return twiddles.get(); } - const inline S* get_inv_log_sizes() const { return inv_log_sizes.get(); } const inline S* get_winograd8_twiddles() const { return winograd8_twiddles.get(); } const inline S* get_winograd8_twiddles_inv() const { return winograd8_twiddles_inv.get(); } const inline S* get_winograd16_twiddles() const { return winograd16_twiddles.get(); } @@ -103,11 +101,6 @@ namespace ntt_cpu { return eIcicleError::INVALID_ARGUMENT; } - s_ntt_domain.inv_log_sizes = std::make_unique(s_ntt_domain.max_log_size); - for (int i = 0; i < s_ntt_domain.max_log_size; i++) { - s_ntt_domain.inv_log_sizes[i] = S::inv_log_size(i); - } - // calculate twiddles // Note: radix-2 INTT needs ONE in last element (in addition to first element), therefore have n+1 elements @@ -124,9 +117,6 @@ namespace ntt_cpu { } s_ntt_domain.twiddles = std::move(temp_twiddles); // Assign twiddles using unique_ptr - const S inv2 = s_ntt_domain.inv_log_sizes[1]; - // const S inv2 = S::inv_log_size(1); - // Winograd 8 if (s_ntt_domain.max_log_size >= 3) { auto temp_win8_twiddles = std::make_unique(3); @@ -134,12 +124,13 @@ namespace ntt_cpu { int basic_tw_idx = (s_ntt_domain.max_size >> 3); S basic_tw = s_ntt_domain.twiddles[basic_tw_idx]; temp_win8_twiddles[0] = basic_tw * basic_tw; - temp_win8_twiddles[1] = (basic_tw + temp_win8_twiddles[0] * basic_tw) * inv2; - temp_win8_twiddles[2] = (basic_tw - temp_win8_twiddles[0] * basic_tw) * inv2; // = temp_win8_twiddles_inv[2] - basic_tw = s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx]; // for inverse ntt - temp_win8_twiddles_inv[0] = basic_tw * basic_tw; // temp_win8_twiddles_inv[0] + temp_win8_twiddles[1] = (basic_tw + temp_win8_twiddles[0] * basic_tw) * S::inv_log_size(1); + temp_win8_twiddles[2] = + (basic_tw - temp_win8_twiddles[0] * basic_tw) * S::inv_log_size(1); // = temp_win8_twiddles_inv[2] + basic_tw = s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx]; // for inverse ntt + temp_win8_twiddles_inv[0] = basic_tw * basic_tw; // temp_win8_twiddles_inv[0] temp_win8_twiddles_inv[1] = - (basic_tw + temp_win8_twiddles_inv[0] * basic_tw) * inv2; // temp_win8_twiddles_inv[1] + (basic_tw + temp_win8_twiddles_inv[0] * basic_tw) * S::inv_log_size(1); // temp_win8_twiddles_inv[1] temp_win8_twiddles_inv[2] = temp_win8_twiddles[2]; s_ntt_domain.winograd8_twiddles = std::move(temp_win8_twiddles); // Assign twiddles using unique_ptr @@ -160,35 +151,35 @@ namespace ntt_cpu { temp_win16_twiddles[4] = s_ntt_domain.twiddles[basic_tw_idx * 0]; temp_win16_twiddles[5] = s_ntt_domain.twiddles[basic_tw_idx * 4]; temp_win16_twiddles[6] = - (s_ntt_domain.twiddles[basic_tw_idx * 2] + s_ntt_domain.twiddles[basic_tw_idx * 6]) * inv2; + (s_ntt_domain.twiddles[basic_tw_idx * 2] + s_ntt_domain.twiddles[basic_tw_idx * 6]) * S::inv_log_size(1); temp_win16_twiddles[7] = - (s_ntt_domain.twiddles[basic_tw_idx * 2] - s_ntt_domain.twiddles[basic_tw_idx * 6]) * inv2; + (s_ntt_domain.twiddles[basic_tw_idx * 2] - s_ntt_domain.twiddles[basic_tw_idx * 6]) * S::inv_log_size(1); temp_win16_twiddles[8] = s_ntt_domain.twiddles[basic_tw_idx * 0]; temp_win16_twiddles[9] = s_ntt_domain.twiddles[basic_tw_idx * 4]; temp_win16_twiddles[10] = - (s_ntt_domain.twiddles[basic_tw_idx * 2] + s_ntt_domain.twiddles[basic_tw_idx * 6]) * inv2; + (s_ntt_domain.twiddles[basic_tw_idx * 2] + s_ntt_domain.twiddles[basic_tw_idx * 6]) * S::inv_log_size(1); temp_win16_twiddles[11] = - (s_ntt_domain.twiddles[basic_tw_idx * 2] - s_ntt_domain.twiddles[basic_tw_idx * 6]) * inv2; + (s_ntt_domain.twiddles[basic_tw_idx * 2] - s_ntt_domain.twiddles[basic_tw_idx * 6]) * S::inv_log_size(1); temp_win16_twiddles[12] = (s_ntt_domain.twiddles[basic_tw_idx * 1] - s_ntt_domain.twiddles[basic_tw_idx * 3] - s_ntt_domain.twiddles[basic_tw_idx * 5] + s_ntt_domain.twiddles[basic_tw_idx * 7]) * - inv2; + S::inv_log_size(1); temp_win16_twiddles[13] = S::neg( s_ntt_domain.twiddles[basic_tw_idx * 1] + s_ntt_domain.twiddles[basic_tw_idx * 3] + s_ntt_domain.twiddles[basic_tw_idx * 5] + s_ntt_domain.twiddles[basic_tw_idx * 7]) * - inv2; + S::inv_log_size(1); temp_win16_twiddles[14] = - (s_ntt_domain.twiddles[basic_tw_idx * 3] + s_ntt_domain.twiddles[basic_tw_idx * 5]) * inv2; + (s_ntt_domain.twiddles[basic_tw_idx * 3] + s_ntt_domain.twiddles[basic_tw_idx * 5]) * S::inv_log_size(1); temp_win16_twiddles[15] = (s_ntt_domain.twiddles[basic_tw_idx * 1] - s_ntt_domain.twiddles[basic_tw_idx * 3] + s_ntt_domain.twiddles[basic_tw_idx * 5] - s_ntt_domain.twiddles[basic_tw_idx * 7]) * - inv2; + S::inv_log_size(1); temp_win16_twiddles[16] = (s_ntt_domain.twiddles[basic_tw_idx * 5] - s_ntt_domain.twiddles[basic_tw_idx * 1] - s_ntt_domain.twiddles[basic_tw_idx * 3] + s_ntt_domain.twiddles[basic_tw_idx * 7]) * - inv2; + S::inv_log_size(1); temp_win16_twiddles[17] = - (s_ntt_domain.twiddles[basic_tw_idx * 3] - s_ntt_domain.twiddles[basic_tw_idx * 5]) * inv2; + (s_ntt_domain.twiddles[basic_tw_idx * 3] - s_ntt_domain.twiddles[basic_tw_idx * 5]) * S::inv_log_size(1); temp_win16_twiddles_inv[0] = s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 0]; temp_win16_twiddles_inv[1] = s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 0]; @@ -199,48 +190,48 @@ namespace ntt_cpu { temp_win16_twiddles_inv[5] = s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 4]; temp_win16_twiddles_inv[6] = (s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 2] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 6]) * - inv2; + S::inv_log_size(1); temp_win16_twiddles_inv[7] = (s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 2] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 6]) * - inv2; + S::inv_log_size(1); temp_win16_twiddles_inv[8] = s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 0]; temp_win16_twiddles_inv[9] = s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 4]; temp_win16_twiddles_inv[10] = (s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 2] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 6]) * - inv2; + S::inv_log_size(1); temp_win16_twiddles_inv[11] = (s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 2] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 6]) * - inv2; + S::inv_log_size(1); temp_win16_twiddles_inv[12] = (s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 1] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 3] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 5] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 7]) * - inv2; + S::inv_log_size(1); temp_win16_twiddles_inv[13] = S::neg( s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 1] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 3] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 5] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 7]) * - inv2; + S::inv_log_size(1); temp_win16_twiddles_inv[14] = (s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 3] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 5]) * - inv2; + S::inv_log_size(1); temp_win16_twiddles_inv[15] = (s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 1] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 3] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 5] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 7]) * - inv2; + S::inv_log_size(1); temp_win16_twiddles_inv[16] = (s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 5] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 1] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 3] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 7]) * - inv2; + S::inv_log_size(1); temp_win16_twiddles_inv[17] = (s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 3] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 5]) * - inv2; + S::inv_log_size(1); s_ntt_domain.winograd16_twiddles = std::move(temp_win16_twiddles); // Assign twiddles using unique_ptr s_ntt_domain.winograd16_twiddles_inv = std::move(temp_win16_twiddles_inv); // Assign twiddles using unique_ptr @@ -259,35 +250,35 @@ namespace ntt_cpu { temp_win32_twiddles[4] = s_ntt_domain.twiddles[basic_tw_idx * 0]; temp_win32_twiddles[5] = s_ntt_domain.twiddles[basic_tw_idx * 4]; temp_win32_twiddles[6] = - (s_ntt_domain.twiddles[basic_tw_idx * 2] + s_ntt_domain.twiddles[basic_tw_idx * 6]) * inv2; + (s_ntt_domain.twiddles[basic_tw_idx * 2] + s_ntt_domain.twiddles[basic_tw_idx * 6]) * S::inv_log_size(1); temp_win32_twiddles[7] = - (s_ntt_domain.twiddles[basic_tw_idx * 2] - s_ntt_domain.twiddles[basic_tw_idx * 6]) * inv2; + (s_ntt_domain.twiddles[basic_tw_idx * 2] - s_ntt_domain.twiddles[basic_tw_idx * 6]) * S::inv_log_size(1); temp_win32_twiddles[8] = s_ntt_domain.twiddles[basic_tw_idx * 0]; temp_win32_twiddles[9] = s_ntt_domain.twiddles[basic_tw_idx * 4]; temp_win32_twiddles[10] = - (s_ntt_domain.twiddles[basic_tw_idx * 2] + s_ntt_domain.twiddles[basic_tw_idx * 6]) * inv2; + (s_ntt_domain.twiddles[basic_tw_idx * 2] + s_ntt_domain.twiddles[basic_tw_idx * 6]) * S::inv_log_size(1); temp_win32_twiddles[11] = - (s_ntt_domain.twiddles[basic_tw_idx * 2] - s_ntt_domain.twiddles[basic_tw_idx * 6]) * inv2; + (s_ntt_domain.twiddles[basic_tw_idx * 2] - s_ntt_domain.twiddles[basic_tw_idx * 6]) * S::inv_log_size(1); temp_win32_twiddles[12] = (s_ntt_domain.twiddles[basic_tw_idx * 1] - s_ntt_domain.twiddles[basic_tw_idx * 3] - s_ntt_domain.twiddles[basic_tw_idx * 5] + s_ntt_domain.twiddles[basic_tw_idx * 7]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles[13] = S::neg( s_ntt_domain.twiddles[basic_tw_idx * 1] + s_ntt_domain.twiddles[basic_tw_idx * 3] + s_ntt_domain.twiddles[basic_tw_idx * 5] + s_ntt_domain.twiddles[basic_tw_idx * 7]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles[14] = - (s_ntt_domain.twiddles[basic_tw_idx * 3] + s_ntt_domain.twiddles[basic_tw_idx * 5]) * inv2; + (s_ntt_domain.twiddles[basic_tw_idx * 3] + s_ntt_domain.twiddles[basic_tw_idx * 5]) * S::inv_log_size(1); temp_win32_twiddles[15] = (s_ntt_domain.twiddles[basic_tw_idx * 1] - s_ntt_domain.twiddles[basic_tw_idx * 3] + s_ntt_domain.twiddles[basic_tw_idx * 5] - s_ntt_domain.twiddles[basic_tw_idx * 7]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles[16] = (s_ntt_domain.twiddles[basic_tw_idx * 5] - s_ntt_domain.twiddles[basic_tw_idx * 1] - s_ntt_domain.twiddles[basic_tw_idx * 3] + s_ntt_domain.twiddles[basic_tw_idx * 7]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles[17] = - (s_ntt_domain.twiddles[basic_tw_idx * 3] - s_ntt_domain.twiddles[basic_tw_idx * 5]) * inv2; + (s_ntt_domain.twiddles[basic_tw_idx * 3] - s_ntt_domain.twiddles[basic_tw_idx * 5]) * S::inv_log_size(1); basic_tw_idx = s_ntt_domain.max_size >> 5; @@ -295,105 +286,105 @@ namespace ntt_cpu { temp_win32_twiddles[19] = s_ntt_domain.twiddles[basic_tw_idx * 8]; temp_win32_twiddles[20] = - (s_ntt_domain.twiddles[basic_tw_idx * 4] + s_ntt_domain.twiddles[basic_tw_idx * 12]) * inv2; + (s_ntt_domain.twiddles[basic_tw_idx * 4] + s_ntt_domain.twiddles[basic_tw_idx * 12]) * S::inv_log_size(1); temp_win32_twiddles[21] = - (s_ntt_domain.twiddles[basic_tw_idx * 4] - s_ntt_domain.twiddles[basic_tw_idx * 12]) * inv2; + (s_ntt_domain.twiddles[basic_tw_idx * 4] - s_ntt_domain.twiddles[basic_tw_idx * 12]) * S::inv_log_size(1); temp_win32_twiddles[22] = (s_ntt_domain.twiddles[basic_tw_idx * 2] + s_ntt_domain.twiddles[basic_tw_idx * 14] - s_ntt_domain.twiddles[basic_tw_idx * 6] - s_ntt_domain.twiddles[basic_tw_idx * 10]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles[23] = S::neg( s_ntt_domain.twiddles[basic_tw_idx * 2] + s_ntt_domain.twiddles[basic_tw_idx * 14] + s_ntt_domain.twiddles[basic_tw_idx * 6] + s_ntt_domain.twiddles[basic_tw_idx * 10]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles[24] = - (s_ntt_domain.twiddles[basic_tw_idx * 6] + s_ntt_domain.twiddles[basic_tw_idx * 10]) * inv2; + (s_ntt_domain.twiddles[basic_tw_idx * 6] + s_ntt_domain.twiddles[basic_tw_idx * 10]) * S::inv_log_size(1); temp_win32_twiddles[25] = (s_ntt_domain.twiddles[basic_tw_idx * 2] - s_ntt_domain.twiddles[basic_tw_idx * 14] - s_ntt_domain.twiddles[basic_tw_idx * 6] + s_ntt_domain.twiddles[basic_tw_idx * 10]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles[26] = (s_ntt_domain.twiddles[basic_tw_idx * 14] - s_ntt_domain.twiddles[basic_tw_idx * 2] - s_ntt_domain.twiddles[basic_tw_idx * 6] + s_ntt_domain.twiddles[basic_tw_idx * 10]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles[27] = - (s_ntt_domain.twiddles[basic_tw_idx * 6] - s_ntt_domain.twiddles[basic_tw_idx * 10]) * inv2; + (s_ntt_domain.twiddles[basic_tw_idx * 6] - s_ntt_domain.twiddles[basic_tw_idx * 10]) * S::inv_log_size(1); temp_win32_twiddles[28] = (s_ntt_domain.twiddles[basic_tw_idx * 1] + s_ntt_domain.twiddles[basic_tw_idx * 15] - s_ntt_domain.twiddles[basic_tw_idx * 5] - s_ntt_domain.twiddles[basic_tw_idx * 11] - s_ntt_domain.twiddles[basic_tw_idx * 7] - s_ntt_domain.twiddles[basic_tw_idx * 9] + s_ntt_domain.twiddles[basic_tw_idx * 3] + s_ntt_domain.twiddles[basic_tw_idx * 13]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles[29] = (s_ntt_domain.twiddles[basic_tw_idx * 5] - s_ntt_domain.twiddles[basic_tw_idx * 1] - s_ntt_domain.twiddles[basic_tw_idx * 15] + s_ntt_domain.twiddles[basic_tw_idx * 11] - s_ntt_domain.twiddles[basic_tw_idx * 7] - s_ntt_domain.twiddles[basic_tw_idx * 9] + s_ntt_domain.twiddles[basic_tw_idx * 3] + s_ntt_domain.twiddles[basic_tw_idx * 13]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles[30] = (s_ntt_domain.twiddles[basic_tw_idx * 7] + s_ntt_domain.twiddles[basic_tw_idx * 9] - s_ntt_domain.twiddles[basic_tw_idx * 3] - s_ntt_domain.twiddles[basic_tw_idx * 13]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles[31] = (s_ntt_domain.twiddles[basic_tw_idx * 3] - s_ntt_domain.twiddles[basic_tw_idx * 7] - s_ntt_domain.twiddles[basic_tw_idx * 9] - s_ntt_domain.twiddles[basic_tw_idx * 5] - s_ntt_domain.twiddles[basic_tw_idx * 11] - s_ntt_domain.twiddles[basic_tw_idx * 1] - s_ntt_domain.twiddles[basic_tw_idx * 15] + s_ntt_domain.twiddles[basic_tw_idx * 13]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles[32] = (s_ntt_domain.twiddles[basic_tw_idx * 7] + s_ntt_domain.twiddles[basic_tw_idx * 9] + s_ntt_domain.twiddles[basic_tw_idx * 5] + s_ntt_domain.twiddles[basic_tw_idx * 11] - s_ntt_domain.twiddles[basic_tw_idx * 1] - s_ntt_domain.twiddles[basic_tw_idx * 15] + s_ntt_domain.twiddles[basic_tw_idx * 3] + s_ntt_domain.twiddles[basic_tw_idx * 13]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles[33] = (s_ntt_domain.twiddles[basic_tw_idx * 1] + s_ntt_domain.twiddles[basic_tw_idx * 15] - s_ntt_domain.twiddles[basic_tw_idx * 3] - s_ntt_domain.twiddles[basic_tw_idx * 13]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles[34] = (s_ntt_domain.twiddles[basic_tw_idx * 5] + s_ntt_domain.twiddles[basic_tw_idx * 11] - s_ntt_domain.twiddles[basic_tw_idx * 3] - s_ntt_domain.twiddles[basic_tw_idx * 13]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles[35] = (S::neg( s_ntt_domain.twiddles[basic_tw_idx * 5] + s_ntt_domain.twiddles[basic_tw_idx * 11] + s_ntt_domain.twiddles[basic_tw_idx * 3] + s_ntt_domain.twiddles[basic_tw_idx * 13])) * - inv2; + S::inv_log_size(1); temp_win32_twiddles[36] = - (s_ntt_domain.twiddles[basic_tw_idx * 3] + s_ntt_domain.twiddles[basic_tw_idx * 13]) * inv2; + (s_ntt_domain.twiddles[basic_tw_idx * 3] + s_ntt_domain.twiddles[basic_tw_idx * 13]) * S::inv_log_size(1); temp_win32_twiddles[37] = (s_ntt_domain.twiddles[basic_tw_idx * 1] - s_ntt_domain.twiddles[basic_tw_idx * 15] - s_ntt_domain.twiddles[basic_tw_idx * 5] + s_ntt_domain.twiddles[basic_tw_idx * 11] - s_ntt_domain.twiddles[basic_tw_idx * 7] + s_ntt_domain.twiddles[basic_tw_idx * 9] + s_ntt_domain.twiddles[basic_tw_idx * 3] - s_ntt_domain.twiddles[basic_tw_idx * 13]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles[38] = (s_ntt_domain.twiddles[basic_tw_idx * 15] - s_ntt_domain.twiddles[basic_tw_idx * 1] + s_ntt_domain.twiddles[basic_tw_idx * 5] - s_ntt_domain.twiddles[basic_tw_idx * 11] - s_ntt_domain.twiddles[basic_tw_idx * 7] + s_ntt_domain.twiddles[basic_tw_idx * 9] + s_ntt_domain.twiddles[basic_tw_idx * 3] - s_ntt_domain.twiddles[basic_tw_idx * 13]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles[39] = (s_ntt_domain.twiddles[basic_tw_idx * 7] - s_ntt_domain.twiddles[basic_tw_idx * 9] - s_ntt_domain.twiddles[basic_tw_idx * 3] + s_ntt_domain.twiddles[basic_tw_idx * 13]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles[40] = (s_ntt_domain.twiddles[basic_tw_idx * 7] - s_ntt_domain.twiddles[basic_tw_idx * 9] - s_ntt_domain.twiddles[basic_tw_idx * 5] + s_ntt_domain.twiddles[basic_tw_idx * 11] + s_ntt_domain.twiddles[basic_tw_idx * 1] - s_ntt_domain.twiddles[basic_tw_idx * 15] + s_ntt_domain.twiddles[basic_tw_idx * 3] - s_ntt_domain.twiddles[basic_tw_idx * 13]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles[41] = (s_ntt_domain.twiddles[basic_tw_idx * 9] - s_ntt_domain.twiddles[basic_tw_idx * 7] + s_ntt_domain.twiddles[basic_tw_idx * 5] - s_ntt_domain.twiddles[basic_tw_idx * 11] + s_ntt_domain.twiddles[basic_tw_idx * 1] - s_ntt_domain.twiddles[basic_tw_idx * 15] + s_ntt_domain.twiddles[basic_tw_idx * 3] - s_ntt_domain.twiddles[basic_tw_idx * 13]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles[42] = (s_ntt_domain.twiddles[basic_tw_idx * 15] - s_ntt_domain.twiddles[basic_tw_idx * 1] - s_ntt_domain.twiddles[basic_tw_idx * 3] + s_ntt_domain.twiddles[basic_tw_idx * 13]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles[43] = (s_ntt_domain.twiddles[basic_tw_idx * 5] - s_ntt_domain.twiddles[basic_tw_idx * 11] - s_ntt_domain.twiddles[basic_tw_idx * 3] + s_ntt_domain.twiddles[basic_tw_idx * 13]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles[44] = (s_ntt_domain.twiddles[basic_tw_idx * 11] - s_ntt_domain.twiddles[basic_tw_idx * 5] - s_ntt_domain.twiddles[basic_tw_idx * 3] + s_ntt_domain.twiddles[basic_tw_idx * 13]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles[45] = - (s_ntt_domain.twiddles[basic_tw_idx * 3] - s_ntt_domain.twiddles[basic_tw_idx * 13]) * inv2; + (s_ntt_domain.twiddles[basic_tw_idx * 3] - s_ntt_domain.twiddles[basic_tw_idx * 13]) * S::inv_log_size(1); s_ntt_domain.winograd32_twiddles = std::move(temp_win32_twiddles); // Assign twiddles using unique_ptr @@ -407,48 +398,48 @@ namespace ntt_cpu { temp_win32_twiddles_inv[5] = s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 4]; temp_win32_twiddles_inv[6] = (s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 2] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 6]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles_inv[7] = (s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 2] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 6]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles_inv[8] = s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 0]; temp_win32_twiddles_inv[9] = s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 4]; temp_win32_twiddles_inv[10] = (s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 2] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 6]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles_inv[11] = (s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 2] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 6]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles_inv[12] = (s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 1] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 3] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 5] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 7]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles_inv[13] = S::neg( s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 1] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 3] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 5] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 7]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles_inv[14] = (s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 3] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 5]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles_inv[15] = (s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 1] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 3] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 5] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 7]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles_inv[16] = (s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 5] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 1] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 3] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 7]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles_inv[17] = (s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 3] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 5]) * - inv2; + S::inv_log_size(1); basic_tw_idx = s_ntt_domain.max_size >> 5; @@ -457,41 +448,41 @@ namespace ntt_cpu { temp_win32_twiddles_inv[20] = (s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 4] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 12]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles_inv[21] = (s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 4] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 12]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles_inv[22] = (s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 2] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 14] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 6] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 10]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles_inv[23] = S::neg( s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 2] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 14] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 6] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 10]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles_inv[24] = (s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 6] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 10]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles_inv[25] = (s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 2] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 14] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 6] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 10]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles_inv[26] = (s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 14] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 2] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 6] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 10]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles_inv[27] = (s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 6] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 10]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles_inv[28] = (s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 1] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 15] - @@ -501,7 +492,7 @@ namespace ntt_cpu { s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 9] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 3] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 13]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles_inv[29] = (s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 5] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 1] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 15] + @@ -510,12 +501,12 @@ namespace ntt_cpu { s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 9] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 3] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 13]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles_inv[30] = (s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 7] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 9] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 3] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 13]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles_inv[31] = (s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 3] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 7] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 9] - @@ -524,7 +515,7 @@ namespace ntt_cpu { s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 1] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 15] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 13]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles_inv[32] = (s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 7] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 9] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 5] + @@ -533,26 +524,26 @@ namespace ntt_cpu { s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 15] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 3] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 13]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles_inv[33] = (s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 1] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 15] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 3] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 13]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles_inv[34] = (s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 5] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 11] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 3] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 13]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles_inv[35] = (S::neg( s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 5] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 11] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 3] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 13])) * - inv2; + S::inv_log_size(1); temp_win32_twiddles_inv[36] = (s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 3] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 13]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles_inv[37] = (s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 1] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 15] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 5] + @@ -561,7 +552,7 @@ namespace ntt_cpu { s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 9] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 3] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 13]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles_inv[38] = (s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 15] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 1] + @@ -571,13 +562,13 @@ namespace ntt_cpu { s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 9] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 3] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 13]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles_inv[39] = (s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 7] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 9] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 3] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 13]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles_inv[40] = (s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 7] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 9] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 5] + @@ -586,7 +577,7 @@ namespace ntt_cpu { s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 15] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 3] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 13]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles_inv[41] = (s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 9] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 7] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 5] - @@ -595,25 +586,25 @@ namespace ntt_cpu { s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 15] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 3] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 13]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles_inv[42] = (s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 15] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 1] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 3] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 13]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles_inv[43] = (s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 5] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 11] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 3] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 13]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles_inv[44] = (s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 11] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 5] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 3] + s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 13]) * - inv2; + S::inv_log_size(1); temp_win32_twiddles_inv[45] = (s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 3] - s_ntt_domain.twiddles[s_ntt_domain.max_size - basic_tw_idx * 13]) * - inv2; + S::inv_log_size(1); s_ntt_domain.winograd32_twiddles_inv = std::move(temp_win32_twiddles_inv); // Assign twiddles using unique_ptr } diff --git a/icicle/backend/cpu/include/ntt_task.h b/icicle/backend/cpu/include/ntt_task.h index ba7530216..c0157ffa4 100644 --- a/icicle/backend/cpu/include/ntt_task.h +++ b/icicle/backend/cpu/include/ntt_task.h @@ -365,8 +365,7 @@ namespace ntt_cpu { current_elements[index_in_mem[4]] = current_elements[index_in_mem[4]] - T; if (last_layer && ntt_data->direction == NTTDir::kInverse) { - const S* inv_log_sizes = CpuNttDomain::s_ntt_domain.get_inv_log_sizes(); - S inv_size = inv_log_sizes[ntt_data->logn]; + S inv_size = S::inv_log_size(ntt_data->logn); for (uint64_t i = 0; i < 8; ++i) { current_elements[index_in_mem[i]] = current_elements[index_in_mem[i]] * inv_size; } @@ -557,8 +556,7 @@ namespace ntt_cpu { current_elements[index_in_mem[10]] = T; if (last_layer && ntt_data->direction == NTTDir::kInverse) { - const S* inv_log_sizes = CpuNttDomain::s_ntt_domain.get_inv_log_sizes(); - S inv_size = inv_log_sizes[ntt_data->logn]; + S inv_size = S::inv_log_size(ntt_data->logn); for (uint64_t i = 0; i < 16; ++i) { current_elements[index_in_mem[i]] = current_elements[index_in_mem[i]] * inv_size; } @@ -1238,8 +1236,7 @@ namespace ntt_cpu { current_elements[index_in_mem[31]] = temp_1[15] - temp_1[31]; if (last_layer && ntt_data->direction == NTTDir::kInverse) { - const S* inv_log_sizes = CpuNttDomain::s_ntt_domain.get_inv_log_sizes(); - S inv_size = inv_log_sizes[ntt_data->logn]; + S inv_size = S::inv_log_size(ntt_data->logn); for (uint64_t i = 0; i < 32; ++i) { current_elements[index_in_mem[i]] = current_elements[index_in_mem[i]] * inv_size; } diff --git a/icicle/include/icicle/fields/field.h b/icicle/include/icicle/fields/field.h index 012d3ed30..b10a3e8e3 100644 --- a/icicle/include/icicle/fields/field.h +++ b/icicle/include/icicle/fields/field.h @@ -136,31 +136,6 @@ class Field "CUDA ERROR: field.h: error on inv_log_size(logn): logn(=%u) > omegas_count (=%u)", logn, CONFIG::omegas_count); assert(false); } -#endif // __CUDA_ARCH__ - Field rs = {1}; - for (int i = 0; i < logn; i++) { - if (rs.limbs_storage.limbs[0] & 1) - base_math::template add_sub_limbs(rs.limbs_storage, get_modulus(), rs.limbs_storage); - rs.limbs_storage = base_math::template right_shift(rs.limbs_storage); - } -#ifdef BARRET - return rs; -#else - return to_montgomery(rs); -#endif - } - - static HOST_DEVICE_INLINE Field inv_log_size_dep(uint32_t logn) - { - if (logn == 0) { return one(); } -#ifndef __CUDA_ARCH__ - if (logn > CONFIG::omegas_count) THROW_ICICLE_ERR(eIcicleError::INVALID_ARGUMENT, "Field: Invalid inv index"); -#else - if (logn > CONFIG::omegas_count) { - printf( - "CUDA ERROR: field.h: error on inv_log_size(logn): logn(=%u) > omegas_count (=%u)", logn, CONFIG::omegas_count); - assert(false); - } #endif // __CUDA_ARCH__ storage_array const inv = CONFIG::inv; #ifdef BARRET