Skip to content

Commit

Permalink
correctness for 6
Browse files Browse the repository at this point in the history
  • Loading branch information
ChickenLover committed Sep 26, 2024
1 parent 6e48a38 commit 2fc4614
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 20 deletions.
43 changes: 28 additions & 15 deletions icicle/src/ntt/kernel_ntt.cu
Original file line number Diff line number Diff line change
Expand Up @@ -415,24 +415,33 @@ namespace mxntt {
if (s_meta.ntt_block_id >= nof_ntt_blocks || (columns_batch_size > 0 && s_meta.batch_id >= columns_batch_size))
return;

engine.loadGlobalData(in, data_stride, log_data_stride, strided, s_meta);
if (dit) {
engine.loadGlobalData64(in, data_stride, log_data_stride, strided, s_meta);
} else {
engine.loadGlobalData(in, data_stride, log_data_stride, strided, s_meta);
}

// printf(
// "T Before: %d\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n",
// threadIdx.x,
// engine.X[0].limbs_storage.limbs[0],
// engine.X[1].limbs_storage.limbs[0],
// engine.X[2].limbs_storage.limbs[0],
// engine.X[3].limbs_storage.limbs[0],
// engine.X[4].limbs_storage.limbs[0],
// engine.X[5].limbs_storage.limbs[0],
// engine.X[6].limbs_storage.limbs[0],
// engine.X[7].limbs_storage.limbs[0]
// );
if (s_meta.ntt_block_id < 2)
printf(
"T Before: %d\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n",
threadIdx.x,
engine.X[0].limbs_storage.limbs[0],
engine.X[1].limbs_storage.limbs[0],
engine.X[2].limbs_storage.limbs[0],
engine.X[3].limbs_storage.limbs[0],
engine.X[4].limbs_storage.limbs[0],
engine.X[5].limbs_storage.limbs[0],
engine.X[6].limbs_storage.limbs[0],
engine.X[7].limbs_storage.limbs[0]
);
#pragma unroll 1
for (uint32_t phase = 0; phase < 2; phase++) {
engine.loadBasicTwiddlesGeneric(basic_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, twiddles_offset, 6, inv, dit, phase);
engine.ntt8();
if (inv) {
engine.intt8();
} else {
engine.ntt8();
}

if (phase == 0) {
engine.SharedData64Columns8(shmem, true, false, strided); // store
Expand Down Expand Up @@ -465,7 +474,11 @@ namespace mxntt {
// engine.X[7].limbs_storage.limbs[0]
// );

engine.storeGlobalData(out, data_stride, log_data_stride, strided, s_meta);
if (dit) {
engine.storeGlobalDataDit(out, data_stride, log_data_stride, strided, s_meta);
} else {
engine.storeGlobalData(out, data_stride, log_data_stride, strided, s_meta);
}
}

template <typename E, typename S>
Expand Down
10 changes: 5 additions & 5 deletions icicle/src/ntt/ntt.cu
Original file line number Diff line number Diff line change
Expand Up @@ -478,11 +478,11 @@ namespace ntt {
CHK_IF_RETURN(mxntt::generate_twiddles_dcct(
primitive_root, domain.basic_twiddles, domain.max_log_size, ctx.stream));

S* tmp = static_cast<S*>(malloc(number_of_twiddles * sizeof(S)));
cudaMemcpy(tmp, domain.basic_twiddles, number_of_twiddles * sizeof(S), cudaMemcpyDeviceToHost);
for (size_t i = 0; i < number_of_twiddles; i++) {
std::cout << tmp[i] << std::endl;
}
// S* tmp = static_cast<S*>(malloc(number_of_twiddles * sizeof(S)));
// cudaMemcpy(tmp, domain.basic_twiddles, number_of_twiddles * sizeof(S), cudaMemcpyDeviceToHost);
// for (size_t i = 0; i < number_of_twiddles; i++) {
// std::cout << tmp[i] << std::endl;
// }
domain.coset_index[S::one()] = 0;
#else
// allocate and calculate twiddles on GPU
Expand Down
33 changes: 33 additions & 0 deletions icicle/src/ntt/thread_ntt.cu
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,22 @@ public:
X[i] = data[s_meta.th_stride * i * data_stride_u64];
}
}
DEVICE_INLINE void
loadGlobalData64(const E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
{
const uint64_t data_stride_u64 = data_stride;
if (strided) {
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride_u64 * s_meta.ntt_inp_id * 8 +
(s_meta.ntt_block_id >> log_data_stride) * data_stride_u64 * s_meta.ntt_block_size;
} else {
data += (uint64_t)s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id * 8;
}

UNROLL
for (uint32_t i = 0; i < 8; i++) {
X[i] = data[i * data_stride_u64];
}
}

DEVICE_INLINE void
storeGlobalData(E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
Expand All @@ -149,6 +165,23 @@ public:
}
}

DEVICE_INLINE void
storeGlobalDataDit(E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
{
const uint64_t data_stride_u64 = data_stride;
if (strided) {
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride_u64 * s_meta.ntt_inp_id +
(s_meta.ntt_block_id >> log_data_stride) * data_stride_u64 * s_meta.ntt_block_size;
} else {
data += (uint64_t)s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id;
}

UNROLL
for (uint32_t i = 0; i < 8; i++) {
data[i * 8 * data_stride_u64] = X[i];
}
}

DEVICE_INLINE void
loadGlobalData32(const E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
{
Expand Down

0 comments on commit 2fc4614

Please sign in to comment.