Skip to content

Commit

Permalink
multi-thread solver
Browse files Browse the repository at this point in the history
  • Loading branch information
hczphn committed Jan 7, 2025
1 parent 822fa99 commit c9c8888
Show file tree
Hide file tree
Showing 5 changed files with 195 additions and 7 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ libec_go_lib.*
.vscode
.code
*.log
*.witness
.DS_Store
2 changes: 2 additions & 0 deletions circuit-std-rs/src/gnark/hints.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::collections::HashMap;
use std::rc::Rc;
use std::str::FromStr;
use crate::big_int::to_binary_hint;
use crate::gnark::limbs::*;
use crate::gnark::utils::*;
use crate::gnark::emparam::FieldParams;
Expand All @@ -23,6 +24,7 @@ use ark_ff::fields::Field;
use num_traits::One;

pub fn register_hint(hint_registry: &mut HintRegistry<M31>) {
hint_registry.register("myhint.tobinary", to_binary_hint);
hint_registry.register("myhint.mulhint", mul_hint);
hint_registry.register("myhint.simple_rangecheck_hint", simple_rangecheck_hint);
hint_registry.register("myhint.querycounthint", query_count_hint);
Expand Down
85 changes: 84 additions & 1 deletion efc/src/bls_verifier.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use std::thread;
use std::cell::RefCell;
use std::sync::Arc;
use std::rc::Rc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use ark_bls12_381::g2;
use circuit_std_rs::gnark::hints::register_hint;
use expander_compiler::circuit::ir::hint_normalized::witness_solver;
use expander_compiler::frontend::*;
use expander_config::M31ExtConfigSha2;
use num_bigint::BigInt;
Expand Down Expand Up @@ -124,7 +126,6 @@ fn run_expander_pairing(){
for i in 0..test_time {
assignments.push(assignment.clone());
}

let compile_result = compile_generic(&PairingCircuit::default(),CompileOptions::default()).unwrap();
let start_time = std::time::Instant::now();
let witness = compile_result
Expand All @@ -136,4 +137,86 @@ fn run_expander_pairing(){
run_circuit::<M31Config, M31ExtConfigSha2>(&compile_result, witness);
let end_time = std::time::Instant::now();
println!("Generate witness Time: {:?}", end_time.duration_since(start_time));
}

#[test]
fn run_multi_pairing(){
/*
hm E([2128747184964102066453428909345807587167353354433686779055175069717994597853044053001604474195549116663962354781667+600928199043548865756890420428378235956589666349872943435617471245143322438124492345775032317976373712791854412075*u,2673014212711484998033216133821539885421138070306477264866327549730911573831074801525177859765712567167095903919303+843401639836709482028685764607129261791330643868212867532430090507242037514006427793603581220496836139166547085499*u])
sig E([963823355633972122114533498175662916621992470505354782789337615847591161145194281419366975300935939968232579346290+596907481049847637954275493859228934805964488037826922094320375977359016208358247522168009186501678750789366694831*u,1503040898615551538476187079486863259539849948567091887110583169943865184109068018840042625482669131770515482621711+3444166137003222945962463909857562676481832034105318967013156342862358108020440293426901361538632823324929201906078*u])
aggPubkey E([3103244252149090420124940058491173358275189586453938010595576928631997313493844448363005953641905183987079560513835,1296246409150097609953508557969533080097715407458068120115474713311006715865163545587973784795351244083056720382121])
*/
let assignment = PairingCircuit::<M31> {
pubkey: [string_to_m31_array("3103244252149090420124940058491173358275189586453938010595576928631997313493844448363005953641905183987079560513835", 8),
string_to_m31_array("1296246409150097609953508557969533080097715407458068120115474713311006715865163545587973784795351244083056720382121", 8)],
hm: [
[string_to_m31_array("2128747184964102066453428909345807587167353354433686779055175069717994597853044053001604474195549116663962354781667", 8),
string_to_m31_array("600928199043548865756890420428378235956589666349872943435617471245143322438124492345775032317976373712791854412075", 8)],
[string_to_m31_array("2673014212711484998033216133821539885421138070306477264866327549730911573831074801525177859765712567167095903919303", 8),
string_to_m31_array("843401639836709482028685764607129261791330643868212867532430090507242037514006427793603581220496836139166547085499", 8)]
],
sig: [
[string_to_m31_array("963823355633972122114533498175662916621992470505354782789337615847591161145194281419366975300935939968232579346290", 8),
string_to_m31_array("596907481049847637954275493859228934805964488037826922094320375977359016208358247522168009186501678750789366694831", 8),],
[string_to_m31_array("1503040898615551538476187079486863259539849948567091887110583169943865184109068018840042625482669131770515482621711", 8),
string_to_m31_array("3444166137003222945962463909857562676481832034105318967013156342862358108020440293426901361538632823324929201906078", 8)]
]
};
let test_time = 2048;
let mut assignments = vec![];
let mut hint_registries = vec![];
for i in 0..test_time {
assignments.push(assignment.clone());
}
for i in 0..test_time/16 {
let mut hint_registry = HintRegistry::<M31>::new();
register_hint(&mut hint_registry);
hint_registries.push(hint_registry);
}

let assignment_chunks: Vec<Vec<PairingCircuit<M31>>> =
assignments.chunks(16).map(|x| x.to_vec()).collect();
let mut w_s: witness_solver::WitnessSolver::<M31Config>;
if std::fs::metadata("pairing.witness").is_ok() {
println!("The file exists!");
w_s = witness_solver::WitnessSolver::deserialize_from(std::fs::File::open("pairing.witness").unwrap()).unwrap();
} else {
println!("The file does not exist.");
let compile_result = compile_generic(&PairingCircuit::default(), CompileOptions::default()).unwrap();
compile_result.witness_solver.serialize_into(std::fs::File::create("pairing.witness").unwrap()).unwrap();
w_s = compile_result.witness_solver;
}
let witness_solver = Arc::new(w_s);
let start_time = std::time::Instant::now();
let handles = assignment_chunks
.into_iter()
.zip(hint_registries)
.map(|(assignments, hint_registry)| {
let witness_solver = Arc::clone(&witness_solver);
thread::spawn(move || {
let mut hint_registry1 = HintRegistry::<M31>::new();
register_hint(&mut hint_registry1);
witness_solver.solve_witnesses_with_hints(&assignments, &mut hint_registry1).unwrap();
}
)
})
.collect::<Vec<_>>();
// let handles = assignment_chunks
// .into_iter()
// .map(|assignments| {
// let witness_solver = Arc::clone(&witness_solver);
// let hint_register = Arc::clone(&share_hint_registry);
// thread::spawn(move || witness_solver.solve_witnesses_with_hints(&assignments, &mut ).unwrap())
// })
// .collect::<Vec<_>>();
let mut results = Vec::new();
for handle in handles {
results.push(handle.join().unwrap());
}
let end_time = std::time::Instant::now();
println!("Generate witness Time: {:?}", end_time.duration_since(start_time));
// for result in results {
// let output = compile_result.layered_circuit.run(&result);
// assert_eq!(output, vec![true; 16]);
// }
}
112 changes: 107 additions & 5 deletions efc/src/hashtable.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,30 @@
use std::sync::{Arc, Mutex};
use std::thread;
use std::cell::RefCell;
use std::sync::Arc;
use std::rc::Rc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use ark_bls12_381::g2;
use circuit_std_rs::gnark::hints::register_hint;
use expander_compiler::circuit::ir::hint_normalized::witness_solver;
use expander_compiler::frontend::*;
use expander_config::M31ExtConfigSha2;
use num_bigint::BigInt;
use sha2::{Digest, Sha256};
use circuit_std_rs::big_int::{to_binary_hint, big_array_add};
use circuit_std_rs::sha2_m31::check_sha256;
use circuit_std_rs::gnark::emulated::field_bls12381::*;
use circuit_std_rs::gnark::emulated::field_bls12381::e2::*;
use circuit_std_rs::gnark::emulated::sw_bls12381::pairing::*;
use circuit_std_rs::gnark::emulated::sw_bls12381::g1::*;
use circuit_std_rs::gnark::emulated::sw_bls12381::g2::*;
use circuit_std_rs::gnark::element::*;
use expander_compiler::frontend::extra::*;
use circuit_std_rs::big_int::*;
use expander_compiler::{circuit::layered::InputType, frontend::*};

use crate::utils::run_circuit;


const SHA256LEN: usize = 32;
const HASHTABLESIZE: usize = 64;
#[derive(Clone, Copy, Debug)]
Expand All @@ -24,8 +39,8 @@ declare_circuit!(HASHTABLECircuit {
seed: [PublicVariable; SHA256LEN],
output: [[Variable;SHA256LEN];HASHTABLESIZE],
});
impl Define<M31Config> for HASHTABLECircuit<Variable> {
fn define(&self, builder: &mut API<M31Config>) {
impl GenericDefine<M31Config> for HASHTABLECircuit<Variable> {
fn define<Builder: RootAPI<M31Config>>(&self, builder: &mut Builder) {
let mut indices = vec![Vec::<Variable>::new(); HASHTABLESIZE];
if HASHTABLESIZE > 256 {
panic!("HASHTABLESIZE > 256")
Expand Down Expand Up @@ -95,7 +110,7 @@ fn test_hashtable(){
for i in 0..test_time {
assignments.push(assignment.clone());
}
let compile_result = compile(&HASHTABLECircuit::default()).unwrap();
let compile_result = compile_generic(&HASHTABLECircuit::default(), CompileOptions::default()).unwrap();
let witness_solver = compile_result.witness_solver.clone();
let start_time = std::time::Instant::now();
for i in 0..test_time {
Expand Down Expand Up @@ -160,7 +175,7 @@ fn run_expander_hashtable(){
assignments.push(assignment.clone());
}

let compile_result = compile(&HASHTABLECircuit::default()).unwrap();
let compile_result = compile_generic(&HASHTABLECircuit::default(), CompileOptions::default()).unwrap();
let mut hint_registry = HintRegistry::<M31>::new();
hint_registry.register("myhint.tobinary", to_binary_hint);
let start_time = std::time::Instant::now();
Expand All @@ -173,4 +188,91 @@ fn run_expander_hashtable(){
run_circuit::<M31Config, M31ExtConfigSha2>(&compile_result, witness);
let end_time = std::time::Instant::now();
println!("Generate witness Time: {:?}", end_time.duration_since(start_time));
}

#[test]
fn run_multi_hashtable(){
let seed = [0 as u8;32];
let round = 0 as u8;
let start_index = [0 as u8;4];
let mut assignment:HASHTABLECircuit<M31> = HASHTABLECircuit::default();
for i in 0..32 {
assignment.seed[i] = M31::from(seed[i] as u32);
}

assignment.shuffle_round = M31::from(round as u32);
for i in 0..4 {
assignment.start_index[i] = M31::from(start_index[i] as u32);
}
let mut inputs = vec![];
let mut cur_index = start_index;
for i in 0..HASHTABLESIZE{
let mut input = vec![];
input.extend_from_slice(&seed);
input.push(round);
input.extend_from_slice(&cur_index);
if cur_index[0] == 255 {
cur_index[0] = 0;
cur_index[1] += 1;
} else {
cur_index[0] += 1;
}
inputs.push(input);
}
for i in 0..HASHTABLESIZE{
let data = inputs[i].to_vec();
let mut hash = Sha256::new();
hash.update(&data);
let output = hash.finalize();
for j in 0..32 {
assignment.output[i][j] = M31::from(output[j] as u32);
}
}
let test_time = 2880;
let mut assignments = vec![];
for i in 0..test_time {
assignments.push(assignment.clone());
}

let assignment_chunks: Vec<Vec<HASHTABLECircuit<M31>>> =
assignments.chunks(16).map(|x| x.to_vec()).collect();
let mut w_s: witness_solver::WitnessSolver::<M31Config>;
if std::fs::metadata("hashtable.witness").is_ok() {
println!("The file exists!");
w_s = witness_solver::WitnessSolver::deserialize_from(std::fs::File::open("hashtable.witness").unwrap()).unwrap();
} else {
println!("The file does not exist.");
let compile_result = compile_generic(&HASHTABLECircuit::default(), CompileOptions::default()).unwrap();
compile_result.witness_solver.serialize_into(std::fs::File::create("hashtable.witness").unwrap()).unwrap();
w_s = compile_result.witness_solver;
}
let witness_solver = Arc::new(w_s);
let start_time = std::time::Instant::now();
let handles = assignment_chunks
.into_iter()
.map(|(assignments)| {
let witness_solver = Arc::clone(&witness_solver);
thread::spawn(move || {
println!("start");
let mut hint_registry1 = HintRegistry::<M31>::new();
register_hint(&mut hint_registry1);
witness_solver.solve_witnesses_with_hints(&assignments, &mut hint_registry1).unwrap();
}
)
})
.collect::<Vec<_>>();
// let handles = assignment_chunks
// .into_iter()
// .map(|assignments| {
// let witness_solver = Arc::clone(&witness_solver);
// let hint_register = Arc::clone(&share_hint_registry);
// thread::spawn(move || witness_solver.solve_witnesses_with_hints(&assignments, &mut ).unwrap())
// })
// .collect::<Vec<_>>();
let mut results = Vec::new();
for handle in handles {
results.push(handle.join().unwrap());
}
let end_time = std::time::Instant::now();
println!("Generate witness Time: {:?}", end_time.duration_since(start_time));
}
2 changes: 1 addition & 1 deletion expander_compiler/src/hints/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ impl<F: Field> HintRegistry<F> {
) {
let id = hint_key_to_id(key);
if self.hints.contains_key(&id) {
panic!("Hint with id {} already exists", id);
panic!("Hint with id {} already exists, key{}", id, key);
}
self.hints.insert(id, Box::new(hint));
}
Expand Down

0 comments on commit c9c8888

Please sign in to comment.