diff --git a/Cargo.toml b/Cargo.toml index eeeab68..08b7e9d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,6 +28,8 @@ force-adx = [ "semolina/force-adx" ] cuda-mobile = [] [dependencies] +abomonation = "0.7.3" +abomonation_derive = { version = "0.1.0", package = "abomonation_derive_ng" } semolina = "~0.1.3" sppark = { git = "https://github.com/lurk-lab/sppark", branch = "gpu-spmvm" } pasta_curves = { git = "https://github.com/lurk-lab/pasta_curves", branch="dev", version = ">=0.3.1, <=0.5", features = ["repr-c"] } diff --git a/src/lib.rs b/src/lib.rs index 00e6878..5c3241a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,6 +2,9 @@ // Licensed under the Apache License, Version 2.0, see LICENSE for details. // SPDX-License-Identifier: Apache-2.0 +pub mod spmvm; +pub mod utils; + extern crate semolina; #[cfg(feature = "cuda")] diff --git a/src/spmvm/mod.rs b/src/spmvm/mod.rs new file mode 100644 index 0000000..699b7ae --- /dev/null +++ b/src/spmvm/mod.rs @@ -0,0 +1,92 @@ +#![allow(non_snake_case)] + +use std::marker::PhantomData; + +use pasta_curves::group::ff::PrimeField; + +use crate::utils::SparseMatrix; + +pub mod pallas; +pub mod vesta; + +#[repr(C)] +pub struct CudaSparseMatrix<'a, F> { + pub data: *const F, + pub col_idx: *const usize, + pub row_ptr: *const usize, + + pub num_rows: usize, + pub num_cols: usize, + pub nnz: usize, + + _p: PhantomData<&'a F>, +} + +impl<'a, F> CudaSparseMatrix<'a, F> { + pub fn new( + data: &[F], + col_idx: &[usize], + row_ptr: &[usize], + num_rows: usize, + num_cols: usize, + ) -> Self { + assert_eq!( + data.len(), + col_idx.len(), + "data and col_idx length mismatch" + ); + assert_eq!( + row_ptr.len(), + num_rows + 1, + "row_ptr length and num_rows mismatch" + ); + + let nnz = data.len(); + CudaSparseMatrix { + data: data.as_ptr(), + col_idx: col_idx.as_ptr(), + row_ptr: row_ptr.as_ptr(), + num_rows, + num_cols, + nnz, + _p: PhantomData, + } + } +} + +impl<'a, F: PrimeField> From<&'a SparseMatrix> for CudaSparseMatrix<'a, F> { + fn from(value: &SparseMatrix) -> Self { + CudaSparseMatrix::new( + &value.data, + &value.indices, + &value.indptr, + value.indptr.len() - 1, + value.cols, + ) + } +} + +#[repr(C)] +pub struct CudaWitness<'a, F> { + pub W: *const F, + pub u: *const F, + pub U: *const F, + pub nW: usize, + pub nU: usize, + _p: PhantomData<&'a F>, +} + +impl<'a, F> CudaWitness<'a, F> { + pub fn new(W: &[F], u: &F, U: &[F]) -> Self { + let nW = W.len(); + let nU = U.len(); + CudaWitness { + W: W.as_ptr(), + u: u as *const _, + U: U.as_ptr(), + nW, + nU, + _p: PhantomData, + } + } +} \ No newline at end of file diff --git a/src/spmvm/pallas.rs b/src/spmvm/pallas.rs new file mode 100644 index 0000000..899fd97 --- /dev/null +++ b/src/spmvm/pallas.rs @@ -0,0 +1,217 @@ +#![allow(non_snake_case)] + +use std::ffi::c_void; + +use pasta_curves::{pallas, group::ff::Field}; + +use super::{CudaSparseMatrix, CudaWitness}; + + +#[repr(C)] +#[derive(Debug, Clone)] +pub struct SpMVMContextPallas { + pub d_data: *const c_void, + pub d_col_idx: *const c_void, + pub d_row_ptr: *const c_void, + + pub num_rows: usize, + pub num_cols: usize, + pub nnz: usize, + + pub d_scalars: *const c_void, + pub d_out: *const c_void, +} + +unsafe impl Send for SpMVMContextPallas {} +unsafe impl Sync for SpMVMContextPallas {} + +impl Default for SpMVMContextPallas { + fn default() -> Self { + Self { + d_data: core::ptr::null(), + d_col_idx: core::ptr::null(), + d_row_ptr: core::ptr::null(), + num_rows: 0, + num_cols: 0, + nnz: 0, + d_scalars: core::ptr::null(), + d_out: core::ptr::null(), + } + } +} + +// TODO: check for device-side memory leaks +impl Drop for SpMVMContextPallas { + fn drop(&mut self) { + extern "C" { + fn drop_spmvm_context_pallas(by_ref: &SpMVMContextPallas); + } + unsafe { + drop_spmvm_context_pallas(std::mem::transmute::<&_, &_>(self)) + }; + + self.d_data = core::ptr::null(); + self.d_col_idx = core::ptr::null(); + self.d_row_ptr = core::ptr::null(); + + self.num_rows = 0; + self.num_cols = 0; + self.nnz = 0; + + self.d_scalars = core::ptr::null(); + self.d_out = core::ptr::null(); + } +} + +pub fn sparse_matrix_mul_pallas( + csr: &CudaSparseMatrix, + scalars: &[pallas::Scalar], + nthreads: usize, +) -> Vec { + extern "C" { + fn cuda_sparse_matrix_mul_pallas( + csr: *const CudaSparseMatrix, + scalars: *const pallas::Scalar, + out: *mut pallas::Scalar, + nthreads: usize, + ) -> sppark::Error; + } + + let mut out = vec![pallas::Scalar::ZERO; csr.num_rows]; + let err = unsafe { + cuda_sparse_matrix_mul_pallas( + csr as *const _, + scalars.as_ptr(), + out.as_mut_ptr(), + nthreads, + ) + }; + if err.code != 0 { + panic!("{}", String::from(err)); + } + + out +} + +pub fn sparse_matrix_witness_init_pallas( + csr: &CudaSparseMatrix, +) -> SpMVMContextPallas { + extern "C" { + fn cuda_sparse_matrix_witness_init_pallas( + csr: *const CudaSparseMatrix, + context: *mut SpMVMContextPallas, + ) -> sppark::Error; + } + + let mut context = SpMVMContextPallas::default(); + let err = unsafe { + cuda_sparse_matrix_witness_init_pallas( + csr as *const _, + &mut context as *mut _, + ) + }; + if err.code != 0 { + panic!("{}", String::from(err)); + } + + context +} + +pub fn sparse_matrix_witness_pallas( + csr: &CudaSparseMatrix, + witness: &CudaWitness, + buffer: &mut [pallas::Scalar], + nthreads: usize, +) { + extern "C" { + fn cuda_sparse_matrix_witness_pallas( + csr: *const CudaSparseMatrix, + witness: *const CudaWitness, + out: *mut pallas::Scalar, + nthreads: usize, + ) -> sppark::Error; + } + + assert_eq!( + witness.nW + witness.nU + 1, + csr.num_cols, + "invalid witness size" + ); + + let err = unsafe { + cuda_sparse_matrix_witness_pallas( + csr as *const _, + witness as *const _, + buffer.as_mut_ptr(), + nthreads, + ) + }; + if err.code != 0 { + panic!("{}", String::from(err)); + } +} + +pub fn sparse_matrix_witness_with_pallas( + context: &SpMVMContextPallas, + witness: &CudaWitness, + buffer: &mut [pallas::Scalar], + nthreads: usize, +) { + extern "C" { + fn cuda_sparse_matrix_witness_with_pallas( + context: *const SpMVMContextPallas, + witness: *const CudaWitness, + out: *mut pallas::Scalar, + nthreads: usize, + ) -> sppark::Error; + } + + assert_eq!( + witness.nW + witness.nU + 1, + context.num_cols, + "invalid witness size" + ); + + let err = unsafe { + cuda_sparse_matrix_witness_with_pallas( + context as *const _, + witness as *const _, + buffer.as_mut_ptr(), + nthreads, + ) + }; + if err.code != 0 { + panic!("{}", String::from(err)); + } +} + +pub fn sparse_matrix_witness_pallas_cpu( + csr: &CudaSparseMatrix, + witness: &CudaWitness, + buffer: &mut [pallas::Scalar], +) { + extern "C" { + fn cuda_sparse_matrix_witness_pallas_cpu( + csr: *const CudaSparseMatrix, + witness: *const CudaWitness, + out: *mut pallas::Scalar, + ) -> sppark::Error; + } + + assert_eq!( + witness.nW + witness.nU + 1, + csr.num_cols, + "invalid witness size" + ); + + let err = unsafe { + cuda_sparse_matrix_witness_pallas_cpu( + csr as *const _, + witness as *const _, + buffer.as_mut_ptr(), + ) + }; + if err.code != 0 { + panic!("{}", String::from(err)); + } +} \ No newline at end of file diff --git a/src/spmvm/vesta.rs b/src/spmvm/vesta.rs new file mode 100644 index 0000000..be47db9 --- /dev/null +++ b/src/spmvm/vesta.rs @@ -0,0 +1,216 @@ +#![allow(non_snake_case)] + +use std::ffi::c_void; + +use pasta_curves::{vesta, group::ff::Field}; + +use super::{CudaSparseMatrix, CudaWitness}; + +#[repr(C)] +#[derive(Debug, Clone)] +pub struct SpMVMContextVesta { + pub d_data: *const c_void, + pub d_col_idx: *const c_void, + pub d_row_ptr: *const c_void, + + pub num_rows: usize, + pub num_cols: usize, + pub nnz: usize, + + pub d_scalars: *const c_void, + pub d_out: *const c_void, +} + +unsafe impl Send for SpMVMContextVesta {} +unsafe impl Sync for SpMVMContextVesta {} + +impl Default for SpMVMContextVesta { + fn default() -> Self { + Self { + d_data: core::ptr::null(), + d_col_idx: core::ptr::null(), + d_row_ptr: core::ptr::null(), + num_rows: 0, + num_cols: 0, + nnz: 0, + d_scalars: core::ptr::null(), + d_out: core::ptr::null(), + } + } +} + +// TODO: check for device-side memory leaks +impl Drop for SpMVMContextVesta { + fn drop(&mut self) { + extern "C" { + fn drop_spmvm_context_vesta(by_ref: &SpMVMContextVesta); + } + unsafe { + drop_spmvm_context_vesta(std::mem::transmute::<&_, &_>(self)) + }; + + self.d_data = core::ptr::null(); + self.d_col_idx = core::ptr::null(); + self.d_row_ptr = core::ptr::null(); + + self.num_rows = 0; + self.num_cols = 0; + self.nnz = 0; + + self.d_scalars = core::ptr::null(); + self.d_out = core::ptr::null(); + } +} + +pub fn sparse_matrix_mul_vesta( + csr: &CudaSparseMatrix, + scalars: &[vesta::Scalar], + nthreads: usize, +) -> Vec { + extern "C" { + fn cuda_sparse_matrix_mul_vesta( + csr: *const CudaSparseMatrix, + scalars: *const vesta::Scalar, + out: *mut vesta::Scalar, + nthreads: usize, + ) -> sppark::Error; + } + + let mut out = vec![vesta::Scalar::ZERO; csr.num_rows]; + let err = unsafe { + cuda_sparse_matrix_mul_vesta( + csr as *const _, + scalars.as_ptr(), + out.as_mut_ptr(), + nthreads, + ) + }; + if err.code != 0 { + panic!("{}", String::from(err)); + } + + out +} + +pub fn sparse_matrix_witness_init_vesta( + csr: &CudaSparseMatrix, +) -> SpMVMContextVesta { + extern "C" { + fn cuda_sparse_matrix_witness_init_vesta( + csr: *const CudaSparseMatrix, + context: *mut SpMVMContextVesta, + ) -> sppark::Error; + } + + let mut context = SpMVMContextVesta::default(); + let err = unsafe { + cuda_sparse_matrix_witness_init_vesta( + csr as *const _, + &mut context as *mut _, + ) + }; + if err.code != 0 { + panic!("{}", String::from(err)); + } + + context +} + +pub fn sparse_matrix_witness_vesta( + csr: &CudaSparseMatrix, + witness: &CudaWitness, + buffer: &mut [vesta::Scalar], + nthreads: usize, +) { + extern "C" { + fn cuda_sparse_matrix_witness_vesta( + csr: *const CudaSparseMatrix, + witness: *const CudaWitness, + out: *mut vesta::Scalar, + nthreads: usize, + ) -> sppark::Error; + } + + assert_eq!( + witness.nW + witness.nU + 1, + csr.num_cols, + "invalid witness size" + ); + + let err = unsafe { + cuda_sparse_matrix_witness_vesta( + csr as *const _, + witness as *const _, + buffer.as_mut_ptr(), + nthreads, + ) + }; + if err.code != 0 { + panic!("{}", String::from(err)); + } +} + +pub fn sparse_matrix_witness_with_vesta( + context: &SpMVMContextVesta, + witness: &CudaWitness, + buffer: &mut [vesta::Scalar], + nthreads: usize, +) { + extern "C" { + fn cuda_sparse_matrix_witness_with_vesta( + context: *const SpMVMContextVesta, + witness: *const CudaWitness, + out: *mut vesta::Scalar, + nthreads: usize, + ) -> sppark::Error; + } + + assert_eq!( + witness.nW + witness.nU + 1, + context.num_cols, + "invalid witness size" + ); + + let err = unsafe { + cuda_sparse_matrix_witness_with_vesta( + context as *const _, + witness as *const _, + buffer.as_mut_ptr(), + nthreads, + ) + }; + if err.code != 0 { + panic!("{}", String::from(err)); + } +} + +pub fn sparse_matrix_witness_vesta_cpu( + csr: &CudaSparseMatrix, + witness: &CudaWitness, + buffer: &mut [vesta::Scalar], +) { + extern "C" { + fn cuda_sparse_matrix_witness_vesta_cpu( + csr: *const CudaSparseMatrix, + witness: *const CudaWitness, + out: *mut vesta::Scalar, + ) -> sppark::Error; + } + + assert_eq!( + witness.nW + witness.nU + 1, + csr.num_cols, + "invalid witness size" + ); + + let err = unsafe { + cuda_sparse_matrix_witness_vesta_cpu( + csr as *const _, + witness as *const _, + buffer.as_mut_ptr(), + ) + }; + if err.code != 0 { + panic!("{}", String::from(err)); + } +} \ No newline at end of file diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..05abd45 --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,262 @@ +//! # Sparse Matrices +//! +//! This module defines a custom implementation of CSR/CSC sparse matrices. +//! Specifically, we implement sparse matrix / dense vector multiplication +//! to compute the `A z`, `B z`, and `C z` in Nova. + +use std::convert::TryInto; + +use abomonation_derive::Abomonation; +use abomonation::Abomonation; +use pasta_curves::{group::ff::PrimeField, arithmetic::CurveAffine}; + +/// CSR format sparse matrix, We follow the names used by scipy. +/// Detailed explanation here: +#[derive(Debug, PartialEq, Eq, Abomonation)] +#[abomonation_bounds(where ::Repr: Abomonation)] +pub struct SparseMatrix { + /// all non-zero values in the matrix + #[abomonate_with(Vec)] + pub data: Vec, + /// column indices + pub indices: Vec, + /// row information + pub indptr: Vec, + /// number of columns + pub cols: usize, +} + +/// [SparseMatrix]s are often large, and this helps with cloning bottlenecks +impl Clone for SparseMatrix { + fn clone(&self) -> Self { + Self { + data: self.data.iter().cloned().collect(), + indices: self.indices.iter().cloned().collect(), + indptr: self.indptr.iter().cloned().collect(), + cols: self.cols, + } + } +} + +impl SparseMatrix { + /// 0x0 empty matrix + pub fn empty() -> Self { + SparseMatrix { + data: vec![], + indices: vec![], + indptr: vec![0], + cols: 0, + } + } + + /// Construct from the COO representation; Vec. + /// We assume that the rows are sorted during construction. + pub fn new(matrix: &[(usize, usize, F)], rows: usize, cols: usize) -> Self { + let mut new_matrix = vec![vec![]; rows]; + for (row, col, val) in matrix { + new_matrix[*row].push((*col, *val)); + } + + for row in new_matrix.iter() { + assert!(row.windows(2).all(|w| w[0].0 < w[1].0)); + } + + let mut indptr = vec![0; rows + 1]; + for (i, col) in new_matrix.iter().enumerate() { + indptr[i + 1] = indptr[i] + col.len(); + } + + let mut indices = vec![]; + let mut data = vec![]; + for col in new_matrix { + let (idx, val): (Vec<_>, Vec<_>) = col.into_iter().unzip(); + indices.extend(idx); + data.extend(val); + } + + SparseMatrix { + data, + indices, + indptr, + cols, + } + } + + /// Retrieves the data for row slice [i..j] from `ptrs`. + /// We assume that `ptrs` is indexed from `indptrs` and do not check if the + /// returned slice is actually a valid row. + pub fn get_row_unchecked(&self, ptrs: &[usize; 2]) -> impl Iterator { + self.data[ptrs[0]..ptrs[1]] + .iter() + .zip(&self.indices[ptrs[0]..ptrs[1]]) + } + + /// Multiply by a dense vector; uses rayon to parallelize. + pub fn multiply_vec(&self, vector: &[F]) -> Vec { + assert_eq!(self.cols, vector.len(), "invalid shape"); + + self.multiply_vec_unchecked(vector) + } + + /// Multiply by a dense vector; uses rayon to parallelize. + /// This does not check that the shape of the matrix/vector are compatible. + pub fn multiply_vec_unchecked(&self, vector: &[F]) -> Vec { + self + .indptr + .windows(2) + .map(|ptrs| { + self + .get_row_unchecked(ptrs.try_into().unwrap()) + .map(|(val, col_idx)| *val * vector[*col_idx]) + .sum() + }) + .collect() + } + + /// number of non-zero entries + pub fn len(&self) -> usize { + *self.indptr.last().unwrap() + } + + /// empty matrix + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// returns a custom iterator + pub fn iter(&self) -> Iter<'_, F> { + let mut row = 0; + while self.indptr[row + 1] == 0 { + row += 1; + } + Iter { + matrix: self, + row, + i: 0, + nnz: *self.indptr.last().unwrap(), + } + } +} + +/// Iterator for sparse matrix +pub struct Iter<'a, F: PrimeField> { + matrix: &'a SparseMatrix, + row: usize, + i: usize, + nnz: usize, +} + +impl<'a, F: PrimeField> Iterator for Iter<'a, F> { + type Item = (usize, usize, F); + + fn next(&mut self) -> Option { + // are we at the end? + if self.i == self.nnz { + return None; + } + + // compute current item + let curr_item = ( + self.row, + self.matrix.indices[self.i], + self.matrix.data[self.i], + ); + + // advance the iterator + self.i += 1; + // edge case at the end + if self.i == self.nnz { + return Some(curr_item); + } + // if `i` has moved to next row + while self.i >= self.matrix.indptr[self.row + 1] { + self.row += 1; + } + + Some(curr_item) + } +} + +/// A type that holds commitment generators +#[derive(Clone, Debug, PartialEq, Eq, Abomonation)] +#[abomonation_omit_bounds] +pub struct CommitmentKey +{ + #[abomonate_with(Vec<[u64; 8]>)] // this is a hack; we just assume the size of the element. + pub ck: Vec, +} + +// pub fn gen_points(npoints: usize) -> Vec { +// let mut ret: Vec = Vec::with_capacity(npoints); +// unsafe { ret.set_len(npoints) }; + +// let mut rnd: Vec = Vec::with_capacity(32 * npoints); +// unsafe { rnd.set_len(32 * npoints) }; +// ChaCha20Rng::from_entropy().fill_bytes(&mut rnd); + +// let n_workers = rayon::current_num_threads(); +// let work = AtomicUsize::new(0); +// rayon::scope(|s| { +// for _ in 0..n_workers { +// s.spawn(|_| { +// let hash = pallas::Point::hash_to_curve("foobar"); + +// let mut stride = 1024; +// let mut tmp: Vec = Vec::with_capacity(stride); +// unsafe { tmp.set_len(stride) }; + +// loop { +// let work = work.fetch_add(stride, Ordering::Relaxed); +// if work >= npoints { +// break; +// } +// if work + stride > npoints { +// stride = npoints - work; +// unsafe { tmp.set_len(stride) }; +// } +// for i in 0..stride { +// let off = (work + i) * 32; +// tmp[i] = hash(&rnd[off..off + 32]); +// } +// #[allow(mutable_transmutes)] +// pallas::Point::batch_normalize(&tmp, unsafe { +// transmute::<&[pallas::Affine], &mut [pallas::Affine]>( +// &ret[work..work + stride], +// ) +// }); +// } +// }) +// } +// }); + +// ret +// } + +// fn as_mut(x: &T) -> &mut T { +// unsafe { &mut *UnsafeCell::raw_get(x as *const _ as *const _) } +// } + +// pub fn gen_scalars(npoints: usize) -> Vec { +// let mut ret: Vec = Vec::with_capacity(npoints); +// unsafe { ret.set_len(npoints) }; + +// let n_workers = rayon::current_num_threads(); +// let work = AtomicUsize::new(0); + +// rayon::scope(|s| { +// for _ in 0..n_workers { +// s.spawn(|_| { +// let mut rng = ChaCha20Rng::from_entropy(); +// loop { +// let work = work.fetch_add(1, Ordering::Relaxed); +// if work >= npoints { +// break; +// } +// *as_mut(&ret[work]) = pallas::Scalar::random(&mut rng); +// } +// }) +// } +// }); + +// ret +// } \ No newline at end of file