Skip to content
This repository has been archived by the owner on Feb 19, 2024. It is now read-only.

Commit

Permalink
spmvm
Browse files Browse the repository at this point in the history
  • Loading branch information
Hanting Zhang committed Dec 13, 2023
1 parent c203abd commit 2955efe
Show file tree
Hide file tree
Showing 6 changed files with 792 additions and 0 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ force-adx = [ "semolina/force-adx" ]
cuda-mobile = []

[dependencies]
abomonation = "0.7.3"
abomonation_derive = { version = "0.1.0", package = "abomonation_derive_ng" }
semolina = "~0.1.3"
sppark = { git = "https://github.com/lurk-lab/sppark", branch = "gpu-spmvm" }
pasta_curves = { git = "https://github.com/lurk-lab/pasta_curves", branch="dev", version = ">=0.3.1, <=0.5", features = ["repr-c"] }
Expand Down
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
// Licensed under the Apache License, Version 2.0, see LICENSE for details.
// SPDX-License-Identifier: Apache-2.0

pub mod spmvm;
pub mod utils;

extern crate semolina;

#[cfg(feature = "cuda")]
Expand Down
92 changes: 92 additions & 0 deletions src/spmvm/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#![allow(non_snake_case)]

use std::marker::PhantomData;

use pasta_curves::group::ff::PrimeField;

use crate::utils::SparseMatrix;

pub mod pallas;
pub mod vesta;

#[repr(C)]
pub struct CudaSparseMatrix<'a, F> {
pub data: *const F,
pub col_idx: *const usize,
pub row_ptr: *const usize,

pub num_rows: usize,
pub num_cols: usize,
pub nnz: usize,

_p: PhantomData<&'a F>,
}

impl<'a, F> CudaSparseMatrix<'a, F> {
pub fn new(
data: &[F],
col_idx: &[usize],
row_ptr: &[usize],
num_rows: usize,
num_cols: usize,
) -> Self {
assert_eq!(
data.len(),
col_idx.len(),
"data and col_idx length mismatch"
);
assert_eq!(
row_ptr.len(),
num_rows + 1,
"row_ptr length and num_rows mismatch"
);

let nnz = data.len();
CudaSparseMatrix {
data: data.as_ptr(),
col_idx: col_idx.as_ptr(),
row_ptr: row_ptr.as_ptr(),
num_rows,
num_cols,
nnz,
_p: PhantomData,
}
}
}

impl<'a, F: PrimeField> From<&'a SparseMatrix<F>> for CudaSparseMatrix<'a, F> {
fn from(value: &SparseMatrix<F>) -> Self {
CudaSparseMatrix::new(
&value.data,
&value.indices,
&value.indptr,
value.indptr.len() - 1,
value.cols,
)
}
}

#[repr(C)]
pub struct CudaWitness<'a, F> {
pub W: *const F,
pub u: *const F,
pub U: *const F,
pub nW: usize,
pub nU: usize,
_p: PhantomData<&'a F>,
}

impl<'a, F> CudaWitness<'a, F> {
pub fn new(W: &[F], u: &F, U: &[F]) -> Self {
let nW = W.len();
let nU = U.len();
CudaWitness {
W: W.as_ptr(),
u: u as *const _,
U: U.as_ptr(),
nW,
nU,
_p: PhantomData,
}
}
}
217 changes: 217 additions & 0 deletions src/spmvm/pallas.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
#![allow(non_snake_case)]

use std::ffi::c_void;

use pasta_curves::{pallas, group::ff::Field};

use super::{CudaSparseMatrix, CudaWitness};


#[repr(C)]
#[derive(Debug, Clone)]
pub struct SpMVMContextPallas {
pub d_data: *const c_void,
pub d_col_idx: *const c_void,
pub d_row_ptr: *const c_void,

pub num_rows: usize,
pub num_cols: usize,
pub nnz: usize,

pub d_scalars: *const c_void,
pub d_out: *const c_void,
}

unsafe impl Send for SpMVMContextPallas {}
unsafe impl Sync for SpMVMContextPallas {}

impl Default for SpMVMContextPallas {
fn default() -> Self {
Self {
d_data: core::ptr::null(),
d_col_idx: core::ptr::null(),
d_row_ptr: core::ptr::null(),
num_rows: 0,
num_cols: 0,
nnz: 0,
d_scalars: core::ptr::null(),
d_out: core::ptr::null(),
}
}
}

// TODO: check for device-side memory leaks
impl Drop for SpMVMContextPallas {
fn drop(&mut self) {
extern "C" {
fn drop_spmvm_context_pallas(by_ref: &SpMVMContextPallas);
}
unsafe {
drop_spmvm_context_pallas(std::mem::transmute::<&_, &_>(self))
};

self.d_data = core::ptr::null();
self.d_col_idx = core::ptr::null();
self.d_row_ptr = core::ptr::null();

self.num_rows = 0;
self.num_cols = 0;
self.nnz = 0;

self.d_scalars = core::ptr::null();
self.d_out = core::ptr::null();
}
}

pub fn sparse_matrix_mul_pallas(
csr: &CudaSparseMatrix<pallas::Scalar>,
scalars: &[pallas::Scalar],
nthreads: usize,
) -> Vec<pallas::Scalar> {
extern "C" {
fn cuda_sparse_matrix_mul_pallas(
csr: *const CudaSparseMatrix<pallas::Scalar>,
scalars: *const pallas::Scalar,
out: *mut pallas::Scalar,
nthreads: usize,
) -> sppark::Error;
}

let mut out = vec![pallas::Scalar::ZERO; csr.num_rows];
let err = unsafe {
cuda_sparse_matrix_mul_pallas(
csr as *const _,
scalars.as_ptr(),
out.as_mut_ptr(),
nthreads,
)
};
if err.code != 0 {
panic!("{}", String::from(err));
}

out
}

pub fn sparse_matrix_witness_init_pallas(
csr: &CudaSparseMatrix<pallas::Scalar>,
) -> SpMVMContextPallas {
extern "C" {
fn cuda_sparse_matrix_witness_init_pallas(
csr: *const CudaSparseMatrix<pallas::Scalar>,
context: *mut SpMVMContextPallas,
) -> sppark::Error;
}

let mut context = SpMVMContextPallas::default();
let err = unsafe {
cuda_sparse_matrix_witness_init_pallas(
csr as *const _,
&mut context as *mut _,
)
};
if err.code != 0 {
panic!("{}", String::from(err));
}

context
}

pub fn sparse_matrix_witness_pallas(
csr: &CudaSparseMatrix<pallas::Scalar>,
witness: &CudaWitness<pallas::Scalar>,
buffer: &mut [pallas::Scalar],
nthreads: usize,
) {
extern "C" {
fn cuda_sparse_matrix_witness_pallas(
csr: *const CudaSparseMatrix<pallas::Scalar>,
witness: *const CudaWitness<pallas::Scalar>,
out: *mut pallas::Scalar,
nthreads: usize,
) -> sppark::Error;
}

assert_eq!(
witness.nW + witness.nU + 1,
csr.num_cols,
"invalid witness size"
);

let err = unsafe {
cuda_sparse_matrix_witness_pallas(
csr as *const _,
witness as *const _,
buffer.as_mut_ptr(),
nthreads,
)
};
if err.code != 0 {
panic!("{}", String::from(err));
}
}

pub fn sparse_matrix_witness_with_pallas(
context: &SpMVMContextPallas,
witness: &CudaWitness<pallas::Scalar>,
buffer: &mut [pallas::Scalar],
nthreads: usize,
) {
extern "C" {
fn cuda_sparse_matrix_witness_with_pallas(
context: *const SpMVMContextPallas,
witness: *const CudaWitness<pallas::Scalar>,
out: *mut pallas::Scalar,
nthreads: usize,
) -> sppark::Error;
}

assert_eq!(
witness.nW + witness.nU + 1,
context.num_cols,
"invalid witness size"
);

let err = unsafe {
cuda_sparse_matrix_witness_with_pallas(
context as *const _,
witness as *const _,
buffer.as_mut_ptr(),
nthreads,
)
};
if err.code != 0 {
panic!("{}", String::from(err));
}
}

pub fn sparse_matrix_witness_pallas_cpu(
csr: &CudaSparseMatrix<pallas::Scalar>,
witness: &CudaWitness<pallas::Scalar>,
buffer: &mut [pallas::Scalar],
) {
extern "C" {
fn cuda_sparse_matrix_witness_pallas_cpu(
csr: *const CudaSparseMatrix<pallas::Scalar>,
witness: *const CudaWitness<pallas::Scalar>,
out: *mut pallas::Scalar,
) -> sppark::Error;
}

assert_eq!(
witness.nW + witness.nU + 1,
csr.num_cols,
"invalid witness size"
);

let err = unsafe {
cuda_sparse_matrix_witness_pallas_cpu(
csr as *const _,
witness as *const _,
buffer.as_mut_ptr(),
)
};
if err.code != 0 {
panic!("{}", String::from(err));
}
}
Loading

0 comments on commit 2955efe

Please sign in to comment.