Skip to content
This repository has been archived by the owner on Feb 19, 2024. It is now read-only.

Commit

Permalink
-- wip --
Browse files Browse the repository at this point in the history
  • Loading branch information
winston-h-zhang authored and Hanting Zhang committed Nov 30, 2023
1 parent 182b971 commit 3fd217e
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 12 deletions.
9 changes: 5 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,17 @@ include = [
default = []
# Compile in portable mode, without ISA extensions.
# Binary can be executed on all systems.
portable = [ "semolina/portable" ]
portable = ["semolina/portable"]
# Enable ADX even if the host CPU doesn't support it.
# Binary can be executed on Broadwell+ and Ryzen+ systems.
force-adx = [ "semolina/force-adx" ]
force-adx = ["semolina/force-adx"]
cuda-mobile = []

[dependencies]
semolina = "~0.1.3"
sppark = "~0.1.2"
pasta_curves = { git="https://github.com/lurk-lab/pasta_curves", branch="dev", version = ">=0.3.1, <=0.5", features = ["repr-c"] }
sppark = { git = "https://github.com/lurk-lab/sppark.git", branch = "preallocated-msm" }
pasta_curves = { git = "https://github.com/lurk-lab/pasta_curves", branch = "dev", version = ">=0.3.1, <=0.5", features = ["repr-c"] }
paste = "1.0.14"

[build-dependencies]
cc = "^1.0.70"
Expand Down
11 changes: 11 additions & 0 deletions benches/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,17 @@ fn criterion_benchmark(c: &mut Criterion) {
})
});

let context = pasta_msm::pallas_init(&points, npoints);

group.bench_function(
format!("preallocated 2**{} points", bench_npow),
|b| {
b.iter(|| {
let _ = pasta_msm::pallas_with(&context, npoints, &scalars);
})
},
);

group.finish();
}
}
Expand Down
34 changes: 30 additions & 4 deletions cuda/pallas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,34 @@ typedef vesta_t scalar_t;
#include <msm/pippenger.cuh>

#ifndef __CUDA_ARCH__
extern "C"
RustError cuda_pippenger_pallas(point_t *out, const affine_t points[], size_t npoints,
const scalar_t scalars[])
{ return mult_pippenger<bucket_t>(out, points, npoints, scalars); }

struct msm_context_pallas_t
{
affine_t::mem_t *d_points;
};

extern "C" void drop_msm_context_t(msm_context_t<void> &ref)
{
CUDA_OK(cudaFree(ref.d_points));
}

extern "C" RustError
cuda_pippenger_pallas_init(const affine_t points[], size_t npoints, msm_context_pallas_t *msm_context)
{
msm_t<bucket_t, point_t, affine_t, scalar_t> msm{points, npoints, false};
msm_context->d_points = msm.get_d_points();
return RustError{cudaSuccess};
}

extern "C" RustError cuda_pippenger_pallas(point_t *out, const affine_t points[], size_t npoints,
const scalar_t scalars[])
{
return mult_pippenger<bucket_t>(out, points, npoints, scalars);
}

extern "C" RustError cuda_pippenger_pallas_with(point_t *out, msm_context_t<affine_t::mem_t> *msm_context, size_t npoints,
const scalar_t scalars[])
{
return mult_pippenger_with<bucket_t, point_t, affine_t, scalar_t>(out, msm_context->d_points, npoints, scalars);
}
#endif
34 changes: 30 additions & 4 deletions cuda/vesta.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,34 @@ typedef pallas_t scalar_t;
#include <msm/pippenger.cuh>

#ifndef __CUDA_ARCH__
extern "C"
RustError cuda_pippenger_vesta(point_t *out, const affine_t points[], size_t npoints,
const scalar_t scalars[])
{ return mult_pippenger<bucket_t>(out, points, npoints, scalars); }

struct msm_context_vesta_t
{
affine_t::mem_t *d_points;
};

extern "C" void drop_msm_context_t(msm_context_t<void> &ref)
{
CUDA_OK(cudaFree(ref.d_points));
}

extern "C" RustError
cuda_pippenger_vesta_init(const affine_t points[], size_t npoints, msm_context_vesta_t *msm_context)
{
msm_t<bucket_t, point_t, affine_t, scalar_t> msm{points, npoints, false};
msm_context->d_points = msm.get_d_points();
return RustError{cudaSuccess};
}

extern "C" RustError cuda_pippenger_vesta(point_t *out, const affine_t points[], size_t npoints,
const scalar_t scalars[])
{
return mult_pippenger<bucket_t>(out, points, npoints, scalars);
}

extern "C" RustError cuda_pippenger_vesta_with(point_t *out, msm_context_t<affine_t::mem_t> *msm_context, size_t npoints,
const scalar_t scalars[])
{
return mult_pippenger_with<bucket_t, point_t, affine_t, scalar_t>(out, msm_context->d_points, npoints, scalars);
}
#endif
112 changes: 112 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,45 @@

extern crate semolina;

#[cfg(feature = "cuda")]
#[repr(C)]
#[derive(Debug)]
pub struct MSMContext {
context: *const std::ffi::c_void,
}

#[cfg(feature = "cuda")]
unsafe impl Send for MSMContext {}

#[cfg(feature = "cuda")]
unsafe impl Sync for MSMContext {}

#[cfg(feature = "cuda")]
impl Default for MSMContext {
fn default() -> Self {
Self { context: std::ptr::null() }
}
}

#[cfg(feature = "cuda")]
impl Clone for MSMContext {
fn clone(&self) -> Self {
Self { context: self.context.clone() }
}
}

#[cfg(feature = "cuda")]
// TODO: check for device-side memory leaks
impl Drop for MSMContext {
fn drop(&mut self) {
extern "C" {
fn drop_msm_context_t(by_ref: &MSMContext);
}
unsafe { drop_msm_context_t(std::mem::transmute::<&_, &_>(self)) };
self.context = core::ptr::null();
}
}

#[cfg(feature = "cuda")]
sppark::cuda_error!();
#[cfg(feature = "cuda")]
Expand Down Expand Up @@ -31,6 +70,79 @@ macro_rules! multi_scalar_mult {
);
}

paste::paste! {
#[cfg(feature = "cuda")]
pub fn [<$pasta _init>](
points: &[$pasta::Affine],
npoints: usize,
) -> MSMContext {
unsafe { assert!(!CUDA_OFF && cuda_available(), "feature = \"cuda\" must be enabled") };
if npoints != points.len() && npoints < 1 << 16 {
panic!("length mismatch or less than 10**16")
}
extern "C" {
fn [<cuda_pippenger_ $pasta _init>](
points: *const pallas::Affine,
npoints: usize,
d_points: &mut MSMContext,
is_mont: bool,
) -> cuda::Error;

}

let mut ret = MSMContext::default();
let err = unsafe {
[<cuda_pippenger_ $pasta _init>](
points.as_ptr() as *const _,
npoints,
&mut ret,
true,
)
};
if err.code != 0 {
panic!("{}", String::from(err));
}
ret
}

#[cfg(feature = "cuda")]
pub fn [<$pasta _with>](
context: &MSMContext,
npoints: usize,
scalars: &[$pasta::Scalar],
) -> $pasta::Point {
unsafe { assert!(!CUDA_OFF && cuda_available(), "feature = \"cuda\" must be enabled") };
if npoints != scalars.len() && npoints < 1 << 16 {
panic!("length mismatch or less than 10**16")
}
extern "C" {
fn [<cuda_pippenger_ $pasta _with>](
out: *mut $pasta::Point,
context: &MSMContext,
npoints: usize,
scalars: *const $pasta::Scalar,
is_mont: bool,
) -> cuda::Error;

}

let mut ret = $pasta::Point::default();
let err = unsafe {
[<cuda_pippenger_ $pasta _with>](
&mut ret,
context,
npoints,
&scalars[0],
true,
)
};
if err.code != 0 {
panic!("{}", String::from(err));
}
ret
}
}

pub fn $pasta(
points: &[$pasta::Affine],
scalars: &[$pasta::Scalar],
Expand Down

0 comments on commit 3fd217e

Please sign in to comment.