From 41f0573eae7fed7ede43d2618ff657f05a50dc38 Mon Sep 17 00:00:00 2001 From: Hanting Zhang Date: Sun, 3 Dec 2023 01:35:24 +0000 Subject: [PATCH] preallocate memory --- .vscode/settings.json | 4 +- benches/spmvm.rs | 9 +- cuda/pallas.cu | 17 +++ cuda/vesta.cu | 17 ++- examples/spmvm_pallas.rs | 9 +- examples/spmvm_vesta.rs | 59 ++++----- src/spmvm.rs | 252 --------------------------------------- src/spmvm/mod.rs | 76 ++++++++++++ src/spmvm/pallas.rs | 217 +++++++++++++++++++++++++++++++++ src/spmvm/vesta.rs | 216 +++++++++++++++++++++++++++++++++ 10 files changed, 584 insertions(+), 292 deletions(-) delete mode 100644 src/spmvm.rs create mode 100644 src/spmvm/mod.rs create mode 100644 src/spmvm/pallas.rs create mode 100644 src/spmvm/vesta.rs diff --git a/.vscode/settings.json b/.vscode/settings.json index 0c1c0d0..d16c14d 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,6 +1,8 @@ { "files.associations": { "__locale": "cpp", - "ios": "cpp" + "ios": "cpp", + "functional": "cpp", + "__functional_base": "cpp" } } \ No newline at end of file diff --git a/benches/spmvm.rs b/benches/spmvm.rs index 28f3683..dec0a8f 100644 --- a/benches/spmvm.rs +++ b/benches/spmvm.rs @@ -59,13 +59,14 @@ fn criterion_benchmark(c: &mut Criterion) { .unwrap_or("17".to_string()) .parse() .unwrap(); - let n: usize = 1 << bench_npow; + let n = 1usize << (bench_npow + 1); + let m = 1usize << bench_npow; println!("generating random matrix and scalars, just hang on..."); - let csr = generate_csr(n, n); + let csr = generate_csr(n, m); let cuda_csr = - CudaSparseMatrix::new(&csr.data, &csr.indices, &csr.indptr, n, n); - let W = crate::tests::gen_scalars(n - 10); + CudaSparseMatrix::new(&csr.data, &csr.indices, &csr.indptr, n, m); + let W = crate::tests::gen_scalars(m - 10); let U = crate::tests::gen_scalars(9); let witness = CudaWitness::new(&W, &pallas::Scalar::ONE, &U); let scalars = [W.clone(), vec![pallas::Scalar::ONE], U.clone()].concat(); diff --git a/cuda/pallas.cu b/cuda/pallas.cu index 2dd1a08..5fa27e1 100644 --- a/cuda/pallas.cu +++ b/cuda/pallas.cu @@ -25,17 +25,34 @@ extern "C" RustError cuda_double_pallas(double_host_t *csr, scalar_t * return double_scalars(csr, scalars, out); } +extern "C" void drop_spmvm_context_pallas(spmvm_context_t &ref) +{ + drop_spmvm_context(ref); +} + extern "C" RustError cuda_sparse_matrix_mul_pallas(spmvm_host_t *csr, const scalar_t *scalars, scalar_t *out, size_t nthreads) { return sparse_matrix_mul(csr, scalars, out, nthreads); } +extern "C" RustError cuda_sparse_matrix_witness_init_pallas( + spmvm_host_t *csr, spmvm_context_t *context) +{ + return sparse_matrix_witness_init(csr, context); +} + extern "C" RustError cuda_sparse_matrix_witness_pallas( spmvm_host_t *csr, const witness_t *witness, scalar_t *out, size_t nthreads) { return sparse_matrix_witness(csr, witness, out, nthreads); } +extern "C" RustError cuda_sparse_matrix_witness_with_pallas( + spmvm_context_t *context, const witness_t *witness, scalar_t *out, size_t nthreads) +{ + return sparse_matrix_witness_with(context, witness, out, nthreads); +} + extern "C" RustError cuda_sparse_matrix_witness_pallas_cpu( spmvm_host_t *csr, const witness_t *witness, scalar_t *out) { diff --git a/cuda/vesta.cu b/cuda/vesta.cu index 874c4f0..987e237 100644 --- a/cuda/vesta.cu +++ b/cuda/vesta.cu @@ -19,8 +19,9 @@ typedef pallas_t scalar_t; #ifndef __CUDA_ARCH__ -extern "C" void drop_msm_context_vesta(msm_context_t &ref) { - CUDA_OK(cudaFree(ref.d_points)); +extern "C" void drop_spmvm_context_vesta(spmvm_context_t &ref) +{ + drop_spmvm_context(ref); } extern "C" RustError cuda_sparse_matrix_mul_vesta(spmvm_host_t *csr, const scalar_t *scalars, scalar_t *out, size_t nthreads) @@ -28,12 +29,24 @@ extern "C" RustError cuda_sparse_matrix_mul_vesta(spmvm_host_t *csr, c return sparse_matrix_mul(csr, scalars, out, nthreads); } +extern "C" RustError cuda_sparse_matrix_witness_init_vesta( + spmvm_host_t *csr, spmvm_context_t *context) +{ + return sparse_matrix_witness_init(csr, context); +} + extern "C" RustError cuda_sparse_matrix_witness_vesta( spmvm_host_t *csr, const witness_t *witness, scalar_t *out, size_t nthreads) { return sparse_matrix_witness(csr, witness, out, nthreads); } +extern "C" RustError cuda_sparse_matrix_witness_with_vesta( + spmvm_context_t *context, const witness_t *witness, scalar_t *out, size_t nthreads) +{ + return sparse_matrix_witness_with(context, witness, out, nthreads); +} + extern "C" RustError cuda_sparse_matrix_witness_vesta_cpu( spmvm_host_t *csr, const witness_t *witness, scalar_t *out) { diff --git a/examples/spmvm_pallas.rs b/examples/spmvm_pallas.rs index 6258a43..78b1dbe 100644 --- a/examples/spmvm_pallas.rs +++ b/examples/spmvm_pallas.rs @@ -7,7 +7,7 @@ use std::time::Instant; use pasta_curves::{group::ff::{PrimeField, Field}, pallas}; use pasta_msm::{ - spmvm::{sparse_matrix_mul_pallas, CudaSparseMatrix, CudaWitness, sparse_matrix_witness_pallas}, + spmvm::{CudaSparseMatrix, CudaWitness, pallas::{sparse_matrix_witness_with_pallas, sparse_matrix_witness_init_pallas}}, utils::SparseMatrix, }; use rand::Rng; @@ -53,12 +53,12 @@ pub fn generate_scalars(len: usize) -> Vec { /// cargo run --release --example spmvm fn main() { let npow: usize = std::env::var("NPOW") - .unwrap_or("17".to_string()) + .unwrap_or("20".to_string()) .parse() .unwrap(); let n = 1usize << npow; let nthreads: usize = std::env::var("NTHREADS") - .unwrap_or("128".to_string()) + .unwrap_or("256".to_string()) .parse() .unwrap(); @@ -73,10 +73,11 @@ fn main() { let res = csr.multiply_vec(&scalars); println!("cpu took: {:?}", start.elapsed()); + let spmvm_context = sparse_matrix_witness_init_pallas(&cuda_csr); let witness = CudaWitness::new(&W, &pallas::Scalar::ONE, &U); let mut cuda_res = vec![pallas::Scalar::ONE; cuda_csr.num_rows]; let start = Instant::now(); - sparse_matrix_witness_pallas(&cuda_csr, &witness, &mut cuda_res, nthreads); + sparse_matrix_witness_with_pallas(&spmvm_context, &witness, &mut cuda_res, nthreads); println!("gpu took: {:?}", start.elapsed()); assert_eq!(res, cuda_res); diff --git a/examples/spmvm_vesta.rs b/examples/spmvm_vesta.rs index 4b7ae0f..c9d22a5 100644 --- a/examples/spmvm_vesta.rs +++ b/examples/spmvm_vesta.rs @@ -11,8 +11,10 @@ use pasta_curves::{ }; use pasta_msm::{ spmvm::{ - sparse_matrix_mul_vesta, sparse_matrix_witness_vesta, CudaSparseMatrix, - CudaWitness, sparse_matrix_witness_vesta_cpu, + vesta::{ + sparse_matrix_witness_init_vesta, sparse_matrix_witness_with_vesta, sparse_matrix_witness_vesta, + }, + CudaSparseMatrix, CudaWitness, }, utils::SparseMatrix, }; @@ -59,43 +61,42 @@ pub fn generate_scalars(len: usize) -> Vec { /// cargo run --release --example spmvm fn main() { let npow: usize = std::env::var("NPOW") - .unwrap_or("3".to_string()) + .unwrap_or("20".to_string()) .parse() .unwrap(); - let n = 1usize << npow; + let n = 1usize << (npow + 1); let nthreads: usize = std::env::var("NTHREADS") - .unwrap_or("1".to_string()) + .unwrap_or("128".to_string()) .parse() .unwrap(); - let csr_A = generate_csr(n, n); - let cuda_csr_A = - CudaSparseMatrix::new(&csr_A.data, &csr_A.indices, &csr_A.indptr, n, n); - let csr_B = generate_csr(n, n); - let cuda_csr_B = - CudaSparseMatrix::new(&csr_B.data, &csr_B.indices, &csr_B.indptr, n, n); - - // let W = generate_scalars(n - 10); - // let U = generate_scalars(9); - // let scalars = [W.clone(), vec![vesta::Scalar::ONE], U.clone()].concat(); - let W = vec![vesta::Scalar::ZERO; n - 3]; - let U = vec![vesta::Scalar::ZERO; 2]; - let scalars = vec![vesta::Scalar::ZERO; n]; + let csr = generate_csr(n, n); + let cuda_csr = + CudaSparseMatrix::new(&csr.data, &csr.indices, &csr.indptr, n, n); + let W = generate_scalars(n - 10); + let U = generate_scalars(9); + let scalars = [W.clone(), vec![vesta::Scalar::ONE], U.clone()].concat(); let start = Instant::now(); - let res_A = csr_A.multiply_vec(&scalars); - let res_B = csr_B.multiply_vec(&scalars); - println!("native took: {:?}", start.elapsed()); + let res = csr.multiply_vec(&scalars); + println!("cpu took: {:?}", start.elapsed()); - let witness = CudaWitness::new(&W, &vesta::Scalar::ZERO, &U); - let mut cuda_res_A = vec![vesta::Scalar::ZERO; cuda_csr_A.num_rows]; - let mut cuda_res_B = vec![vesta::Scalar::ZERO; cuda_csr_B.num_rows]; + let witness = CudaWitness::new(&W, &vesta::Scalar::ONE, &U); + let mut cuda_res = vec![vesta::Scalar::ONE; cuda_csr.num_rows]; let start = Instant::now(); - sparse_matrix_witness_vesta(&cuda_csr_A, &witness, &mut cuda_res_A, nthreads); - sparse_matrix_witness_vesta(&cuda_csr_B, &witness, &mut cuda_res_B, nthreads); - println!("ffi took: {:?}", start.elapsed()); + sparse_matrix_witness_vesta(&cuda_csr, &witness, &mut cuda_res, nthreads); + println!("gpu took: {:?}", start.elapsed()); - assert_eq!(res_A, cuda_res_A); - assert!(res_B == cuda_res_B); + let spmvm_context = sparse_matrix_witness_init_vesta(&cuda_csr); + let start = Instant::now(); + sparse_matrix_witness_with_vesta( + &spmvm_context, + &witness, + &mut cuda_res, + nthreads, + ); + println!("preallocated gpu took: {:?}", start.elapsed()); + + assert_eq!(res, cuda_res); println!("success!"); } diff --git a/src/spmvm.rs b/src/spmvm.rs deleted file mode 100644 index accd763..0000000 --- a/src/spmvm.rs +++ /dev/null @@ -1,252 +0,0 @@ -#![allow(non_snake_case)] - -use std::marker::PhantomData; -use pasta_curves::{group::ff::Field, pallas, vesta}; - -#[repr(C)] -pub struct CudaSparseMatrix<'a, F> { - pub data: *const F, - pub col_idx: *const usize, - pub row_ptr: *const usize, - - pub num_rows: usize, - pub num_cols: usize, - pub nnz: usize, - - _p: PhantomData<&'a F>, -} - -impl<'a, F> CudaSparseMatrix<'a, F> { - pub fn new( - data: &[F], - col_idx: &[usize], - row_ptr: &[usize], - num_rows: usize, - num_cols: usize, - ) -> Self { - assert_eq!( - data.len(), - col_idx.len(), - "data and col_idx length mismatch" - ); - assert_eq!( - row_ptr.len(), - num_rows + 1, - "row_ptr length and num_rows mismatch" - ); - - let nnz = data.len(); - CudaSparseMatrix { - data: data.as_ptr(), - col_idx: col_idx.as_ptr(), - row_ptr: row_ptr.as_ptr(), - num_rows, - num_cols, - nnz, - _p: PhantomData, - } - } -} - -#[repr(C)] -pub struct CudaWitness<'a, F> { - pub W: *const F, - pub u: *const F, - pub U: *const F, - pub nW: usize, - pub nU: usize, - _p: PhantomData<&'a F>, -} - -impl<'a, F> CudaWitness<'a, F> { - pub fn new( - W: &[F], - u: &F, - U: &[F], - ) -> Self { - let nW = W.len(); - let nU = U.len(); - CudaWitness { - W: W.as_ptr(), - u: u as *const _, - U: U.as_ptr(), - nW, - nU, - _p: PhantomData, - } - } -} - -pub fn sparse_matrix_mul_pallas( - csr: &CudaSparseMatrix, - scalars: &[pallas::Scalar], - nthreads: usize, -) -> Vec { - extern "C" { - fn cuda_sparse_matrix_mul_pallas( - csr: *const CudaSparseMatrix, - scalars: *const pallas::Scalar, - out: *mut pallas::Scalar, - nthreads: usize, - ) -> sppark::Error; - } - - let mut out = vec![pallas::Scalar::ZERO; csr.num_rows]; - let err = unsafe { - cuda_sparse_matrix_mul_pallas( - csr as *const _, - scalars.as_ptr(), - out.as_mut_ptr(), - nthreads, - ) - }; - if err.code != 0 { - panic!("{}", String::from(err)); - } - - out -} - -pub fn sparse_matrix_witness_pallas( - csr: &CudaSparseMatrix, - witness: &CudaWitness, - buffer: &mut [pallas::Scalar], - nthreads: usize, -) { - extern "C" { - fn cuda_sparse_matrix_witness_pallas( - csr: *const CudaSparseMatrix, - witness: *const CudaWitness, - out: *mut pallas::Scalar, - nthreads: usize, - ) -> sppark::Error; - } - - assert_eq!(witness.nW + witness.nU + 1, csr.num_cols, "invalid witness size"); - - let err = unsafe { - cuda_sparse_matrix_witness_pallas( - csr as *const _, - witness as *const _, - buffer.as_mut_ptr(), - nthreads, - ) - }; - if err.code != 0 { - panic!("{}", String::from(err)); - } -} - -pub fn sparse_matrix_witness_pallas_cpu( - csr: &CudaSparseMatrix, - witness: &CudaWitness, - buffer: &mut [pallas::Scalar], -) { - extern "C" { - fn cuda_sparse_matrix_witness_pallas_cpu( - csr: *const CudaSparseMatrix, - witness: *const CudaWitness, - out: *mut pallas::Scalar, - ) -> sppark::Error; - } - - assert_eq!(witness.nW + witness.nU + 1, csr.num_cols, "invalid witness size"); - - let err = unsafe { - cuda_sparse_matrix_witness_pallas_cpu( - csr as *const _, - witness as *const _, - buffer.as_mut_ptr(), - ) - }; - if err.code != 0 { - panic!("{}", String::from(err)); - } -} - -pub fn sparse_matrix_mul_vesta( - csr: &CudaSparseMatrix, - scalars: &[vesta::Scalar], - nthreads: usize, -) -> Vec { - extern "C" { - fn cuda_sparse_matrix_mul_vesta( - csr: *const CudaSparseMatrix, - scalars: *const vesta::Scalar, - out: *mut vesta::Scalar, - nthreads: usize, - ) -> sppark::Error; - } - - let mut out = vec![vesta::Scalar::ZERO; csr.num_rows]; - let err = unsafe { - cuda_sparse_matrix_mul_vesta( - csr as *const _, - scalars.as_ptr(), - out.as_mut_ptr(), - nthreads, - ) - }; - if err.code != 0 { - panic!("{}", String::from(err)); - } - - out -} - -pub fn sparse_matrix_witness_vesta( - csr: &CudaSparseMatrix, - witness: &CudaWitness, - buffer: &mut [vesta::Scalar], - nthreads: usize, -) { - extern "C" { - fn cuda_sparse_matrix_witness_vesta( - csr: *const CudaSparseMatrix, - witness: *const CudaWitness, - out: *mut vesta::Scalar, - nthreads: usize, - ) -> sppark::Error; - } - - assert_eq!(witness.nW + witness.nU + 1, csr.num_cols, "invalid witness size"); - - let err = unsafe { - cuda_sparse_matrix_witness_vesta( - csr as *const _, - witness as *const _, - buffer.as_mut_ptr(), - nthreads, - ) - }; - if err.code != 0 { - panic!("{}", String::from(err)); - } -} - -pub fn sparse_matrix_witness_vesta_cpu( - csr: &CudaSparseMatrix, - witness: &CudaWitness, - buffer: &mut [vesta::Scalar], -) { - extern "C" { - fn cuda_sparse_matrix_witness_vesta_cpu( - csr: *const CudaSparseMatrix, - witness: *const CudaWitness, - out: *mut vesta::Scalar, - ) -> sppark::Error; - } - - assert_eq!(witness.nW + witness.nU + 1, csr.num_cols, "invalid witness size"); - - let err = unsafe { - cuda_sparse_matrix_witness_vesta_cpu( - csr as *const _, - witness as *const _, - buffer.as_mut_ptr(), - ) - }; - if err.code != 0 { - panic!("{}", String::from(err)); - } -} diff --git a/src/spmvm/mod.rs b/src/spmvm/mod.rs new file mode 100644 index 0000000..2461dfa --- /dev/null +++ b/src/spmvm/mod.rs @@ -0,0 +1,76 @@ +#![allow(non_snake_case)] + +use std::marker::PhantomData; + +pub mod pallas; +pub mod vesta; + +#[repr(C)] +pub struct CudaSparseMatrix<'a, F> { + pub data: *const F, + pub col_idx: *const usize, + pub row_ptr: *const usize, + + pub num_rows: usize, + pub num_cols: usize, + pub nnz: usize, + + _p: PhantomData<&'a F>, +} + +impl<'a, F> CudaSparseMatrix<'a, F> { + pub fn new( + data: &[F], + col_idx: &[usize], + row_ptr: &[usize], + num_rows: usize, + num_cols: usize, + ) -> Self { + assert_eq!( + data.len(), + col_idx.len(), + "data and col_idx length mismatch" + ); + assert_eq!( + row_ptr.len(), + num_rows + 1, + "row_ptr length and num_rows mismatch" + ); + + let nnz = data.len(); + CudaSparseMatrix { + data: data.as_ptr(), + col_idx: col_idx.as_ptr(), + row_ptr: row_ptr.as_ptr(), + num_rows, + num_cols, + nnz, + _p: PhantomData, + } + } +} + +#[repr(C)] +pub struct CudaWitness<'a, F> { + pub W: *const F, + pub u: *const F, + pub U: *const F, + pub nW: usize, + pub nU: usize, + _p: PhantomData<&'a F>, +} + +impl<'a, F> CudaWitness<'a, F> { + pub fn new(W: &[F], u: &F, U: &[F]) -> Self { + let nW = W.len(); + let nU = U.len(); + CudaWitness { + W: W.as_ptr(), + u: u as *const _, + U: U.as_ptr(), + nW, + nU, + _p: PhantomData, + } + } +} \ No newline at end of file diff --git a/src/spmvm/pallas.rs b/src/spmvm/pallas.rs new file mode 100644 index 0000000..899fd97 --- /dev/null +++ b/src/spmvm/pallas.rs @@ -0,0 +1,217 @@ +#![allow(non_snake_case)] + +use std::ffi::c_void; + +use pasta_curves::{pallas, group::ff::Field}; + +use super::{CudaSparseMatrix, CudaWitness}; + + +#[repr(C)] +#[derive(Debug, Clone)] +pub struct SpMVMContextPallas { + pub d_data: *const c_void, + pub d_col_idx: *const c_void, + pub d_row_ptr: *const c_void, + + pub num_rows: usize, + pub num_cols: usize, + pub nnz: usize, + + pub d_scalars: *const c_void, + pub d_out: *const c_void, +} + +unsafe impl Send for SpMVMContextPallas {} +unsafe impl Sync for SpMVMContextPallas {} + +impl Default for SpMVMContextPallas { + fn default() -> Self { + Self { + d_data: core::ptr::null(), + d_col_idx: core::ptr::null(), + d_row_ptr: core::ptr::null(), + num_rows: 0, + num_cols: 0, + nnz: 0, + d_scalars: core::ptr::null(), + d_out: core::ptr::null(), + } + } +} + +// TODO: check for device-side memory leaks +impl Drop for SpMVMContextPallas { + fn drop(&mut self) { + extern "C" { + fn drop_spmvm_context_pallas(by_ref: &SpMVMContextPallas); + } + unsafe { + drop_spmvm_context_pallas(std::mem::transmute::<&_, &_>(self)) + }; + + self.d_data = core::ptr::null(); + self.d_col_idx = core::ptr::null(); + self.d_row_ptr = core::ptr::null(); + + self.num_rows = 0; + self.num_cols = 0; + self.nnz = 0; + + self.d_scalars = core::ptr::null(); + self.d_out = core::ptr::null(); + } +} + +pub fn sparse_matrix_mul_pallas( + csr: &CudaSparseMatrix, + scalars: &[pallas::Scalar], + nthreads: usize, +) -> Vec { + extern "C" { + fn cuda_sparse_matrix_mul_pallas( + csr: *const CudaSparseMatrix, + scalars: *const pallas::Scalar, + out: *mut pallas::Scalar, + nthreads: usize, + ) -> sppark::Error; + } + + let mut out = vec![pallas::Scalar::ZERO; csr.num_rows]; + let err = unsafe { + cuda_sparse_matrix_mul_pallas( + csr as *const _, + scalars.as_ptr(), + out.as_mut_ptr(), + nthreads, + ) + }; + if err.code != 0 { + panic!("{}", String::from(err)); + } + + out +} + +pub fn sparse_matrix_witness_init_pallas( + csr: &CudaSparseMatrix, +) -> SpMVMContextPallas { + extern "C" { + fn cuda_sparse_matrix_witness_init_pallas( + csr: *const CudaSparseMatrix, + context: *mut SpMVMContextPallas, + ) -> sppark::Error; + } + + let mut context = SpMVMContextPallas::default(); + let err = unsafe { + cuda_sparse_matrix_witness_init_pallas( + csr as *const _, + &mut context as *mut _, + ) + }; + if err.code != 0 { + panic!("{}", String::from(err)); + } + + context +} + +pub fn sparse_matrix_witness_pallas( + csr: &CudaSparseMatrix, + witness: &CudaWitness, + buffer: &mut [pallas::Scalar], + nthreads: usize, +) { + extern "C" { + fn cuda_sparse_matrix_witness_pallas( + csr: *const CudaSparseMatrix, + witness: *const CudaWitness, + out: *mut pallas::Scalar, + nthreads: usize, + ) -> sppark::Error; + } + + assert_eq!( + witness.nW + witness.nU + 1, + csr.num_cols, + "invalid witness size" + ); + + let err = unsafe { + cuda_sparse_matrix_witness_pallas( + csr as *const _, + witness as *const _, + buffer.as_mut_ptr(), + nthreads, + ) + }; + if err.code != 0 { + panic!("{}", String::from(err)); + } +} + +pub fn sparse_matrix_witness_with_pallas( + context: &SpMVMContextPallas, + witness: &CudaWitness, + buffer: &mut [pallas::Scalar], + nthreads: usize, +) { + extern "C" { + fn cuda_sparse_matrix_witness_with_pallas( + context: *const SpMVMContextPallas, + witness: *const CudaWitness, + out: *mut pallas::Scalar, + nthreads: usize, + ) -> sppark::Error; + } + + assert_eq!( + witness.nW + witness.nU + 1, + context.num_cols, + "invalid witness size" + ); + + let err = unsafe { + cuda_sparse_matrix_witness_with_pallas( + context as *const _, + witness as *const _, + buffer.as_mut_ptr(), + nthreads, + ) + }; + if err.code != 0 { + panic!("{}", String::from(err)); + } +} + +pub fn sparse_matrix_witness_pallas_cpu( + csr: &CudaSparseMatrix, + witness: &CudaWitness, + buffer: &mut [pallas::Scalar], +) { + extern "C" { + fn cuda_sparse_matrix_witness_pallas_cpu( + csr: *const CudaSparseMatrix, + witness: *const CudaWitness, + out: *mut pallas::Scalar, + ) -> sppark::Error; + } + + assert_eq!( + witness.nW + witness.nU + 1, + csr.num_cols, + "invalid witness size" + ); + + let err = unsafe { + cuda_sparse_matrix_witness_pallas_cpu( + csr as *const _, + witness as *const _, + buffer.as_mut_ptr(), + ) + }; + if err.code != 0 { + panic!("{}", String::from(err)); + } +} \ No newline at end of file diff --git a/src/spmvm/vesta.rs b/src/spmvm/vesta.rs new file mode 100644 index 0000000..be47db9 --- /dev/null +++ b/src/spmvm/vesta.rs @@ -0,0 +1,216 @@ +#![allow(non_snake_case)] + +use std::ffi::c_void; + +use pasta_curves::{vesta, group::ff::Field}; + +use super::{CudaSparseMatrix, CudaWitness}; + +#[repr(C)] +#[derive(Debug, Clone)] +pub struct SpMVMContextVesta { + pub d_data: *const c_void, + pub d_col_idx: *const c_void, + pub d_row_ptr: *const c_void, + + pub num_rows: usize, + pub num_cols: usize, + pub nnz: usize, + + pub d_scalars: *const c_void, + pub d_out: *const c_void, +} + +unsafe impl Send for SpMVMContextVesta {} +unsafe impl Sync for SpMVMContextVesta {} + +impl Default for SpMVMContextVesta { + fn default() -> Self { + Self { + d_data: core::ptr::null(), + d_col_idx: core::ptr::null(), + d_row_ptr: core::ptr::null(), + num_rows: 0, + num_cols: 0, + nnz: 0, + d_scalars: core::ptr::null(), + d_out: core::ptr::null(), + } + } +} + +// TODO: check for device-side memory leaks +impl Drop for SpMVMContextVesta { + fn drop(&mut self) { + extern "C" { + fn drop_spmvm_context_vesta(by_ref: &SpMVMContextVesta); + } + unsafe { + drop_spmvm_context_vesta(std::mem::transmute::<&_, &_>(self)) + }; + + self.d_data = core::ptr::null(); + self.d_col_idx = core::ptr::null(); + self.d_row_ptr = core::ptr::null(); + + self.num_rows = 0; + self.num_cols = 0; + self.nnz = 0; + + self.d_scalars = core::ptr::null(); + self.d_out = core::ptr::null(); + } +} + +pub fn sparse_matrix_mul_vesta( + csr: &CudaSparseMatrix, + scalars: &[vesta::Scalar], + nthreads: usize, +) -> Vec { + extern "C" { + fn cuda_sparse_matrix_mul_vesta( + csr: *const CudaSparseMatrix, + scalars: *const vesta::Scalar, + out: *mut vesta::Scalar, + nthreads: usize, + ) -> sppark::Error; + } + + let mut out = vec![vesta::Scalar::ZERO; csr.num_rows]; + let err = unsafe { + cuda_sparse_matrix_mul_vesta( + csr as *const _, + scalars.as_ptr(), + out.as_mut_ptr(), + nthreads, + ) + }; + if err.code != 0 { + panic!("{}", String::from(err)); + } + + out +} + +pub fn sparse_matrix_witness_init_vesta( + csr: &CudaSparseMatrix, +) -> SpMVMContextVesta { + extern "C" { + fn cuda_sparse_matrix_witness_init_vesta( + csr: *const CudaSparseMatrix, + context: *mut SpMVMContextVesta, + ) -> sppark::Error; + } + + let mut context = SpMVMContextVesta::default(); + let err = unsafe { + cuda_sparse_matrix_witness_init_vesta( + csr as *const _, + &mut context as *mut _, + ) + }; + if err.code != 0 { + panic!("{}", String::from(err)); + } + + context +} + +pub fn sparse_matrix_witness_vesta( + csr: &CudaSparseMatrix, + witness: &CudaWitness, + buffer: &mut [vesta::Scalar], + nthreads: usize, +) { + extern "C" { + fn cuda_sparse_matrix_witness_vesta( + csr: *const CudaSparseMatrix, + witness: *const CudaWitness, + out: *mut vesta::Scalar, + nthreads: usize, + ) -> sppark::Error; + } + + assert_eq!( + witness.nW + witness.nU + 1, + csr.num_cols, + "invalid witness size" + ); + + let err = unsafe { + cuda_sparse_matrix_witness_vesta( + csr as *const _, + witness as *const _, + buffer.as_mut_ptr(), + nthreads, + ) + }; + if err.code != 0 { + panic!("{}", String::from(err)); + } +} + +pub fn sparse_matrix_witness_with_vesta( + context: &SpMVMContextVesta, + witness: &CudaWitness, + buffer: &mut [vesta::Scalar], + nthreads: usize, +) { + extern "C" { + fn cuda_sparse_matrix_witness_with_vesta( + context: *const SpMVMContextVesta, + witness: *const CudaWitness, + out: *mut vesta::Scalar, + nthreads: usize, + ) -> sppark::Error; + } + + assert_eq!( + witness.nW + witness.nU + 1, + context.num_cols, + "invalid witness size" + ); + + let err = unsafe { + cuda_sparse_matrix_witness_with_vesta( + context as *const _, + witness as *const _, + buffer.as_mut_ptr(), + nthreads, + ) + }; + if err.code != 0 { + panic!("{}", String::from(err)); + } +} + +pub fn sparse_matrix_witness_vesta_cpu( + csr: &CudaSparseMatrix, + witness: &CudaWitness, + buffer: &mut [vesta::Scalar], +) { + extern "C" { + fn cuda_sparse_matrix_witness_vesta_cpu( + csr: *const CudaSparseMatrix, + witness: *const CudaWitness, + out: *mut vesta::Scalar, + ) -> sppark::Error; + } + + assert_eq!( + witness.nW + witness.nU + 1, + csr.num_cols, + "invalid witness size" + ); + + let err = unsafe { + cuda_sparse_matrix_witness_vesta_cpu( + csr as *const _, + witness as *const _, + buffer.as_mut_ptr(), + ) + }; + if err.code != 0 { + panic!("{}", String::from(err)); + } +} \ No newline at end of file