From 3fd217e4fa569d49fdb3d86a7e1964e810d07deb Mon Sep 17 00:00:00 2001 From: Hanting Zhang Date: Wed, 29 Nov 2023 15:50:35 -0800 Subject: [PATCH] -- wip -- --- Cargo.toml | 9 ++-- benches/main.rs | 11 +++++ cuda/pallas.cu | 34 +++++++++++++-- cuda/vesta.cu | 34 +++++++++++++-- src/lib.rs | 112 ++++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 188 insertions(+), 12 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 9d541f0..377a56a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,16 +21,17 @@ include = [ default = [] # Compile in portable mode, without ISA extensions. # Binary can be executed on all systems. -portable = [ "semolina/portable" ] +portable = ["semolina/portable"] # Enable ADX even if the host CPU doesn't support it. # Binary can be executed on Broadwell+ and Ryzen+ systems. -force-adx = [ "semolina/force-adx" ] +force-adx = ["semolina/force-adx"] cuda-mobile = [] [dependencies] semolina = "~0.1.3" -sppark = "~0.1.2" -pasta_curves = { git="https://github.com/lurk-lab/pasta_curves", branch="dev", version = ">=0.3.1, <=0.5", features = ["repr-c"] } +sppark = { git = "https://github.com/lurk-lab/sppark.git", branch = "preallocated-msm" } +pasta_curves = { git = "https://github.com/lurk-lab/pasta_curves", branch = "dev", version = ">=0.3.1, <=0.5", features = ["repr-c"] } +paste = "1.0.14" [build-dependencies] cc = "^1.0.70" diff --git a/benches/main.rs b/benches/main.rs index 5bf3f56..5866e32 100644 --- a/benches/main.rs +++ b/benches/main.rs @@ -66,6 +66,17 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); + let context = pasta_msm::pallas_init(&points, npoints); + + group.bench_function( + format!("preallocated 2**{} points", bench_npow), + |b| { + b.iter(|| { + let _ = pasta_msm::pallas_with(&context, npoints, &scalars); + }) + }, + ); + group.finish(); } } diff --git a/cuda/pallas.cu b/cuda/pallas.cu index f3897bb..2a02eae 100644 --- a/cuda/pallas.cu +++ b/cuda/pallas.cu @@ -17,8 +17,34 @@ typedef vesta_t scalar_t; #include #ifndef __CUDA_ARCH__ -extern "C" -RustError cuda_pippenger_pallas(point_t *out, const affine_t points[], size_t npoints, - const scalar_t scalars[]) -{ return mult_pippenger(out, points, npoints, scalars); } + +struct msm_context_pallas_t +{ + affine_t::mem_t *d_points; +}; + +extern "C" void drop_msm_context_t(msm_context_t &ref) +{ + CUDA_OK(cudaFree(ref.d_points)); +} + +extern "C" RustError +cuda_pippenger_pallas_init(const affine_t points[], size_t npoints, msm_context_pallas_t *msm_context) +{ + msm_t msm{points, npoints, false}; + msm_context->d_points = msm.get_d_points(); + return RustError{cudaSuccess}; +} + +extern "C" RustError cuda_pippenger_pallas(point_t *out, const affine_t points[], size_t npoints, + const scalar_t scalars[]) +{ + return mult_pippenger(out, points, npoints, scalars); +} + +extern "C" RustError cuda_pippenger_pallas_with(point_t *out, msm_context_t *msm_context, size_t npoints, + const scalar_t scalars[]) +{ + return mult_pippenger_with(out, msm_context->d_points, npoints, scalars); +} #endif diff --git a/cuda/vesta.cu b/cuda/vesta.cu index a926c6d..f1ac29a 100644 --- a/cuda/vesta.cu +++ b/cuda/vesta.cu @@ -17,8 +17,34 @@ typedef pallas_t scalar_t; #include #ifndef __CUDA_ARCH__ -extern "C" -RustError cuda_pippenger_vesta(point_t *out, const affine_t points[], size_t npoints, - const scalar_t scalars[]) -{ return mult_pippenger(out, points, npoints, scalars); } + +struct msm_context_vesta_t +{ + affine_t::mem_t *d_points; +}; + +extern "C" void drop_msm_context_t(msm_context_t &ref) +{ + CUDA_OK(cudaFree(ref.d_points)); +} + +extern "C" RustError +cuda_pippenger_vesta_init(const affine_t points[], size_t npoints, msm_context_vesta_t *msm_context) +{ + msm_t msm{points, npoints, false}; + msm_context->d_points = msm.get_d_points(); + return RustError{cudaSuccess}; +} + +extern "C" RustError cuda_pippenger_vesta(point_t *out, const affine_t points[], size_t npoints, + const scalar_t scalars[]) +{ + return mult_pippenger(out, points, npoints, scalars); +} + +extern "C" RustError cuda_pippenger_vesta_with(point_t *out, msm_context_t *msm_context, size_t npoints, + const scalar_t scalars[]) +{ + return mult_pippenger_with(out, msm_context->d_points, npoints, scalars); +} #endif diff --git a/src/lib.rs b/src/lib.rs index 00e6878..89a1052 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,6 +4,45 @@ extern crate semolina; +#[cfg(feature = "cuda")] +#[repr(C)] +#[derive(Debug)] +pub struct MSMContext { + context: *const std::ffi::c_void, +} + +#[cfg(feature = "cuda")] +unsafe impl Send for MSMContext {} + +#[cfg(feature = "cuda")] +unsafe impl Sync for MSMContext {} + +#[cfg(feature = "cuda")] +impl Default for MSMContext { + fn default() -> Self { + Self { context: std::ptr::null() } + } +} + +#[cfg(feature = "cuda")] +impl Clone for MSMContext { + fn clone(&self) -> Self { + Self { context: self.context.clone() } + } +} + +#[cfg(feature = "cuda")] +// TODO: check for device-side memory leaks +impl Drop for MSMContext { + fn drop(&mut self) { + extern "C" { + fn drop_msm_context_t(by_ref: &MSMContext); + } + unsafe { drop_msm_context_t(std::mem::transmute::<&_, &_>(self)) }; + self.context = core::ptr::null(); + } +} + #[cfg(feature = "cuda")] sppark::cuda_error!(); #[cfg(feature = "cuda")] @@ -31,6 +70,79 @@ macro_rules! multi_scalar_mult { ); } + paste::paste! { + #[cfg(feature = "cuda")] + pub fn [<$pasta _init>]( + points: &[$pasta::Affine], + npoints: usize, + ) -> MSMContext { + unsafe { assert!(!CUDA_OFF && cuda_available(), "feature = \"cuda\" must be enabled") }; + if npoints != points.len() && npoints < 1 << 16 { + panic!("length mismatch or less than 10**16") + } + extern "C" { + fn []( + points: *const pallas::Affine, + npoints: usize, + d_points: &mut MSMContext, + is_mont: bool, + ) -> cuda::Error; + + } + + let mut ret = MSMContext::default(); + let err = unsafe { + []( + points.as_ptr() as *const _, + npoints, + &mut ret, + true, + ) + }; + if err.code != 0 { + panic!("{}", String::from(err)); + } + ret + } + + #[cfg(feature = "cuda")] + pub fn [<$pasta _with>]( + context: &MSMContext, + npoints: usize, + scalars: &[$pasta::Scalar], + ) -> $pasta::Point { + unsafe { assert!(!CUDA_OFF && cuda_available(), "feature = \"cuda\" must be enabled") }; + if npoints != scalars.len() && npoints < 1 << 16 { + panic!("length mismatch or less than 10**16") + } + extern "C" { + fn []( + out: *mut $pasta::Point, + context: &MSMContext, + npoints: usize, + scalars: *const $pasta::Scalar, + is_mont: bool, + ) -> cuda::Error; + + } + + let mut ret = $pasta::Point::default(); + let err = unsafe { + []( + &mut ret, + context, + npoints, + &scalars[0], + true, + ) + }; + if err.code != 0 { + panic!("{}", String::from(err)); + } + ret + } + } + pub fn $pasta( points: &[$pasta::Affine], scalars: &[$pasta::Scalar],