diff --git a/icicle/src/ntt/kernel_ntt.cu b/icicle/src/ntt/kernel_ntt.cu index b5607de88e..007eb43990 100644 --- a/icicle/src/ntt/kernel_ntt.cu +++ b/icicle/src/ntt/kernel_ntt.cu @@ -415,24 +415,33 @@ namespace mxntt { if (s_meta.ntt_block_id >= nof_ntt_blocks || (columns_batch_size > 0 && s_meta.batch_id >= columns_batch_size)) return; - engine.loadGlobalData(in, data_stride, log_data_stride, strided, s_meta); + if (dit) { + engine.loadGlobalData64(in, data_stride, log_data_stride, strided, s_meta); + } else { + engine.loadGlobalData(in, data_stride, log_data_stride, strided, s_meta); + } - // printf( - // "T Before: %d\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n", - // threadIdx.x, - // engine.X[0].limbs_storage.limbs[0], - // engine.X[1].limbs_storage.limbs[0], - // engine.X[2].limbs_storage.limbs[0], - // engine.X[3].limbs_storage.limbs[0], - // engine.X[4].limbs_storage.limbs[0], - // engine.X[5].limbs_storage.limbs[0], - // engine.X[6].limbs_storage.limbs[0], - // engine.X[7].limbs_storage.limbs[0] - // ); + if (s_meta.ntt_block_id < 2) + printf( + "T Before: %d\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n0x%x\n", + threadIdx.x, + engine.X[0].limbs_storage.limbs[0], + engine.X[1].limbs_storage.limbs[0], + engine.X[2].limbs_storage.limbs[0], + engine.X[3].limbs_storage.limbs[0], + engine.X[4].limbs_storage.limbs[0], + engine.X[5].limbs_storage.limbs[0], + engine.X[6].limbs_storage.limbs[0], + engine.X[7].limbs_storage.limbs[0] + ); #pragma unroll 1 for (uint32_t phase = 0; phase < 2; phase++) { engine.loadBasicTwiddlesGeneric(basic_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, twiddles_offset, 6, inv, dit, phase); - engine.ntt8(); + if (inv) { + engine.intt8(); + } else { + engine.ntt8(); + } if (phase == 0) { engine.SharedData64Columns8(shmem, true, false, strided); // store @@ -465,7 +474,11 @@ namespace mxntt { // engine.X[7].limbs_storage.limbs[0] // ); - engine.storeGlobalData(out, data_stride, log_data_stride, strided, s_meta); + if (dit) { + engine.storeGlobalDataDit(out, data_stride, log_data_stride, strided, s_meta); + } else { + engine.storeGlobalData(out, data_stride, log_data_stride, strided, s_meta); + } } template diff --git a/icicle/src/ntt/ntt.cu b/icicle/src/ntt/ntt.cu index dc69b6082b..341c087fb3 100644 --- a/icicle/src/ntt/ntt.cu +++ b/icicle/src/ntt/ntt.cu @@ -478,11 +478,11 @@ namespace ntt { CHK_IF_RETURN(mxntt::generate_twiddles_dcct( primitive_root, domain.basic_twiddles, domain.max_log_size, ctx.stream)); - S* tmp = static_cast(malloc(number_of_twiddles * sizeof(S))); - cudaMemcpy(tmp, domain.basic_twiddles, number_of_twiddles * sizeof(S), cudaMemcpyDeviceToHost); - for (size_t i = 0; i < number_of_twiddles; i++) { - std::cout << tmp[i] << std::endl; - } + // S* tmp = static_cast(malloc(number_of_twiddles * sizeof(S))); + // cudaMemcpy(tmp, domain.basic_twiddles, number_of_twiddles * sizeof(S), cudaMemcpyDeviceToHost); + // for (size_t i = 0; i < number_of_twiddles; i++) { + // std::cout << tmp[i] << std::endl; + // } domain.coset_index[S::one()] = 0; #else // allocate and calculate twiddles on GPU diff --git a/icicle/src/ntt/thread_ntt.cu b/icicle/src/ntt/thread_ntt.cu index e60c81ea5a..5a6639f963 100644 --- a/icicle/src/ntt/thread_ntt.cu +++ b/icicle/src/ntt/thread_ntt.cu @@ -131,6 +131,22 @@ public: X[i] = data[s_meta.th_stride * i * data_stride_u64]; } } + DEVICE_INLINE void + loadGlobalData64(const E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta) + { + const uint64_t data_stride_u64 = data_stride; + if (strided) { + data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride_u64 * s_meta.ntt_inp_id * 8 + + (s_meta.ntt_block_id >> log_data_stride) * data_stride_u64 * s_meta.ntt_block_size; + } else { + data += (uint64_t)s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id * 8; + } + + UNROLL + for (uint32_t i = 0; i < 8; i++) { + X[i] = data[i * data_stride_u64]; + } + } DEVICE_INLINE void storeGlobalData(E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta) @@ -149,6 +165,23 @@ public: } } + DEVICE_INLINE void + storeGlobalDataDit(E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta) + { + const uint64_t data_stride_u64 = data_stride; + if (strided) { + data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride_u64 * s_meta.ntt_inp_id + + (s_meta.ntt_block_id >> log_data_stride) * data_stride_u64 * s_meta.ntt_block_size; + } else { + data += (uint64_t)s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id; + } + + UNROLL + for (uint32_t i = 0; i < 8; i++) { + data[i * 8 * data_stride_u64] = X[i]; + } + } + DEVICE_INLINE void loadGlobalData32(const E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta) {