Skip to content
This repository has been archived by the owner on Feb 19, 2024. It is now read-only.

Commit

Permalink
so close
Browse files Browse the repository at this point in the history
  • Loading branch information
Hanting Zhang committed Dec 2, 2023
1 parent c2e1370 commit 8a10540
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 3 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ cuda-mobile = []

[dependencies]
semolina = "~0.1.3"
sppark = { git = "https://github.com/lurk-lab/sppark.git", branch = "preallocated-msm" }
sppark = { git = "https://github.com/lurk-lab/sppark.git", branch = "pushing-the-limit" }
pasta_curves = { git = "https://github.com/lurk-lab/pasta_curves", branch = "dev", version = ">=0.3.1, <=0.5", features = ["repr-c"] }
paste = "1.0.14"

Expand Down
11 changes: 9 additions & 2 deletions cuda/pallas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,17 @@ typedef bucket_t::affine_t affine_t;
typedef vesta_t scalar_t;

#include <msm/pippenger.cuh>
#include <spmvm/spmvm.cuh>

#ifndef __CUDA_ARCH__

extern "C" void drop_msm_context_pallas(msm_context_t<affine_t::mem_t> &ref) {
extern "C" RustError spmvm_pallas(scalar_t *scalars, size_t nscalars)
{
return double_scalars<scalar_t>(scalars, nscalars);
}

extern "C" void drop_msm_context_pallas(msm_context_t<affine_t::mem_t> &ref)
{
CUDA_OK(cudaFree(ref.d_points));
}

Expand All @@ -35,7 +42,7 @@ extern "C" RustError cuda_pippenger_pallas(point_t *out, const affine_t points[]
}

extern "C" RustError cuda_pippenger_pallas_with(point_t *out, msm_context_t<affine_t::mem_t> *msm_context, size_t npoints,
const scalar_t scalars[])
const scalar_t scalars[])
{
return mult_pippenger_with<bucket_t, point_t, affine_t, scalar_t>(out, msm_context, npoints, scalars);
}
Expand Down
41 changes: 41 additions & 0 deletions examples/spmvm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright Supranational LLC
// Licensed under the Apache License, Version 2.0, see LICENSE for details.
// SPDX-License-Identifier: Apache-2.0

use std::time::Instant;

use pasta_curves::group::ff::PrimeField;
use pasta_msm::spmvm::double_pallas;

pub fn generate_scalars<F: PrimeField>(
len: usize,
) -> Vec<F> {
let mut rng = rand::thread_rng();
let scalars = (0..len)
.map(|_| F::random(&mut rng))
.collect::<Vec<_>>();

scalars
}

/// cargo run --release --example spmvm
fn main() {
let npow: usize = std::env::var("NPOW")
.unwrap_or("23".to_string())
.parse()
.unwrap();
let n = 1usize << npow;

let mut scalars = generate_scalars(n);

let start = Instant::now();
let double_scalars = scalars.iter().map(|x| x + x).collect::<Vec<_>>();
println!("cpu took: {:?}", start.elapsed());

let start = Instant::now();
double_pallas(&mut scalars);
println!("gpu took: {:?}", start.elapsed());

assert_eq!(double_scalars, scalars);
println!("success!");
}
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// Licensed under the Apache License, Version 2.0, see LICENSE for details.
// SPDX-License-Identifier: Apache-2.0

pub mod spmvm;

extern crate semolina;

#[cfg(feature = "cuda")]
Expand Down
21 changes: 21 additions & 0 deletions src/spmvm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
use pasta_curves::pallas;

pub fn sparse_matrix_mul_pallas(scalars: &mut [pallas::Scalar]) {
extern "C" {
fn spmvm_pallas(
scalars: *mut pallas::Scalar,
nscalars: usize,
) -> sppark::Error;
}

let nscalars = scalars.len();
let err = unsafe {
spmvm_pallas(
scalars.as_mut_ptr(),
nscalars,
)
};
if err.code != 0 {
panic!("{}", String::from(err));
}
}

0 comments on commit 8a10540

Please sign in to comment.