Skip to content

Commit

Permalink
actual correctness for ntt64
Browse files Browse the repository at this point in the history
  • Loading branch information
ChickenLover committed Aug 5, 2024
1 parent 733f5f3 commit aa06501
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 47 deletions.
74 changes: 49 additions & 25 deletions icicle/src/ntt/kernel_ntt.cu
Original file line number Diff line number Diff line change
Expand Up @@ -423,43 +423,67 @@ namespace mxntt {
engine.loadGlobalData(in, data_stride, log_data_stride, strided, s_meta);

// if (threadIdx.x == 0) {
// printf(
// "T BEFORE: %d\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n",
// threadIdx.x,
// engine.X[0].limbs_storage.limbs[0],
// engine.X[0].limbs_storage.limbs[1],
// engine.X[0].limbs_storage.limbs[2],
// engine.X[0].limbs_storage.limbs[3],
// engine.X[0].limbs_storage.limbs[4],
// engine.X[0].limbs_storage.limbs[5],
// engine.X[0].limbs_storage.limbs[6],
// engine.X[0].limbs_storage.limbs[7]
// );
// }
engine.loadBasicTwiddlesGeneric64(basic_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv, false);
#pragma unroll 1
for (uint32_t phase = 0; phase < 2; phase++) {
printf(
"T BEFORE: %d\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n",
threadIdx.x,
engine.X[0].limbs_storage.limbs[0],
engine.X[0].limbs_storage.limbs[1],
engine.X[0].limbs_storage.limbs[2],
engine.X[0].limbs_storage.limbs[3],
engine.X[0].limbs_storage.limbs[4],
engine.X[0].limbs_storage.limbs[5],
engine.X[0].limbs_storage.limbs[6],
engine.X[0].limbs_storage.limbs[7]
engine.X[1].limbs_storage.limbs[0],
engine.X[2].limbs_storage.limbs[0],
engine.X[3].limbs_storage.limbs[0],
engine.X[4].limbs_storage.limbs[0],
engine.X[5].limbs_storage.limbs[0],
engine.X[6].limbs_storage.limbs[0],
engine.X[7].limbs_storage.limbs[0]
);
// }
engine.loadBasicTwiddlesGeneric64(basic_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv, false);
#pragma unroll 1
for (uint32_t phase = 0; phase < 2; phase++) {
engine.ntt8();

// if (threadIdx.x == 0) {
printf(
"T AFTER: %d\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n",
threadIdx.x,
engine.X[0].limbs_storage.limbs[0],
engine.X[0].limbs_storage.limbs[1],
engine.X[0].limbs_storage.limbs[2],
engine.X[0].limbs_storage.limbs[3],
engine.X[0].limbs_storage.limbs[4],
engine.X[0].limbs_storage.limbs[5],
engine.X[0].limbs_storage.limbs[6],
engine.X[0].limbs_storage.limbs[7]
);
// printf(
// "T AFTER: %d\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n",
// threadIdx.x,
// engine.X[0].limbs_storage.limbs[0],
// engine.X[1].limbs_storage.limbs[0],
// engine.X[2].limbs_storage.limbs[0],
// engine.X[3].limbs_storage.limbs[0],
// engine.X[4].limbs_storage.limbs[0],
// engine.X[5].limbs_storage.limbs[0],
// engine.X[6].limbs_storage.limbs[0],
// engine.X[7].limbs_storage.limbs[0]
// );
// }
if (phase == 0) {
engine.loadBasicTwiddlesGeneric64(basic_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv, true);
engine.SharedData64Columns8(shmem, true, false, strided); // store
__syncthreads();
engine.SharedData64Rows8(shmem, false, false, strided); // load
printf(
"T AFTER: %d\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n",
threadIdx.x,
engine.X[0].limbs_storage.limbs[0],
engine.X[1].limbs_storage.limbs[0],
engine.X[2].limbs_storage.limbs[0],
engine.X[3].limbs_storage.limbs[0],
engine.X[4].limbs_storage.limbs[0],
engine.X[5].limbs_storage.limbs[0],
engine.X[6].limbs_storage.limbs[0],
engine.X[7].limbs_storage.limbs[0]
);
}
}

Expand Down Expand Up @@ -700,7 +724,7 @@ namespace mxntt {

int stage = log_size - 1;
uint32_t stage_rev = 0;
S* stage_ptr = basic_twiddles;
S* stage_ptr = basic_twiddles + (stage * (1 << stage));
const int NOF_BLOCKS = (stage >= 8) ? (1 << (stage - 8)) : 1;
const int NOF_THREADS = (stage >= 8) ? 256 : (1 << stage);
// std::cout << "Stage: " << stage << "; nof_blocks: " << NOF_BLOCKS << "; nof_threads: " << NOF_THREADS << "; step:
Expand All @@ -709,7 +733,7 @@ namespace mxntt {
CHK_IF_RETURN(cudaPeekAtLastError());

for (--stage; stage >= 0; stage--) {
stage_ptr += 1 << (log_size - 1);
stage_ptr -= 1 << (log_size - 1);
stage_rev++;
// std::cout << "Stage: " << stage << "; nof_blocks: " << NOF_BLOCKS << "; nof_threads: " << NOF_THREADS << ";
// step: " << step << "; temp_root: " << temp_root <<"; stage_ptr: " << stage_ptr<< std::endl;
Expand Down
44 changes: 22 additions & 22 deletions icicle/src/ntt/thread_ntt.cu
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,12 @@ public:
uint32_t block_offset = s_meta.ntt_inp_id * 4;
for (int i = 0; i < 4; i++) {
if (phase) {
exp = phase_offset + s_meta.ntt_inp_id + (stage * 4 + i) * 8;
exp = phase_offset + stage_offset + block_offset + i;
} else {
exp = stage_offset + block_offset + i;
exp = s_meta.ntt_inp_id + (stage * 4 + i) * 8;
}

if (threadIdx.x == 0) {
// if (threadIdx.x == 0) {
printf(
"T: %d, I: %d, stage_offset: %d, block_offset: %d, exp: %d, tw: 0x%x\n",
threadIdx.x,
Expand All @@ -87,7 +87,7 @@ public:
exp,
basic_twiddles[exp].limbs_storage.limbs[0]
);
}
// }

WB[stage * 4 + i] = basic_twiddles[(inv && exp) ? ((1 << tw_log_size) - exp) : exp];
}
Expand Down Expand Up @@ -130,12 +130,12 @@ public:
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride_u64 * s_meta.ntt_inp_id +
(s_meta.ntt_block_id >> log_data_stride) * data_stride_u64 * s_meta.ntt_block_size;
} else {
data += (uint64_t)s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id * s_meta.th_stride;
data += (uint64_t)s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id;
}

UNROLL
for (uint32_t i = 0; i < 8; i++) {
X[i] = data[i * data_stride_u64];
X[i] = data[s_meta.th_stride * i * data_stride_u64];
}
}

Expand Down Expand Up @@ -404,16 +404,16 @@ public:
{
E T;

// Stage 2
X[1] = X[1] * WB[0];
X[3] = X[3] * WB[1];
X[5] = X[5] * WB[2];
// Stage 0
X[4] = X[4] * WB[0];
X[5] = X[5] * WB[1];
X[6] = X[6] * WB[2];
X[7] = X[7] * WB[3];

BF(T, X[0], X[1]);
BF(T, X[2], X[3]);
BF(T, X[4], X[5]);
BF(T, X[6], X[7]);
BF(T, X[0], X[4]);
BF(T, X[1], X[5]);
BF(T, X[2], X[6]);
BF(T, X[3], X[7]);

// Stage 1
X[2] = X[2] * WB[4];
Expand All @@ -426,16 +426,16 @@ public:
BF(T, X[4], X[6]);
BF(T, X[5], X[7]);

// Stage 0
X[4] = X[4] * WB[8];
X[5] = X[5] * WB[9];
X[6] = X[6] * WB[10];
// Stage 2
X[1] = X[1] * WB[8];
X[3] = X[3] * WB[9];
X[5] = X[5] * WB[10];
X[7] = X[7] * WB[11];

BF(T, X[0], X[4]);
BF(T, X[1], X[5]);
BF(T, X[2], X[6]);
BF(T, X[3], X[7]);
BF(T, X[0], X[1]);
BF(T, X[2], X[3]);
BF(T, X[4], X[5]);
BF(T, X[6], X[7]);
}

DEVICE_INLINE void SharedData64Columns8(E* shmem, bool store, bool high_bits, bool stride)
Expand Down

0 comments on commit aa06501

Please sign in to comment.