From 436f40178b07c536f9f74c47b1198206923510fc Mon Sep 17 00:00:00 2001 From: Jeremy Felder Date: Sun, 29 Sep 2024 13:48:51 +0300 Subject: [PATCH] Add FRI folding for M31 (#618) This PR adds FRI folding for M31 field. --- icicle/include/fri/fri.cuh | 64 ++++ icicle/src/fields/CMakeLists.txt | 6 + icicle/src/fri/extern.cu | 55 +++ icicle/src/fri/fri.cu | 154 +++++++++ .../icicle-fields/icicle-m31/src/fri/mod.rs | 313 ++++++++++++++++++ .../rust/icicle-fields/icicle-m31/src/lib.rs | 1 + 6 files changed, 593 insertions(+) create mode 100644 icicle/include/fri/fri.cuh create mode 100644 icicle/src/fri/extern.cu create mode 100644 icicle/src/fri/fri.cu create mode 100644 wrappers/rust/icicle-fields/icicle-m31/src/fri/mod.rs diff --git a/icicle/include/fri/fri.cuh b/icicle/include/fri/fri.cuh new file mode 100644 index 000000000..7d20e6892 --- /dev/null +++ b/icicle/include/fri/fri.cuh @@ -0,0 +1,64 @@ +#pragma once +#ifndef FRI_H +#define FRI_H + +#include + +#include "gpu-utils/device_context.cuh" + +namespace fri { + + struct FriConfig { + device_context::DeviceContext ctx; + bool are_evals_on_device; + bool are_domain_elements_on_device; + bool are_results_on_device; + bool is_async; + }; + + /** + * @brief Folds a layer's evaluation into a degree d/2 evaluation using the provided folding factor alpha. + * + * @param evals Pointer to the array of evaluation in the current FRI layer. + * @param domain_xs Pointer to a subset of line domain values. + * @param alpha The folding factor used in the FRI protocol. + * @param folded_evals Pointer to the array where the folded evaluations will be stored. + * @param n The number of evaluations in the original layer (before folding). + * + * @tparam S The scalar field type used for domain_xs. + * @tparam E The evaluation type, typically the same as the field element type. + * + * @note The size of the output array 'folded_evals' should be half of 'n', as folding reduces the number of + * evaluations by half. + */ + template + cudaError_t fold_line(E* eval, S* domain_xs, E alpha, E* folded_eval, uint64_t n, FriConfig& cfg); + + /** + * @brief Folds a layer of FRI evaluations from a circle into a line. + * + * This function performs the folding operation in the FRI (Fast Reed-Solomon IOP of Proximity) protocol, + * specifically for evaluations on a circle domain. It takes a layer of evaluations on a circle and folds + * them into a line using the provided folding factor alpha. + * + * @param evals Pointer to the array of evaluations in the current FRI layer, representing points on a circle. + * @param domain_ys Pointer to the array of y-coordinates of the circle points in the domain of the circle that evals + * represents. + * @param alpha The folding factor used in the FRI protocol. + * @param folded_evals Pointer to the array where the folded evaluations (now on a line) will be stored. + * @param n The number of evaluations in the original layer (before folding). + * + * @tparam S The scalar field type used for alpha and domain_ys. + * @tparam E The evaluation type, typically the same as the field element type. + * + * @note The size of the output array 'folded_evals' should be half of 'n', as folding reduces the number of + * evaluations by half. + * @note This function is specifically designed for folding evaluations from a circular domain to a linear domain. + */ + + template + cudaError_t fold_circle_into_line(E* eval, S* domain_ys, E alpha, E* folded_eval, uint64_t n, FriConfig& cfg); + +} // namespace fri + +#endif \ No newline at end of file diff --git a/icicle/src/fields/CMakeLists.txt b/icicle/src/fields/CMakeLists.txt index 1853e319f..0a0138894 100644 --- a/icicle/src/fields/CMakeLists.txt +++ b/icicle/src/fields/CMakeLists.txt @@ -4,6 +4,7 @@ endif () SET(SUPPORTED_FIELDS_WITHOUT_NTT grumpkin;m31) SET(SUPPORTED_FIELDS_WITHOUT_POSEIDON2 bls12_381;bls12_377;grumpkin;bw6_761;stark252;m31) +SET(SUPPORTED_FIELDS_WITH_FRI m31) set(TARGET icicle_field) @@ -42,6 +43,11 @@ if (NOT FIELD IN_LIST SUPPORTED_FIELDS_WITHOUT_NTT) list(APPEND FIELD_SOURCE ${POLYNOMIAL_SOURCE_FILES}) # requires NTT endif() +if (FIELD IN_LIST SUPPORTED_FIELDS_WITH_FRI) + list(APPEND FIELD_SOURCE ${SRC}/fri/extern.cu) + list(APPEND FIELD_SOURCE ${SRC}/fri/fri.cu) +endif() + add_library(${TARGET} STATIC ${FIELD_SOURCE}) target_include_directories(${TARGET} PUBLIC ${CMAKE_SOURCE_DIR}/include/) set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME "ingo_field_${FIELD}") diff --git a/icicle/src/fri/extern.cu b/icicle/src/fri/extern.cu new file mode 100644 index 000000000..e2fdaa5e6 --- /dev/null +++ b/icicle/src/fri/extern.cu @@ -0,0 +1,55 @@ +#include "fields/field_config.cuh" +using namespace field_config; + +#include "fri.cu" +#include "utils/utils.h" + +namespace fri { + /** + * Extern "C" version of [fold_line](@ref fold_line) function with the following values of + * template parameters (where the field is given by `-DFIELD` env variable during build): + * - `E` is the extension field type used for evaluations and alpha + * - `S` is the scalar field type used for domain elements + * @param line_eval Pointer to the array of evaluations on the line + * @param domain_elements Pointer to the array of domain elements + * @param alpha The folding factor + * @param folded_evals Pointer to the array where folded evaluations will be stored + * @param n The number of evaluations + * @param ctx The device context; if the stream is not 0, then everything is run async + * @return `cudaSuccess` if the execution was successful and an error code otherwise. + */ + extern "C" cudaError_t CONCAT_EXPAND(FIELD, fold_line)( + extension_t* line_eval, + scalar_t* domain_elements, + extension_t alpha, + extension_t* folded_evals, + uint64_t n, + FriConfig& cfg) + { + return fri::fold_line(line_eval, domain_elements, alpha, folded_evals, n, cfg); + }; + + /** + * Extern "C" version of [fold_circle_into_line](@ref fold_circle_into_line) function with the following values of + * template parameters (where the field is given by `-DFIELD` env variable during build): + * - `E` is the extension field type used for evaluations and alpha + * - `S` is the scalar field type used for domain elements + * @param circle_evals Pointer to the array of evaluations on the circle + * @param domain_elements Pointer to the array of domain elements + * @param alpha The folding factor + * @param folded_line_evals Pointer to the array where folded evaluations will be stored + * @param n The number of evaluations + * @param ctx The device context; if the stream is not 0, then everything is run async + * @return `cudaSuccess` if the execution was successful and an error code otherwise. + */ + extern "C" cudaError_t CONCAT_EXPAND(FIELD, fold_circle_into_line)( + extension_t* circle_evals, + scalar_t* domain_elements, + extension_t alpha, + extension_t* folded_line_evals, + uint64_t n, + FriConfig& cfg) + { + return fri::fold_circle_into_line(circle_evals, domain_elements, alpha, folded_line_evals, n, cfg); + }; +} // namespace fri diff --git a/icicle/src/fri/fri.cu b/icicle/src/fri/fri.cu new file mode 100644 index 000000000..7d66cb3e6 --- /dev/null +++ b/icicle/src/fri/fri.cu @@ -0,0 +1,154 @@ +#include + +#include "fri/fri.cuh" + +#include "fields/field.cuh" +#include "gpu-utils/error_handler.cuh" +#include "gpu-utils/device_context.cuh" + +namespace fri { + + namespace { + template + __device__ void ibutterfly(E& v0, E& v1, const S& itwid) + { + E tmp = v0; + v0 = tmp + v1; + v1 = (tmp - v1) * itwid; + } + + template + __global__ void fold_line_kernel(E* eval, S* domain_xs, E alpha, E* folded_eval, uint64_t n) + { + unsigned idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx % 2 == 0 && idx < n) { + E f_x = eval[idx]; // even + E f_x_neg = eval[idx + 1]; // odd + S x_domain = domain_xs[idx / 2]; + ibutterfly(f_x, f_x_neg, S::inverse(x_domain)); + auto folded_eval_idx = idx / 2; + folded_eval[folded_eval_idx] = f_x + alpha * f_x_neg; + } + } + + template + __global__ void fold_circle_into_line_kernel(E* eval, S* domain_ys, E alpha, E alpha_sq, E* folded_eval, uint64_t n) + { + unsigned idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx % 2 == 0 && idx < n) { + E f0_px = eval[idx]; + E f1_px = eval[idx + 1]; + ibutterfly(f0_px, f1_px, S::inverse(domain_ys[idx / 2])); + E f_prime = f0_px + alpha * f1_px; + auto folded_eval_idx = idx / 2; + folded_eval[folded_eval_idx] = folded_eval[folded_eval_idx] * alpha_sq + f_prime; + } + } + } // namespace + + template + cudaError_t fold_line(E* eval, S* domain_xs, E alpha, E* folded_eval, uint64_t n, FriConfig& cfg) + { + CHK_INIT_IF_RETURN(); + + cudaStream_t stream = cfg.ctx.stream; + // Allocate and move line domain evals to device if necessary + E* d_eval; + if (!cfg.are_evals_on_device) { + auto data_size = sizeof(E) * n; + CHK_IF_RETURN(cudaMallocAsync(&d_eval, data_size, stream)); + CHK_IF_RETURN(cudaMemcpyAsync(d_eval, eval, data_size, cudaMemcpyHostToDevice, stream)); + } else { + d_eval = eval; + } + + // Allocate and move domain's elements to device if necessary + S* d_domain_xs; + if (!cfg.are_domain_elements_on_device) { + auto data_size = sizeof(S) * n / 2; + CHK_IF_RETURN(cudaMallocAsync(&d_domain_xs, data_size, stream)); + CHK_IF_RETURN(cudaMemcpyAsync(d_domain_xs, domain_xs, data_size, cudaMemcpyHostToDevice, stream)); + } else { + d_domain_xs = domain_xs; + } + + // Allocate folded_eval if pointer is not a device pointer + E* d_folded_eval; + if (!cfg.are_results_on_device) { + CHK_IF_RETURN(cudaMallocAsync(&d_folded_eval, sizeof(E) * n / 2, stream)); + } else { + d_folded_eval = folded_eval; + } + + uint64_t num_threads = 256; + uint64_t num_blocks = (n / 2 + num_threads - 1) / num_threads; + fold_line_kernel<<>>(d_eval, d_domain_xs, alpha, d_folded_eval, n); + + // Move folded_eval back to host if requested + if (!cfg.are_results_on_device) { + CHK_IF_RETURN(cudaMemcpyAsync(folded_eval, d_folded_eval, sizeof(E) * n / 2, cudaMemcpyDeviceToHost, stream)); + CHK_IF_RETURN(cudaFreeAsync(d_folded_eval, stream)); + } + if (!cfg.are_domain_elements_on_device) { CHK_IF_RETURN(cudaFreeAsync(d_domain_xs, stream)); } + if (!cfg.are_evals_on_device) { CHK_IF_RETURN(cudaFreeAsync(d_eval, stream)); } + + // Sync if stream is default stream + if (stream == 0) CHK_IF_RETURN(cudaStreamSynchronize(stream)); + + return CHK_LAST(); + } + + template + cudaError_t fold_circle_into_line(E* eval, S* domain_ys, E alpha, E* folded_eval, uint64_t n, FriConfig& cfg) + { + CHK_INIT_IF_RETURN(); + + cudaStream_t stream = cfg.ctx.stream; + // Allocate and move circle domain evals to device if necessary + E* d_eval; + if (!cfg.are_evals_on_device) { + auto data_size = sizeof(E) * n; + CHK_IF_RETURN(cudaMallocAsync(&d_eval, data_size, stream)); + CHK_IF_RETURN(cudaMemcpyAsync(d_eval, eval, data_size, cudaMemcpyHostToDevice, stream)); + } else { + d_eval = eval; + } + + // Allocate and move domain's elements to device if necessary + S* d_domain_ys; + if (!cfg.are_domain_elements_on_device) { + auto data_size = sizeof(S) * n / 2; + CHK_IF_RETURN(cudaMallocAsync(&d_domain_ys, data_size, stream)); + CHK_IF_RETURN(cudaMemcpyAsync(d_domain_ys, domain_ys, data_size, cudaMemcpyHostToDevice, stream)); + } else { + d_domain_ys = domain_ys; + } + + // Allocate folded_evals if pointer is not a device pointer + E* d_folded_eval; + if (!cfg.are_results_on_device) { + CHK_IF_RETURN(cudaMallocAsync(&d_folded_eval, sizeof(E) * n / 2, stream)); + } else { + d_folded_eval = folded_eval; + } + + E alpha_sq = alpha * alpha; + uint64_t num_threads = 256; + uint64_t num_blocks = (n / 2 + num_threads - 1) / num_threads; + fold_circle_into_line_kernel<<>>( + d_eval, d_domain_ys, alpha, alpha_sq, d_folded_eval, n); + + // Move folded_evals back to host if requested + if (!cfg.are_results_on_device) { + CHK_IF_RETURN(cudaMemcpyAsync(folded_eval, d_folded_eval, sizeof(E) * n / 2, cudaMemcpyDeviceToHost, stream)); + CHK_IF_RETURN(cudaFreeAsync(d_folded_eval, stream)); + } + if (!cfg.are_domain_elements_on_device) { CHK_IF_RETURN(cudaFreeAsync(d_domain_ys, stream)); } + if (!cfg.are_evals_on_device) { CHK_IF_RETURN(cudaFreeAsync(d_eval, stream)); } + + // Sync if stream is default stream + if (stream == 0) CHK_IF_RETURN(cudaStreamSynchronize(stream)); + + return CHK_LAST(); + } +} // namespace fri diff --git a/wrappers/rust/icicle-fields/icicle-m31/src/fri/mod.rs b/wrappers/rust/icicle-fields/icicle-m31/src/fri/mod.rs new file mode 100644 index 000000000..f75aea68b --- /dev/null +++ b/wrappers/rust/icicle-fields/icicle-m31/src/fri/mod.rs @@ -0,0 +1,313 @@ +use crate::field::{ExtensionField, ScalarField}; +use icicle_core::error::IcicleResult; +use icicle_core::traits::IcicleResultWrap; +use icicle_cuda_runtime::device::check_device; +use icicle_cuda_runtime::device_context::{DeviceContext, DEFAULT_DEVICE_ID}; +use icicle_cuda_runtime::error::CudaError; +use icicle_cuda_runtime::memory::HostOrDeviceSlice; + +/// Struct that encodes FRI parameters. +#[repr(C)] +#[derive(Debug, Clone)] +pub struct FriConfig<'a> { + /// Details related to the device such as its id and stream id. See [DeviceContext](@ref device_context::DeviceContext). + pub ctx: DeviceContext<'a>, + are_evals_on_device: bool, + are_domain_elements_on_device: bool, + are_results_on_device: bool, + /// Whether to run the vector operations asynchronously. If set to `true`, the functions will be non-blocking and you'd need to synchronize + /// it explicitly by running `stream.synchronize()`. If set to false, the functions will block the current CPU thread. + pub is_async: bool, +} + +impl<'a> Default for FriConfig<'a> { + fn default() -> Self { + Self::default_for_device(DEFAULT_DEVICE_ID) + } +} + +impl<'a> FriConfig<'a> { + pub fn default_for_device(device_id: usize) -> Self { + FriConfig { + ctx: DeviceContext::default_for_device(device_id), + are_evals_on_device: false, + are_domain_elements_on_device: false, + are_results_on_device: false, + is_async: false, + } + } +} + +fn check_fri_args<'a, F, S>( + eval: &(impl HostOrDeviceSlice + ?Sized), + domain_elements: &(impl HostOrDeviceSlice + ?Sized), + folded_eval: &(impl HostOrDeviceSlice + ?Sized), + cfg: &FriConfig<'a>, +) -> FriConfig<'a> { + if eval.len() / 2 != domain_elements.len() { + panic!( + "Number of domain elements is not half of the evaluation's domain size; {} != {} / 2", + eval.len(), + domain_elements.len() + ); + } + + if eval.len() / 2 != folded_eval.len() { + panic!( + "Folded poly degree is not half of the evaluation poly's degree; {} != {} / 2", + eval.len(), + folded_eval.len() + ); + } + + let ctx_device_id = cfg + .ctx + .device_id; + + if let Some(device_id) = eval.device_id() { + assert_eq!(device_id, ctx_device_id, "Device ids in eval and context are different"); + } + if let Some(device_id) = domain_elements.device_id() { + assert_eq!( + device_id, ctx_device_id, + "Device ids in domain_elements and context are different" + ); + } + if let Some(device_id) = folded_eval.device_id() { + assert_eq!( + device_id, ctx_device_id, + "Device ids in folded_eval and context are different" + ); + } + check_device(ctx_device_id); + + let mut res_cfg = cfg.clone(); + res_cfg.are_evals_on_device = eval.is_on_device(); + res_cfg.are_domain_elements_on_device = domain_elements.is_on_device(); + res_cfg.are_results_on_device = folded_eval.is_on_device(); + res_cfg +} + +pub fn fold_line( + eval: &(impl HostOrDeviceSlice + ?Sized), + domain_elements: &(impl HostOrDeviceSlice + ?Sized), + folded_eval: &mut (impl HostOrDeviceSlice + ?Sized), + alpha: ExtensionField, + cfg: &FriConfig, +) -> IcicleResult<()> { + let cfg = check_fri_args(eval, domain_elements, folded_eval, cfg); + unsafe { + _fri::fold_line( + eval.as_ptr(), + domain_elements.as_ptr(), + &alpha, + folded_eval.as_mut_ptr(), + eval.len() as u64, + &cfg as *const FriConfig, + ) + .wrap() + } +} + +pub fn fold_circle_into_line( + eval: &(impl HostOrDeviceSlice + ?Sized), + domain_elements: &(impl HostOrDeviceSlice + ?Sized), + folded_eval: &mut (impl HostOrDeviceSlice + ?Sized), + alpha: ExtensionField, + cfg: &FriConfig, +) -> IcicleResult<()> { + let cfg = check_fri_args(eval, domain_elements, folded_eval, cfg); + unsafe { + _fri::fold_circle_into_line( + eval.as_ptr(), + domain_elements.as_ptr(), + &alpha, + folded_eval.as_mut_ptr(), + eval.len() as u64, + &cfg as *const FriConfig, + ) + .wrap() + } +} + +mod _fri { + use super::{CudaError, ExtensionField, FriConfig, ScalarField}; + + extern "C" { + #[link_name = "m31_fold_line"] + pub(crate) fn fold_line( + line_eval: *const ExtensionField, + domain_elements: *const ScalarField, + alpha: &ExtensionField, + folded_eval: *mut ExtensionField, + n: u64, + cfg: *const FriConfig, + ) -> CudaError; + + #[link_name = "m31_fold_circle_into_line"] + pub(crate) fn fold_circle_into_line( + circle_eval: *const ExtensionField, + domain_elements: *const ScalarField, + alpha: &ExtensionField, + folded_line_eval: *mut ExtensionField, + n: u64, + cfg: *const FriConfig, + ) -> CudaError; + } +} + +#[cfg(test)] +pub(crate) mod tests { + use super::*; + use crate::field::{ExtensionField, ScalarField}; + use icicle_core::traits::FieldImpl; + use icicle_cuda_runtime::memory::{DeviceVec, HostSlice}; + use std::iter::zip; + + #[test] + fn test_fold_line() { + // All hardcoded values were generated with https://github.com/starkware-libs/stwo/blob/f976890/crates/prover/src/core/fri.rs#L1005-L1037 + const DEGREE: usize = 8; + + // Set evals + let evals_raw: [u32; DEGREE] = [ + 1358331652, 807347720, 543926930, 1585623140, 1753377641, 616790922, 630401694, 1294134897, + ]; + let evals_as_extension = evals_raw + .into_iter() + .map(|val: u32| ExtensionField::from_u32(val)) + .collect::>(); + let eval = HostSlice::from_slice(evals_as_extension.as_slice()); + let mut d_eval = DeviceVec::::cuda_malloc(DEGREE).unwrap(); + d_eval + .copy_from_host(eval) + .unwrap(); + + // Set domain + let domain_raw: [u32; DEGREE / 2] = [1179735656, 1241207368, 1415090252, 2112881577]; + let domain_as_scalar = domain_raw + .into_iter() + .map(|val: u32| ScalarField::from_u32(val)) + .collect::>(); + let domain_elements = HostSlice::from_slice(domain_as_scalar.as_slice()); + let mut d_domain_elements = DeviceVec::::cuda_malloc(DEGREE / 2).unwrap(); + d_domain_elements + .copy_from_host(domain_elements) + .unwrap(); + + // Alloc folded_evals + let mut folded_eval_raw = vec![ExtensionField::zero(); DEGREE / 2]; + let folded_eval = HostSlice::from_mut_slice(folded_eval_raw.as_mut_slice()); + let mut d_folded_eval = DeviceVec::::cuda_malloc(DEGREE / 2).unwrap(); + + let alpha = ExtensionField::from_u32(19283); + let cfg = FriConfig::default(); + + let res = fold_line(&d_eval[..], &d_domain_elements[..], &mut d_folded_eval[..], alpha, &cfg); + + assert!(res.is_ok()); + + let expected_folded_evals_raw: [u32; DEGREE / 2] = [547848116, 1352534073, 2053322292, 341725613]; + let expected_folded_evals_extension = expected_folded_evals_raw + .into_iter() + .map(|val: u32| ExtensionField::from_u32(val)) + .collect::>(); + let expected_folded_evals = expected_folded_evals_extension.as_slice(); + + d_folded_eval + .copy_to_host(folded_eval) + .unwrap(); + + for (i, (folded_eval_val, expected_folded_eval_val)) in + zip(folded_eval.as_slice(), expected_folded_evals).enumerate() + { + assert_eq!( + folded_eval_val, expected_folded_eval_val, + "Mismatch of folded eval at {i}" + ); + } + } + + #[test] + fn test_fold_circle_to_line() { + // All hardcoded values were generated with https://github.com/starkware-libs/stwo/blob/f976890/crates/prover/src/core/fri.rs#L1040-L1053 + const DEGREE: usize = 64; + let circle_eval_raw: [u32; DEGREE] = [ + 466407290, 127986842, 1870304883, 875137047, 1381744584, 1242514872, 1657247602, 1816542136, 18610701, + 183082621, 1291388290, 1665658712, 1768829380, 872721779, 1113994239, 827698214, 57598558, 1809783851, + 1582268514, 1018797774, 1927599636, 619773471, 802072749, 2111764399, 714973298, 532899888, 671071637, + 536208302, 1268828963, 255940280, 586928868, 535875357, 1650651309, 1473550629, 1387441966, 893930940, + 126593346, 1263510627, 18204497, 211871416, 604224095, 465540164, 1007455733, 755529771, 2130798047, + 871433949, 1073797249, 1097851807, 369407795, 302384846, 1904956607, 1168797665, 352925744, 10934213, + 409562797, 1646664722, 676414749, 35135895, 2606032, 2121020146, 1205801045, 1079025338, 2111544534, + 1635203417, + ]; + let circle_eval_as_extension = circle_eval_raw + .into_iter() + .map(|val: u32| ExtensionField::from_u32(val)) + .collect::>(); + let circle_eval = HostSlice::from_slice(circle_eval_as_extension.as_slice()); + let mut d_circle_eval = DeviceVec::::cuda_malloc(DEGREE).unwrap(); + d_circle_eval + .copy_from_host(circle_eval) + .unwrap(); + + let domain_raw: [u32; DEGREE / 2] = [ + 1774253895, 373229752, 1309288441, 838195206, 262191051, 1885292596, 408478793, 1739004854, 212443077, + 1935040570, 1941424532, 206059115, 883753057, 1263730590, 350742286, 1796741361, 404685994, 1742797653, + 7144319, 2140339328, 68458636, 2079025011, 2137679949, 9803698, 228509164, 1918974483, 2132953617, + 14530030, 134155457, 2013328190, 1108537731, 1038945916, + ]; + let domain_as_scalar = domain_raw + .into_iter() + .map(|val: u32| ScalarField::from_u32(val)) + .collect::>(); + let domain_elements = HostSlice::from_slice(domain_as_scalar.as_slice()); + let mut d_domain_elements = DeviceVec::::cuda_malloc(DEGREE / 2).unwrap(); + d_domain_elements + .copy_from_host(domain_elements) + .unwrap(); + + let mut folded_eval_raw = vec![ExtensionField::zero(); DEGREE / 2]; + let folded_eval = HostSlice::from_mut_slice(folded_eval_raw.as_mut_slice()); + let mut d_folded_eval = DeviceVec::::cuda_malloc(DEGREE / 2).unwrap(); + + let alpha = ExtensionField::one(); + let cfg = FriConfig::default(); + + let res = fold_circle_into_line( + &d_circle_eval[..], + &d_domain_elements[..], + &mut d_folded_eval[..], + alpha, + &cfg, + ); + + assert!(res.is_ok()); + + let expected_folded_evals_raw: [u32; DEGREE / 2] = [ + 1188788264, 1195916566, 953551618, 505128535, 403386644, 1619126710, 988135024, 1735901259, 1587281171, + 907165282, 799778920, 1532707002, 348262725, 267076231, 902054839, 98124803, 1953436582, 267778518, + 632724299, 460151826, 2139528518, 1378487361, 1709496698, 48330818, 1343585282, 1852541250, 727719914, + 1964971391, 1423101288, 2099768709, 274685472, 1051044961, + ]; + let expected_folded_evals_extension = expected_folded_evals_raw + .into_iter() + .map(|val: u32| ExtensionField::from_u32(val)) + .collect::>(); + let expected_folded_evals = expected_folded_evals_extension.as_slice(); + + d_folded_eval + .copy_to_host(folded_eval) + .unwrap(); + + for (i, (folded_eval_val, expected_folded_eval_val)) in + zip(folded_eval.as_slice(), expected_folded_evals).enumerate() + { + assert_eq!( + folded_eval_val, expected_folded_eval_val, + "Mismatch of folded eval at {i}" + ); + } + } +} diff --git a/wrappers/rust/icicle-fields/icicle-m31/src/lib.rs b/wrappers/rust/icicle-fields/icicle-m31/src/lib.rs index 001f51ba9..4f3bd8f62 100644 --- a/wrappers/rust/icicle-fields/icicle-m31/src/lib.rs +++ b/wrappers/rust/icicle-fields/icicle-m31/src/lib.rs @@ -1,2 +1,3 @@ pub mod field; +pub mod fri; pub mod vec_ops;