diff --git a/.gitignore b/.gitignore index 3f7f23880..b432b7fc9 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ target log.txt logs/ table_cache_dev_* +.DS_Store diff --git a/Cargo.lock b/Cargo.lock index 24e4c756d..c80dfdedb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -108,9 +108,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.90" +version = "1.0.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37bf3594c4c988a53154954629820791dde498571819ae4ca50ca811e060cc95" +checksum = "c042108f3ed77fd83760a5fd79b53be043192bb3b9dba91d8c574c0ada7850c8" [[package]] name = "ark-std" @@ -262,7 +262,6 @@ dependencies = [ "ceno_emul", "cfg-if", "clap", - "const_env", "criterion", "ff", "ff_ext", @@ -409,26 +408,6 @@ dependencies = [ "tiny-keccak", ] -[[package]] -name = "const_env" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e9e4f72c6e3398ca6da372abd9affd8f89781fe728869bbf986206e9af9627e" -dependencies = [ - "const_env_impl", -] - -[[package]] -name = "const_env_impl" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a4f51209740b5e1589e702b3044cdd4562cef41b6da404904192ffffb852d62" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", -] - [[package]] name = "constant_time_eq" version = "0.3.1" @@ -719,7 +698,6 @@ version = "0.1.0" dependencies = [ "ark-std", "cfg-if", - "const_env", "criterion", "crossbeam-channel", "ff", @@ -1630,18 +1608,18 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] name = "serde" -version = "1.0.210" +version = "1.0.213" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a" +checksum = "3ea7893ff5e2466df8d720bb615088341b295f849602c6956047f8f80f0e9bc1" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.210" +version = "1.0.213" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f" +checksum = "7e85ad2009c50b58e87caa8cd6dac16bdf511bbfb7af6c33df902396aa480fa5" dependencies = [ "proc-macro2", "quote", @@ -1692,7 +1670,6 @@ version = "0.1.0" dependencies = [ "ark-std", "cfg-if", - "const_env", "criterion", "ff", "ff_ext", @@ -1800,7 +1777,6 @@ name = "sumcheck" version = "0.1.0" dependencies = [ "ark-std", - "const_env", "criterion", "crossbeam-channel", "ff", diff --git a/Cargo.toml b/Cargo.toml index fc601342b..aea908372 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,11 +26,10 @@ version = "0.1.0" [workspace.dependencies] ark-std = "0.4" cfg-if = "1.0" -const_env = "0.1" criterion = { version = "0.5", features = ["html_reports"] } crossbeam-channel = "0.5" ff = "0.13" -goldilocks = { git = "https://github.com/zhenfeizhang/Goldilocks" } +goldilocks = { git = "https://github.com/hero78119/Goldilocks" } itertools = "0.13" paste = "1" plonky2 = "0.2" @@ -52,8 +51,5 @@ tracing = { version = "0.1", features = [ tracing-flame = "0.2" tracing-subscriber = { version = "0.3", features = ["env-filter"] } -[patch."https://github.com/zhenfeizhang/Goldilocks"] -goldilocks = { git = "https://github.com/hero78119/Goldilocks" } - [profile.release] lto = "thin" diff --git a/build.rs b/build.rs deleted file mode 100644 index 3e31cb0a9..000000000 --- a/build.rs +++ /dev/null @@ -1,3 +0,0 @@ -fn main() { - println!("cargo:rerun-if-env-changed=RAYON_NUM_THREADS"); -} diff --git a/ceno_emul/src/rv32im.rs b/ceno_emul/src/rv32im.rs index a150918e3..2abd03cce 100644 --- a/ceno_emul/src/rv32im.rs +++ b/ceno_emul/src/rv32im.rs @@ -218,28 +218,6 @@ impl DecodedInstruction { } } - #[allow(dead_code)] - pub fn from_raw(kind: InsnKind, rs1: u32, rs2: u32, rd: u32) -> Self { - // limit the range of inputs - let rs2 = rs2 & 0x1f; // 5bits mask - let rs1 = rs1 & 0x1f; - let rd = rd & 0x1f; - let func7 = kind.codes().func7; - let func3 = kind.codes().func3; - let opcode = kind.codes().opcode; - let insn = func7 << 25 | rs2 << 20 | rs1 << 15 | func3 << 12 | rd << 7 | opcode; - Self { - insn, - top_bit: func7 | 0x80, - func7, - rs2, - rs1, - func3, - rd, - opcode, - } - } - pub fn encoded(&self) -> u32 { self.insn } diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index 36d2e77cb..ac16a6de9 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -36,7 +36,6 @@ thread_local = "1.1" [dev-dependencies] base64 = "0.22" cfg-if.workspace = true -const_env.workspace = true criterion.workspace = true pprof.workspace = true serde_json.workspace = true diff --git a/ceno_zkvm/benches/riscv_add.rs b/ceno_zkvm/benches/riscv_add.rs index 9d69e120f..16d5cfe67 100644 --- a/ceno_zkvm/benches/riscv_add.rs +++ b/ceno_zkvm/benches/riscv_add.rs @@ -7,7 +7,6 @@ use ceno_zkvm::{ scheme::prover::ZKVMProver, structs::{ZKVMConstraintSystem, ZKVMFixedTraces}, }; -use const_env::from_env; use criterion::*; use ceno_zkvm::scheme::constants::MAX_NUM_VARIABLES; @@ -37,31 +36,9 @@ cfg_if::cfg_if! { criterion_main!(op_add); const NUM_SAMPLES: usize = 10; -#[from_env] -const RAYON_NUM_THREADS: usize = 8; fn bench_add(c: &mut Criterion) { type Pcs = BasefoldDefault; - let max_threads = { - if !RAYON_NUM_THREADS.is_power_of_two() { - #[cfg(not(feature = "non_pow2_rayon_thread"))] - { - panic!( - "add --features non_pow2_rayon_thread to enable unsafe feature which support non pow of 2 rayon thread pool" - ); - } - - #[cfg(feature = "non_pow2_rayon_thread")] - { - use sumcheck::{local_thread_pool::create_local_pool_once, util::ceil_log2}; - let max_thread_id = 1 << ceil_log2(RAYON_NUM_THREADS); - create_local_pool_once(1 << ceil_log2(RAYON_NUM_THREADS), true); - max_thread_id - } - } else { - RAYON_NUM_THREADS - } - }; let mut zkvm_cs = ZKVMConstraintSystem::default(); let _ = zkvm_cs.register_opcode_circuit::>(); let mut zkvm_fixed_traces = ZKVMFixedTraces::default(); @@ -128,7 +105,6 @@ fn bench_add(c: &mut Criterion) { commit, &[], num_instances, - max_threads, &mut transcript, &challenges, ) diff --git a/ceno_zkvm/examples/riscv_opcodes.rs b/ceno_zkvm/examples/riscv_opcodes.rs index cc71355b4..3db8ba758 100644 --- a/ceno_zkvm/examples/riscv_opcodes.rs +++ b/ceno_zkvm/examples/riscv_opcodes.rs @@ -11,7 +11,6 @@ use ceno_zkvm::{ }, }; use clap::Parser; -use const_env::from_env; use ceno_emul::{ ByteAddr, CENO_PLATFORM, EmuContext, @@ -31,9 +30,6 @@ use tracing_flame::FlameLayer; use tracing_subscriber::{EnvFilter, Registry, fmt, layer::SubscriberExt}; use transcript::Transcript; -#[from_env] -const RAYON_NUM_THREADS: usize = 8; - const PROGRAM_SIZE: usize = 512; // For now, we assume registers // - x0 is not touched, @@ -82,27 +78,6 @@ fn main() { type E = GoldilocksExt2; type Pcs = Basefold; - let max_threads = { - if !RAYON_NUM_THREADS.is_power_of_two() { - #[cfg(not(feature = "non_pow2_rayon_thread"))] - { - panic!( - "add --features non_pow2_rayon_thread to enable unsafe feature which support non pow of 2 rayon thread pool" - ); - } - - #[cfg(feature = "non_pow2_rayon_thread")] - { - use sumcheck::{local_thread_pool::create_local_pool_once, util::ceil_log2}; - let max_thread_id = 1 << ceil_log2(RAYON_NUM_THREADS); - create_local_pool_once(1 << ceil_log2(RAYON_NUM_THREADS), true); - max_thread_id - } - } else { - RAYON_NUM_THREADS - } - }; - let (flame_layer, _guard) = FlameLayer::with_file("./tracing.folded").unwrap(); let subscriber = Registry::default() .with( @@ -287,7 +262,7 @@ fn main() { let transcript = Transcript::new(b"riscv"); let mut zkvm_proof = prover - .create_proof(zkvm_witness, pi, max_threads, transcript) + .create_proof(zkvm_witness, pi, transcript) .expect("create_proof failed"); println!( diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index b0106c996..f39f6b301 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -173,7 +173,13 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { NR: Into, N: FnOnce() -> NR, { - self.namespace(|| "require_equal", |cb| cb.cs.require_zero(name_fn, a - b)) + self.namespace( + || "require_equal", + |cb| { + cb.cs + .require_zero(name_fn, a.to_monomial_form() - b.to_monomial_form()) + }, + ) } pub fn require_one(&mut self, name_fn: N, expr: Expression) -> Result<(), ZKVMError> diff --git a/ceno_zkvm/src/expression.rs b/ceno_zkvm/src/expression.rs index 90f2bba97..4d8aec48c 100644 --- a/ceno_zkvm/src/expression.rs +++ b/ceno_zkvm/src/expression.rs @@ -639,22 +639,6 @@ impl WitIn { } } -#[macro_export] -/// this is to avoid non-monomial expression -macro_rules! create_witin_from_expr { - // Handle the case for a single expression - ($name:expr, $builder:expr, $debug:expr, $e:expr) => { - WitIn::from_expr($name, $builder, $e, $debug) - }; - // Recursively handle multiple expressions and create a flat tuple with error handling - ($name:expr, $builder:expr, $debug:expr, $e:expr, $($rest:expr),+) => { - { - // Return a Result tuple, handling errors - Ok::<_, ZKVMError>((WitIn::from_expr($name, $builder, $e, $debug)?, $(WitIn::from_expr($name, $builder, $rest)?),*)) - } - }; -} - pub trait ToExpr { type Output; fn expr(&self) -> Self::Output; @@ -757,7 +741,9 @@ pub mod fmt { ) -> String { match expression { Expression::WitIn(wit_in) => { - wtns.push(*wit_in); + if !wtns.contains(wit_in) { + wtns.push(*wit_in); + } format!("WitIn({})", wit_in) } Expression::Challenge(id, pow, scaler, offset) => { diff --git a/ceno_zkvm/src/gadgets/is_lt.rs b/ceno_zkvm/src/gadgets/is_lt.rs index e35ea6b7a..f8d40cdee 100644 --- a/ceno_zkvm/src/gadgets/is_lt.rs +++ b/ceno_zkvm/src/gadgets/is_lt.rs @@ -109,11 +109,23 @@ impl IsLtConfig { lhs: u64, rhs: u64, ) -> Result<(), ZKVMError> { - let is_lt = lhs < rhs; - set_val!(instance, self.is_lt, is_lt as u64); + set_val!(instance, self.is_lt, (lhs < rhs) as u64); self.config.assign_instance(instance, lkm, lhs, rhs)?; Ok(()) } + + pub fn assign_instance_signed( + &self, + instance: &mut [MaybeUninit], + lkm: &mut LkMultiplicity, + lhs: SWord, + rhs: SWord, + ) -> Result<(), ZKVMError> { + set_val!(instance, self.is_lt, (lhs < rhs) as u64); + self.config + .assign_instance_signed(instance, lkm, lhs, rhs)?; + Ok(()) + } } #[derive(Debug, Clone)] @@ -337,12 +349,9 @@ impl InnerSignedLtConfig { 1, )?; - // Convert two's complement representation into field arithmetic. - // Example: 0xFFFF_FFFF = 2^32 - 1 --> shift --> -1 - let neg_shift = -Expression::Constant((1_u64 << 32).into()); - let lhs_value = lhs.value() + is_lhs_neg.expr() * neg_shift.clone(); - let rhs_value = rhs.value() + is_rhs_neg.expr() * neg_shift; - + // Convert to field arithmetic. + let lhs_value = lhs.to_field_expr(is_lhs_neg.expr()); + let rhs_value = rhs.to_field_expr(is_rhs_neg.expr()); let config = InnerLtConfig::construct_circuit( cb, format!("{name} (lhs < rhs)"), diff --git a/ceno_zkvm/src/instructions/riscv.rs b/ceno_zkvm/src/instructions/riscv.rs index 5b88d8b92..96b192d60 100644 --- a/ceno_zkvm/src/instructions/riscv.rs +++ b/ceno_zkvm/src/instructions/riscv.rs @@ -17,6 +17,7 @@ pub mod mulh; pub mod shift; pub mod shift_imm; pub mod slt; +pub mod slti; pub mod sltu; mod b_insn; @@ -33,6 +34,8 @@ mod memory; mod s_insn; #[cfg(test)] mod test; +#[cfg(test)] +mod test_utils; pub trait RIVInstruction { const INST_KIND: InsnKind; diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm.rs b/ceno_zkvm/src/instructions/riscv/arith_imm.rs index 94f676f01..74e19acb2 100644 --- a/ceno_zkvm/src/instructions/riscv/arith_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/arith_imm.rs @@ -88,21 +88,12 @@ mod test { use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, - instructions::Instruction, + instructions::{Instruction, riscv::test_utils::imm_i}, scheme::mock_prover::{MOCK_PC_START, MockProver}, }; use super::AddiInstruction; - fn imm(imm: i32) -> u32 { - // imm is 12 bits in B-type - const IMM_MAX: i32 = 2i32.pow(12); - if imm.is_negative() { - (IMM_MAX + imm) as u32 - } else { - imm as u32 - } - } #[test] fn test_opcode_addi() { let mut cs = ConstraintSystem::::new(|| "riscv"); @@ -118,7 +109,7 @@ mod test { .unwrap() .unwrap(); - let insn_code = encode_rv32(InsnKind::ADDI, 2, 0, 4, imm(3)); + let insn_code = encode_rv32(InsnKind::ADDI, 2, 0, 4, imm_i(3)); let (raw_witin, lkm) = AddiInstruction::::assign_instances( &config, cb.cs.num_witin as usize, @@ -162,7 +153,7 @@ mod test { .unwrap() .unwrap(); - let insn_code = encode_rv32(InsnKind::ADDI, 2, 0, 4, imm(-3)); + let insn_code = encode_rv32(InsnKind::ADDI, 2, 0, 4, imm_i(-3)); let (raw_witin, lkm) = AddiInstruction::::assign_instances( &config, cb.cs.num_witin as usize, diff --git a/ceno_zkvm/src/instructions/riscv/branch/test.rs b/ceno_zkvm/src/instructions/riscv/branch/test.rs index 6b2d6fc01..36746fff3 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/test.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/test.rs @@ -7,23 +7,13 @@ use super::*; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, error::ZKVMError, - instructions::Instruction, + instructions::{Instruction, riscv::test_utils::imm_b}, scheme::mock_prover::{MOCK_PC_START, MockProver}, }; const A: Word = 0xbead1010; const B: Word = 0xef552020; -fn imm(imm: i32) -> u32 { - // imm is 13 bits in B-type - const IMM_MAX: i32 = 2i32.pow(13); - if imm.is_negative() { - (IMM_MAX + imm) as u32 - } else { - imm as u32 - } -} - #[test] fn test_opcode_beq() { impl_opcode_beq(false); @@ -44,7 +34,7 @@ fn impl_opcode_beq(equal: bool) { .unwrap() .unwrap(); - let insn_code = encode_rv32(InsnKind::BEQ, 2, 3, 0, imm(8)); + let insn_code = encode_rv32(InsnKind::BEQ, 2, 3, 0, imm_b(8)); let pc_offset = if equal { 8 } else { PC_STEP_SIZE }; let (raw_witin, lkm) = BeqInstruction::assign_instances(&config, cb.cs.num_witin as usize, vec![ @@ -93,7 +83,7 @@ fn impl_opcode_bne(equal: bool) { .unwrap() .unwrap(); - let insn_code = encode_rv32(InsnKind::BNE, 2, 3, 0, imm(8)); + let insn_code = encode_rv32(InsnKind::BNE, 2, 3, 0, imm_b(8)); let pc_offset = if equal { PC_STEP_SIZE } else { 8 }; let (raw_witin, lkm) = BneInstruction::assign_instances(&config, cb.cs.num_witin as usize, vec![ @@ -145,7 +135,7 @@ fn impl_bltu_circuit(taken: bool, a: u32, b: u32) -> Result<(), ZKVMError> { MOCK_PC_START + PC_STEP_SIZE }; - let insn_code = encode_rv32(InsnKind::BLTU, 2, 3, 0, imm(-8)); + let insn_code = encode_rv32(InsnKind::BLTU, 2, 3, 0, imm_b(-8)); println!("{:#b}", insn_code); let (raw_witin, lkm) = BltuInstruction::assign_instances(&config, circuit_builder.cs.num_witin as usize, vec![ @@ -198,7 +188,7 @@ fn impl_bgeu_circuit(taken: bool, a: u32, b: u32) -> Result<(), ZKVMError> { MOCK_PC_START + PC_STEP_SIZE }; - let insn_code = encode_rv32(InsnKind::BGEU, 2, 3, 0, imm(-8)); + let insn_code = encode_rv32(InsnKind::BGEU, 2, 3, 0, imm_b(-8)); let (raw_witin, lkm) = BgeuInstruction::assign_instances(&config, circuit_builder.cs.num_witin as usize, vec![ StepRecord::new_b_instruction( @@ -251,7 +241,7 @@ fn impl_blt_circuit(taken: bool, a: i32, b: i32) -> Result<(), ZKVMError> { MOCK_PC_START + PC_STEP_SIZE }; - let insn_code = encode_rv32(InsnKind::BLT, 2, 3, 0, imm(-8)); + let insn_code = encode_rv32(InsnKind::BLT, 2, 3, 0, imm_b(-8)); let (raw_witin, lkm) = BltInstruction::assign_instances(&config, circuit_builder.cs.num_witin as usize, vec![ StepRecord::new_b_instruction( @@ -304,7 +294,7 @@ fn impl_bge_circuit(taken: bool, a: i32, b: i32) -> Result<(), ZKVMError> { MOCK_PC_START + PC_STEP_SIZE }; - let insn_code = encode_rv32(InsnKind::BGE, 2, 3, 0, imm(-8)); + let insn_code = encode_rv32(InsnKind::BGE, 2, 3, 0, imm_b(-8)); let (raw_witin, lkm) = BgeInstruction::assign_instances(&config, circuit_builder.cs.num_witin as usize, vec![ StepRecord::new_b_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/jump/test.rs b/ceno_zkvm/src/instructions/riscv/jump/test.rs index f1152ce7c..a1b17e911 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/test.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/test.rs @@ -5,21 +5,15 @@ use multilinear_extensions::mle::IntoMLEs; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, - instructions::Instruction, + instructions::{ + Instruction, + riscv::test_utils::{imm_j, imm_u}, + }, scheme::mock_prover::{MOCK_PC_START, MockProver}, }; use super::{AuipcInstruction, JalInstruction, JalrInstruction, LuiInstruction}; -fn imm_j(imm: i32) -> u32 { - // imm is 21 bits in J-type - const IMM_MAX: i32 = 2i32.pow(21); - if imm.is_negative() { - (IMM_MAX + imm) as u32 - } else { - imm as u32 - } -} #[test] fn test_opcode_jal() { let mut cs = ConstraintSystem::::new(|| "riscv"); @@ -113,10 +107,6 @@ fn test_opcode_jalr() { ); } -fn imm_u(imm: u32) -> u32 { - // valid imm is imm[12:31] in U-type - imm << 12 -} #[test] fn test_opcode_lui() { let mut cs = ConstraintSystem::::new(|| "riscv"); diff --git a/ceno_zkvm/src/instructions/riscv/shift_imm.rs b/ceno_zkvm/src/instructions/riscv/shift_imm.rs index b5cdda27f..166d91e21 100644 --- a/ceno_zkvm/src/instructions/riscv/shift_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/shift_imm.rs @@ -3,35 +3,51 @@ use crate::{ Value, circuit_builder::CircuitBuilder, error::ZKVMError, - gadgets::DivConfig, + expression::{Expression, ToExpr, WitIn}, + gadgets::{AssertLTConfig, IsLtConfig}, instructions::{ Instruction, riscv::{constants::UInt, i_insn::IInstructionConfig}, }, + set_val, witness::LkMultiplicity, }; -use ceno_emul::StepRecord; +use ceno_emul::{InsnKind, StepRecord}; use ff_ext::ExtensionField; use std::{marker::PhantomData, mem::MaybeUninit}; -pub struct InstructionConfig { +pub struct ShiftImmConfig { i_insn: IInstructionConfig, - imm: UInt, + imm: WitIn, + rs1_read: UInt, rd_written: UInt, - remainder: UInt, - div_config: DivConfig, + outflow: WitIn, + assert_lt_config: AssertLTConfig, + + // SRAI + is_lt_config: Option, } pub struct ShiftImmInstruction(PhantomData<(E, I)>); +pub struct SlliOp; +impl RIVInstruction for SlliOp { + const INST_KIND: ceno_emul::InsnKind = ceno_emul::InsnKind::SLLI; +} + +pub struct SraiOp; +impl RIVInstruction for SraiOp { + const INST_KIND: ceno_emul::InsnKind = ceno_emul::InsnKind::SRAI; +} + pub struct SrliOp; impl RIVInstruction for SrliOp { - const INST_KIND: ceno_emul::InsnKind = ceno_emul::InsnKind::SRLI; + const INST_KIND: ceno_emul::InsnKind = InsnKind::SRLI; } impl Instruction for ShiftImmInstruction { - type InstructionConfig = InstructionConfig; + type InstructionConfig = ShiftImmConfig; fn name() -> String { format!("{:?}", I::INST_KIND) @@ -40,36 +56,77 @@ impl Instruction for ShiftImmInstructio fn construct_circuit( circuit_builder: &mut CircuitBuilder, ) -> Result { - let mut imm = UInt::new(|| "imm", circuit_builder)?; - let mut rd_written = UInt::new(|| "rd_written", circuit_builder)?; - - // Note: `imm` is set to 2**imm (upto 32 bit) just for SRLI for efficient verification - // Goal is to constrain: - // rs1 == rd_written * imm + remainder - let remainder = UInt::new(|| "remainder", circuit_builder)?; - let div_config = DivConfig::construct_circuit( + // Note: `imm` wtns is set to 2**imm (upto 32 bit) just for efficient verification. + let imm = circuit_builder.create_witin(|| "imm")?; + let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; + let rd_written = UInt::new(|| "rd_written", circuit_builder)?; + + let outflow = circuit_builder.create_witin(|| "outflow")?; + let assert_lt_config = AssertLTConfig::construct_circuit( circuit_builder, - || "srli_div", - &mut imm, - &mut rd_written, - &remainder, + || "outflow < imm", + outflow.expr(), + imm.expr(), + 2, )?; + let two_pow_total_bits: Expression<_> = (1u64 << UInt::::TOTAL_BITS).into(); + + let is_lt_config = match I::INST_KIND { + InsnKind::SLLI => { + circuit_builder.require_equal( + || "shift check", + rs1_read.value() * imm.expr(), // inflow is zero for this case + outflow.expr() * two_pow_total_bits + rd_written.value(), + )?; + None + } + InsnKind::SRAI | InsnKind::SRLI => { + let (inflow, is_lt_config) = match I::INST_KIND { + InsnKind::SRAI => { + let max_signed_limb_expr: Expression<_> = + ((1 << (UInt::::LIMB_BITS - 1)) - 1).into(); + let is_rs1_neg = IsLtConfig::construct_circuit( + circuit_builder, + || "lhs_msb", + max_signed_limb_expr.clone(), + rs1_read.limbs.iter().last().unwrap().expr(), // msb limb + 1, + )?; + let msb_expr: Expression = is_rs1_neg.is_lt.expr(); + let ones = imm.expr() - Expression::ONE; + (msb_expr * ones, Some(is_rs1_neg)) + } + InsnKind::SRLI => (Expression::ZERO, None), + _ => unreachable!(), + }; + circuit_builder.require_equal( + || "shift check", + rd_written.value() * imm.expr() + outflow.expr(), + inflow * two_pow_total_bits + rs1_read.value(), + )?; + is_lt_config + } + _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), + }; + let i_insn = IInstructionConfig::::construct_circuit( circuit_builder, I::INST_KIND, - &imm.value(), - div_config.dividend.register_expr(), + &imm.expr(), + rs1_read.register_expr(), rd_written.register_expr(), false, )?; - Ok(InstructionConfig { + Ok(ShiftImmConfig { i_insn, imm, + rs1_read, rd_written, - remainder, - div_config, + outflow, + assert_lt_config, + is_lt_config, }) } @@ -79,26 +136,36 @@ impl Instruction for ShiftImmInstructio lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { + let imm = step.insn().imm_or_funct7(); + let rs1_read = Value::new_unchecked(step.rs1().unwrap().value); let rd_written = Value::new(step.rd().unwrap().value.after, lk_multiplicity); - let (remainder, imm) = { - let rs1_read = step.rs1().unwrap().value; - let imm = step.insn().imm_or_funct7(); - ( - Value::new(rs1_read % imm, lk_multiplicity), - Value::new(imm, lk_multiplicity), - ) - }; - config.div_config.assign_instance( - instance, - lk_multiplicity, - &imm, - &rd_written, - &remainder, - )?; - config.imm.assign_value(instance, imm); + set_val!(instance, config.imm, imm as u64); + config.rs1_read.assign_value(instance, rs1_read.clone()); config.rd_written.assign_value(instance, rd_written); - config.remainder.assign_value(instance, remainder); + + let outflow = match I::INST_KIND { + InsnKind::SLLI => (rs1_read.as_u64() * imm as u64) >> UInt::::TOTAL_BITS, + InsnKind::SRAI | InsnKind::SRLI => { + if I::INST_KIND == InsnKind::SRAI { + let max_signed_limb_expr = (1 << (UInt::::LIMB_BITS - 1)) - 1; + config.is_lt_config.as_ref().unwrap().assign_instance( + instance, + lk_multiplicity, + max_signed_limb_expr, + rs1_read.as_u64() >> UInt::::LIMB_BITS, + )?; + } + + rs1_read.as_u64() & (imm as u64 - 1) + } + _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), + }; + + set_val!(instance, config.outflow, outflow); + config + .assert_lt_config + .assign_instance(instance, lk_multiplicity, outflow, imm as u64)?; config .i_insn @@ -115,34 +182,91 @@ mod test { use itertools::Itertools; use multilinear_extensions::mle::IntoMLEs; + use super::{ShiftImmInstruction, SlliOp, SraiOp, SrliOp}; use crate::{ Value, circuit_builder::{CircuitBuilder, ConstraintSystem}, - instructions::{Instruction, riscv::constants::UInt}, + instructions::{ + Instruction, + riscv::{RIVInstruction, constants::UInt}, + }, scheme::mock_prover::{MOCK_PC_START, MockProver}, }; - use super::{ShiftImmInstruction, SrliOp}; + #[test] + fn test_opcode_slli() { + // imm = 3 + verify::("32 << 3", 32, 3, 32 << 3); + verify::("33 << 3", 33, 3, 33 << 3); + // imm = 31 + verify::("32 << 31", 32, 31, 32 << 31); + verify::("33 << 31", 33, 31, 33 << 31); + } + + #[test] + fn test_opcode_srai() { + // positive rs1 + // imm = 3 + verify::("32 >> 3", 32, 3, 32 >> 3); + verify::("33 >> 3", 33, 3, 33 >> 3); + // imm = 31 + verify::("32 >> 31", 32, 31, 32 >> 31); + verify::("33 >> 31", 33, 31, 33 >> 31); + + // negative rs1 + // imm = 3 + verify::("-32 >> 3", (-32_i32) as u32, 3, (-32_i32 >> 3) as u32); + verify::("-33 >> 3", (-33_i32) as u32, 3, (-33_i32 >> 3) as u32); + // imm = 31 + verify::("-32 >> 31", (-32_i32) as u32, 31, (-32_i32 >> 31) as u32); + verify::("-33 >> 31", (-33_i32) as u32, 31, (-33_i32 >> 31) as u32); + } #[test] fn test_opcode_srli() { // imm = 3 - verify_srli(3, 32, 32 >> 3); - verify_srli(3, 33, 33 >> 3); + verify::("32 >> 3", 32, 3, 32 >> 3); + verify::("33 >> 3", 33, 3, 33 >> 3); // imm = 31 - verify_srli(31, 32, 32 >> 31); - verify_srli(31, 33, 33 >> 31); + verify::("32 >> 31", 32, 31, 32 >> 31); + verify::("33 >> 31", 33, 31, 33 >> 31); + // rs1 top bit is 1 + verify::("-32 >> 3", (-32_i32) as u32, 3, (-32_i32) as u32 >> 3); } - fn verify_srli(imm: u32, rs1_read: u32, expected_rd_written: u32) { + fn verify( + name: &'static str, + rs1_read: u32, + imm: u32, + expected_rd_written: u32, + ) { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); + + let (prefix, insn_code, rd_written) = match I::INST_KIND { + InsnKind::SLLI => ( + "SLLI", + encode_rv32(InsnKind::SLLI, 2, 0, 4, imm), + rs1_read << imm, + ), + InsnKind::SRAI => ( + "SRAI", + encode_rv32(InsnKind::SRAI, 2, 0, 4, imm), + (rs1_read as i32 >> imm as i32) as u32, + ), + InsnKind::SRLI => ( + "SRLI", + encode_rv32(InsnKind::SRLI, 2, 0, 4, imm), + rs1_read >> imm, + ), + _ => unreachable!(), + }; + let config = cb .namespace( - || "srli", + || format!("{prefix}_({name})"), |cb| { - let config = - ShiftImmInstruction::::construct_circuit(cb); + let config = ShiftImmInstruction::::construct_circuit(cb); Ok(config) }, ) @@ -152,7 +276,7 @@ mod test { config .rd_written .require_equal( - || "assert_rd_written", + || format!("{prefix}_({name})_assert_rd_written"), &mut cb, &UInt::from_const_unchecked( Value::new_unchecked(expected_rd_written) @@ -162,8 +286,7 @@ mod test { ) .unwrap(); - let insn_code = encode_rv32(InsnKind::SRLI, 2, 0, 4, imm); - let (raw_witin, lkm) = ShiftImmInstruction::::assign_instances( + let (raw_witin, lkm) = ShiftImmInstruction::::assign_instances( &config, cb.cs.num_witin as usize, vec![StepRecord::new_i_instruction( @@ -171,22 +294,12 @@ mod test { Change::new(MOCK_PC_START, MOCK_PC_START + PC_STEP_SIZE), insn_code, rs1_read, - Change::new(0, rs1_read >> imm), + Change::new(0, rd_written), 0, )], ) .unwrap(); - let expected_rd_written = UInt::from_const_unchecked( - Value::new_unchecked(expected_rd_written) - .as_u16_limbs() - .to_vec(), - ); - config - .rd_written - .require_equal(|| "assert_rd_written", &mut cb, &expected_rd_written) - .unwrap(); - MockProver::assert_satisfied( &cb, &raw_witin diff --git a/ceno_zkvm/src/instructions/riscv/slti.rs b/ceno_zkvm/src/instructions/riscv/slti.rs new file mode 100644 index 000000000..af31961df --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/slti.rs @@ -0,0 +1,220 @@ +use std::marker::PhantomData; + +use ceno_emul::{InsnKind, SWord, StepRecord, Word}; +use ff_ext::ExtensionField; + +use super::{ + constants::{UINT_LIMBS, UInt}, + i_insn::IInstructionConfig, +}; +use crate::{ + circuit_builder::CircuitBuilder, + error::ZKVMError, + expression::{Expression, ToExpr, WitIn}, + gadgets::IsLtConfig, + instructions::Instruction, + set_val, + tables::InsnRecord, + uint::Value, + witness::LkMultiplicity, +}; +use core::mem::MaybeUninit; + +#[derive(Debug)] +pub struct InstructionConfig { + i_insn: IInstructionConfig, + + rs1_read: UInt, + imm: WitIn, + #[allow(dead_code)] + rd_written: UInt, + + is_rs1_neg: IsLtConfig, + lt: IsLtConfig, +} + +pub struct SltiInstruction(PhantomData); + +impl Instruction for SltiInstruction { + type InstructionConfig = InstructionConfig; + + fn name() -> String { + format!("{:?}", InsnKind::SLTI) + } + + fn construct_circuit(cb: &mut CircuitBuilder) -> Result { + // If rs1_read < imm, rd_written = 1. Otherwise rd_written = 0 + let rs1_read = UInt::new_unchecked(|| "rs1_read", cb)?; + let imm = cb.create_witin(|| "imm")?; + + let max_signed_limb_expr: Expression<_> = ((1 << (UInt::::LIMB_BITS - 1)) - 1).into(); + let is_rs1_neg = IsLtConfig::construct_circuit( + cb, + || "lhs_msb", + max_signed_limb_expr.clone(), + rs1_read.limbs.iter().last().unwrap().expr(), // msb limb + 1, + )?; + + let lt = IsLtConfig::construct_circuit( + cb, + || "rs1 < imm", + rs1_read.to_field_expr(is_rs1_neg.expr()), + imm.expr(), + UINT_LIMBS, + )?; + let rd_written = UInt::from_exprs_unchecked(vec![lt.expr()])?; + + let i_insn = IInstructionConfig::::construct_circuit( + cb, + InsnKind::SLTI, + &imm.expr(), + rs1_read.register_expr(), + rd_written.register_expr(), + false, + )?; + + Ok(InstructionConfig { + i_insn, + rs1_read, + imm, + rd_written, + is_rs1_neg, + lt, + }) + } + + fn assign_instance( + config: &Self::InstructionConfig, + instance: &mut [MaybeUninit], + lkm: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + config.i_insn.assign_instance(instance, lkm, step)?; + + let rs1 = step.rs1().unwrap().value; + let max_signed_limb = (1u64 << (UInt::::LIMB_BITS - 1)) - 1; + let rs1_value = Value::new_unchecked(rs1 as Word); + config + .rs1_read + .assign_value(instance, Value::new_unchecked(rs1)); + config.is_rs1_neg.assign_instance( + instance, + lkm, + max_signed_limb, + *rs1_value.limbs.last().unwrap() as u64, + )?; + + let imm = step.insn().imm_or_funct7(); + let imm_field = InsnRecord::imm_or_funct7_field::(&step.insn()); + set_val!(instance, config.imm, imm_field); + + config + .lt + .assign_instance_signed(instance, lkm, rs1 as SWord, imm as SWord)?; + + Ok(()) + } +} + +#[cfg(test)] +mod test { + use ceno_emul::{Change, PC_STEP_SIZE, StepRecord, Word, encode_rv32}; + use goldilocks::GoldilocksExt2; + + use itertools::Itertools; + use multilinear_extensions::mle::IntoMLEs; + use rand::Rng; + + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::test_utils::imm_i}, + scheme::mock_prover::{MOCK_PC_START, MockProver}, + }; + + fn verify(name: &'static str, rs1: i32, imm: i32, rd: Word) { + let mut cs = ConstraintSystem::::new(|| "riscv"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = cb + .namespace( + || format!("SLTI/{name}"), + |cb| { + let config = SltiInstruction::construct_circuit(cb); + Ok(config) + }, + ) + .unwrap() + .unwrap(); + + let insn_code = encode_rv32(InsnKind::SLTI, 2, 0, 4, imm_i(imm)); + let (raw_witin, lkm) = + SltiInstruction::assign_instances(&config, cb.cs.num_witin as usize, vec![ + StepRecord::new_i_instruction( + 3, + Change::new(MOCK_PC_START, MOCK_PC_START + PC_STEP_SIZE), + insn_code, + rs1 as Word, + Change::new(0, rd), + 0, + ), + ]) + .unwrap(); + + let expected_rd_written = + UInt::from_const_unchecked(Value::new_unchecked(rd).as_u16_limbs().to_vec()); + config + .rd_written + .require_equal(|| "assert_rd_written", &mut cb, &expected_rd_written) + .unwrap(); + + MockProver::assert_satisfied( + &cb, + &raw_witin + .de_interleaving() + .into_mles() + .into_iter() + .map(|v| v.into()) + .collect_vec(), + &[insn_code], + None, + Some(lkm), + ); + } + + #[test] + fn test_slti_true() { + verify("lt = true, 0 < 1", 0, 1, 1); + verify("lt = true, 1 < 2", 1, 2, 1); + verify("lt = true, -1 < 0", -1, 0, 1); + verify("lt = true, -1 < 1", -1, 1, 1); + verify("lt = true, -2 < -1", -2, -1, 1); + // -2048 <= imm <= 2047 + verify("lt = true, imm upper bondary", i32::MIN, 2047, 1); + verify("lt = true, imm lower bondary", i32::MIN, -2048, 1); + } + + #[test] + fn test_slti_false() { + verify("lt = false, 1 < 0", 1, 0, 0); + verify("lt = false, 2 < 1", 2, 1, 0); + verify("lt = false, 0 < -1", 0, -1, 0); + verify("lt = false, 1 < -1", 1, -1, 0); + verify("lt = false, -1 < -2", -1, -2, 0); + verify("lt = false, 0 == 0", 0, 0, 0); + verify("lt = false, 1 == 1", 1, 1, 0); + verify("lt = false, -1 == -1", -1, -1, 0); + // -2048 <= imm <= 2047 + verify("lt = false, imm upper bondary", i32::MAX, 2047, 0); + verify("lt = false, imm lower bondary", i32::MAX, -2048, 0); + } + + #[test] + fn test_slti_random() { + let mut rng = rand::thread_rng(); + let a: i32 = rng.gen(); + let b: i32 = rng.gen::() % 2048; + println!("random: {} u32 { + // imm is 13 bits in B-type + imm_with_max_valid_bits(imm, 13) +} + +pub fn imm_i(imm: i32) -> u32 { + // imm is 12 bits in I-type + imm_with_max_valid_bits(imm, 12) +} + +pub fn imm_j(imm: i32) -> u32 { + // imm is 21 bits in J-type + imm_with_max_valid_bits(imm, 21) +} + +fn imm_with_max_valid_bits(imm: i32, bits: u32) -> u32 { + imm as u32 & !(u32::MAX << bits) +} + +pub fn imm_u(imm: u32) -> u32 { + // valid imm is imm[12:31] in U-type + imm << 12 +} diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index a49fbbb00..77b9fe8ff 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -34,7 +34,7 @@ use crate::{ structs::{ Point, ProvingKey, TowerProofs, TowerProver, TowerProverSpec, ZKVMProvingKey, ZKVMWitnesses, }, - utils::{get_challenge_pows, next_pow2_instance_padding, proper_num_threads}, + utils::{get_challenge_pows, next_pow2_instance_padding, optimal_sumcheck_threads}, virtual_polys::VirtualPolynomials, }; @@ -56,7 +56,6 @@ impl> ZKVMProver { &self, witnesses: ZKVMWitnesses, pi: PublicValues, - max_threads: usize, mut transcript: Transcript, ) -> Result, ZKVMError> { let mut vm_proof = ZKVMProof::empty(pi); @@ -150,7 +149,6 @@ impl> ZKVMProver { wits_commit, &pi, num_instances, - max_threads, transcript, &challenges, )?; @@ -170,7 +168,6 @@ impl> ZKVMProver { witness.into_iter().map(|v| v.into()).collect_vec(), wits_commit, &pi, - max_threads, transcript, &challenges, )?; @@ -204,7 +201,6 @@ impl> ZKVMProver { wits_commit: PCS::CommitmentWithData, pi: &[ArcMultilinearExtension<'_, E>], num_instances: usize, - max_threads: usize, transcript: &mut Transcript, challenges: &[E; 2], ) -> Result, ZKVMError> { @@ -338,7 +334,6 @@ impl> ZKVMProver { let lk_q2_out_eval = lk_wit_layers[0][3].get_ext_field_vec()[0]; assert!(record_r_out_evals.len() == NUM_FANIN && record_w_out_evals.len() == NUM_FANIN); let (rt_tower, tower_proof) = TowerProver::create_proof( - max_threads, vec![ TowerProverSpec { witness: r_wit_layers, @@ -381,7 +376,7 @@ impl> ZKVMProver { rt_tower[..log2_num_instances].to_vec(), ); - let num_threads = proper_num_threads(log2_num_instances, max_threads); + let num_threads = optimal_sumcheck_threads(log2_num_instances); let alpha_pow = get_challenge_pows( MAINCONSTRAIN_SUMCHECK_BATCH_SIZE + cs.assert_zero_sumcheck_expressions.len(), transcript, @@ -642,7 +637,6 @@ impl> ZKVMProver { witnesses: Vec>, wits_commit: PCS::CommitmentWithData, pi: &[ArcMultilinearExtension<'_, E>], - max_threads: usize, transcript: &mut Transcript, challenges: &[E; 2], ) -> Result, ZKVMError> { @@ -863,7 +857,6 @@ impl> ZKVMProver { .collect_vec(); let (rt_tower, tower_proof) = TowerProver::create_proof( - max_threads, // pattern [r1, w1, r2, w2, ...] same pair are chain together r_wit_layers .into_iter() @@ -904,7 +897,7 @@ impl> ZKVMProver { // If all table length are the same, we can skip this sumcheck let span = entered_span!("sumcheck::opening_same_point"); // NOTE: max concurrency will be dominated by smallest table since it will blo - let num_threads = proper_num_threads(min_log2_num_instance, max_threads); + let num_threads = optimal_sumcheck_threads(min_log2_num_instance); let alpha_pow = get_challenge_pows( cs.r_table_expressions.len() + cs.w_table_expressions.len() @@ -1131,7 +1124,6 @@ impl TowerProofs { /// Tower Prover impl TowerProver { pub fn create_proof<'a, E: ExtensionField>( - max_threads: usize, prod_specs: Vec>, logup_specs: Vec>, num_fanin: usize, @@ -1166,7 +1158,7 @@ impl TowerProver { let (next_rt, _) = (1..=max_round_index).fold((initial_rt, alpha_pows), |(out_rt, alpha_pows), round| { // in first few round we just run on single thread - let num_threads = proper_num_threads(out_rt.len(), max_threads); + let num_threads = optimal_sumcheck_threads(out_rt.len()); let eq: ArcMultilinearExtension = build_eq_x_r_vec(&out_rt).into_mle().into(); let mut virtual_polys = VirtualPolynomials::::new(num_threads, out_rt.len()); diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 75cacf6ea..b64e34d99 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -141,7 +141,6 @@ fn test_rw_lk_expression_combination() { commit, &[], num_instances, - 1, &mut transcript, &prover_challenges, ) @@ -290,7 +289,7 @@ fn test_single_add_instance_e2e() { let pi = PublicValues::new(0, 0, 0, 0, 0, vec![0]); let transcript = Transcript::new(b"riscv"); let zkvm_proof = prover - .create_proof(zkvm_witness, pi, 1, transcript) + .create_proof(zkvm_witness, pi, transcript) .expect("create_proof failed"); let transcript = Transcript::new(b"riscv"); diff --git a/ceno_zkvm/src/uint.rs b/ceno_zkvm/src/uint.rs index fef0c80bc..f243e3769 100644 --- a/ceno_zkvm/src/uint.rs +++ b/ceno_zkvm/src/uint.rs @@ -527,6 +527,12 @@ impl UIntLimbs { UIntLimbs::from_exprs_unchecked(self_hi)?, )) } + + pub fn to_field_expr(&self, is_neg: Expression) -> Expression { + // Convert two's complement representation into field arithmetic. + // Example: 0xFFFF_FFFF = 2^32 - 1 --> shift --> -1 + self.value() - is_neg * (1_u64 << 32) + } } /// Construct `UIntLimbs` from `Vec` @@ -630,6 +636,7 @@ impl ValueMul { } } +#[derive(Clone)] pub struct Value<'a, T: Into + From + Copy + Default> { #[allow(dead_code)] val: T, diff --git a/ceno_zkvm/src/uint/arithmetic.rs b/ceno_zkvm/src/uint/arithmetic.rs index 3ce3b65f6..62754aa0d 100644 --- a/ceno_zkvm/src/uint/arithmetic.rs +++ b/ceno_zkvm/src/uint/arithmetic.rs @@ -5,7 +5,6 @@ use itertools::{Itertools, izip}; use super::{UIntLimbs, UintLimb}; use crate::{ circuit_builder::CircuitBuilder, - create_witin_from_expr, error::ZKVMError, expression::{Expression, ToExpr, WitIn}, gadgets::AssertLTConfig, @@ -281,7 +280,7 @@ impl UIntLimbs { .iter() .fold(Expression::ZERO, |acc, flag| acc.clone() + flag.expr()); - let sum_flag = create_witin_from_expr!(|| "sum_flag", circuit_builder, false, sum_expr)?; + let sum_flag = WitIn::from_expr(|| "sum_flag", circuit_builder, sum_expr, false)?; let (is_equal, diff_inv) = circuit_builder.is_equal(sum_flag.expr(), Expression::from(n_limbs))?; Ok(IsEqualConfig { @@ -314,7 +313,7 @@ impl UIntLimbs { let inv_128 = F::from(128).invert().unwrap(); let msb = (high_limb - high_limb_no_msb.expr()) * Expression::Constant(inv_128); - let msb = create_witin_from_expr!(|| "msb", circuit_builder, false, msb)?; + let msb = WitIn::from_expr(|| "msb", circuit_builder, msb, false)?; Ok(MsbConfig { msb, high_limb_no_msb, @@ -359,7 +358,7 @@ impl UIntLimbs { .rev() .enumerate() .map(|(i, expr)| { - create_witin_from_expr!(|| format!("si_expr_{i}"), circuit_builder, false, expr) + WitIn::from_expr(|| format!("si_expr_{i}"), circuit_builder, expr, false) }) .collect::, ZKVMError>>()?; @@ -394,10 +393,8 @@ impl UIntLimbs { // check the first byte difference has a inverse // unwrap is safe because vector len > 0 - let lhs_ne_byte = - create_witin_from_expr!(|| "lhs_ne_byte", circuit_builder, false, sa.clone())?; - let rhs_ne_byte = - create_witin_from_expr!(|| "rhs_ne_byte", circuit_builder, false, sb.clone())?; + let lhs_ne_byte = WitIn::from_expr(|| "lhs_ne_byte", circuit_builder, sa.clone(), false)?; + let rhs_ne_byte = WitIn::from_expr(|| "rhs_ne_byte", circuit_builder, sb.clone(), false)?; let index_ne = si.first().unwrap(); circuit_builder.require_zero( || "byte inverse check", diff --git a/ceno_zkvm/src/utils.rs b/ceno_zkvm/src/utils.rs index a6cee56e9..e8a5553d6 100644 --- a/ceno_zkvm/src/utils.rs +++ b/ceno_zkvm/src/utils.rs @@ -2,6 +2,7 @@ use ff::Field; use ff_ext::ExtensionField; use goldilocks::SmallField; use itertools::Itertools; +use multilinear_extensions::util::max_usable_threads; use transcript::Transcript; /// convert ext field element to u64, assume it is inside the range @@ -113,7 +114,8 @@ pub fn u64vec(x: u64) -> [u64; W] { /// we expect each thread at least take 4 num of sumcheck variables /// return optimal num threads to run sumcheck -pub fn proper_num_threads(num_vars: usize, expected_max_threads: usize) -> usize { +pub fn optimal_sumcheck_threads(num_vars: usize) -> usize { + let expected_max_threads = max_usable_threads(); let min_numvar_per_thread = 4; if num_vars <= min_numvar_per_thread { 1 diff --git a/gkr/Cargo.toml b/gkr/Cargo.toml index 702791bd0..fab3ec6b8 100644 --- a/gkr/Cargo.toml +++ b/gkr/Cargo.toml @@ -9,7 +9,6 @@ ark-std.workspace = true ff.workspace = true goldilocks.workspace = true -const_env.workspace = true crossbeam-channel.workspace = true ff_ext = { path = "../ff_ext" } itertools.workspace = true diff --git a/gkr/benches/keccak256.rs b/gkr/benches/keccak256.rs index fce8ffb31..d48920dd7 100644 --- a/gkr/benches/keccak256.rs +++ b/gkr/benches/keccak256.rs @@ -3,11 +3,11 @@ use std::time::Duration; -use const_env::from_env; use criterion::*; use gkr::gadgets::keccak256::{keccak256_circuit, prove_keccak256, verify_keccak256}; use goldilocks::GoldilocksExt2; +use multilinear_extensions::util::max_usable_threads; cfg_if::cfg_if! { if #[cfg(feature = "flamegraph")] { @@ -28,8 +28,6 @@ cfg_if::cfg_if! { criterion_main!(keccak256); const NUM_SAMPLES: usize = 10; -#[from_env] -const RAYON_NUM_THREADS: usize = 8; fn bench_keccak256(c: &mut Criterion) { println!( @@ -37,26 +35,7 @@ fn bench_keccak256(c: &mut Criterion) { keccak256_circuit::().layers.len() ); - let max_thread_id = { - if !RAYON_NUM_THREADS.is_power_of_two() { - #[cfg(not(feature = "non_pow2_rayon_thread"))] - { - panic!( - "add --features non_pow2_rayon_thread to enable unsafe feature which support non pow of 2 rayon thread pool" - ); - } - - #[cfg(feature = "non_pow2_rayon_thread")] - { - use sumcheck::{local_thread_pool::create_local_pool_once, util::ceil_log2}; - let max_thread_id = 1 << ceil_log2(RAYON_NUM_THREADS); - create_local_pool_once(1 << ceil_log2(RAYON_NUM_THREADS), true); - max_thread_id - } - } else { - RAYON_NUM_THREADS - } - }; + let max_thread_id = max_usable_threads(); let circuit = keccak256_circuit::(); diff --git a/multilinear_extensions/src/util.rs b/multilinear_extensions/src/util.rs index 28e4f8284..a0a8e56a2 100644 --- a/multilinear_extensions/src/util.rs +++ b/multilinear_extensions/src/util.rs @@ -30,3 +30,21 @@ pub fn create_uninit_vec(len: usize) -> Vec> { pub fn largest_even_below(n: usize) -> usize { if n % 2 == 0 { n } else { n.saturating_sub(1) } } + +fn prev_power_of_two(n: usize) -> usize { + (n + 1).next_power_of_two() / 2 +} + +/// Largest power of two that fits the available rayon threads +pub fn max_usable_threads() -> usize { + if cfg!(test) { + 1 + } else { + let n = rayon::current_num_threads(); + let threads = prev_power_of_two(n); + if n != threads { + tracing::warn!("thread size {n} is not power of 2, using {threads} threads instead."); + } + threads + } +} diff --git a/multilinear_extensions/src/virtual_poly.rs b/multilinear_extensions/src/virtual_poly.rs index a5468ad90..45e4e9fb7 100644 --- a/multilinear_extensions/src/virtual_poly.rs +++ b/multilinear_extensions/src/virtual_poly.rs @@ -2,7 +2,7 @@ use std::{cmp::max, collections::HashMap, marker::PhantomData, mem::MaybeUninit, use crate::{ mle::{ArcDenseMultilinearExtension, DenseMultilinearExtension, MultilinearExtension}, - util::{bit_decompose, create_uninit_vec}, + util::{bit_decompose, create_uninit_vec, max_usable_threads}, }; use ark_std::{end_timer, iterable::Iterable, rand::Rng, start_timer}; use ff::{Field, PrimeField}; @@ -452,8 +452,7 @@ pub fn build_eq_x_r_vec(r: &[E]) -> Vec { // .... // 1 1 1 1 -> r0 * r1 * r2 * r3 // we will need 2^num_var evaluations - let nthreads = - std::env::var("RAYON_NUM_THREADS").map_or(8, |s| s.parse::().unwrap_or(8)); + let nthreads = max_usable_threads(); let nbits = nthreads.trailing_zeros() as usize; assert_eq!(1 << nbits, nthreads); diff --git a/singer/Cargo.toml b/singer/Cargo.toml index e6c10602e..ac71ab4b9 100644 --- a/singer/Cargo.toml +++ b/singer/Cargo.toml @@ -28,7 +28,6 @@ tracing-subscriber.workspace = true [dev-dependencies] cfg-if.workspace = true -const_env.workspace = true criterion.workspace = true pprof.workspace = true tracing.workspace = true diff --git a/singer/benches/add.rs b/singer/benches/add.rs index 70fee6f28..5984a19b2 100644 --- a/singer/benches/add.rs +++ b/singer/benches/add.rs @@ -4,7 +4,6 @@ use std::time::{Duration, Instant}; use ark_std::test_rng; -use const_env::from_env; use criterion::*; use ff_ext::{ExtensionField, ff::Field}; @@ -30,9 +29,8 @@ cfg_if::cfg_if! { criterion_main!(op_add); const NUM_SAMPLES: usize = 10; -#[from_env] -const RAYON_NUM_THREADS: usize = 8; +use multilinear_extensions::util::max_usable_threads; use singer::{ CircuitWiresIn, SingerGraphBuilder, SingerParams, instructions::{Instruction, InstructionGraph, SingerCircuitBuilder, add::AddInstruction}, @@ -42,26 +40,7 @@ use singer_utils::structs::ChipChallenges; use transcript::Transcript; fn bench_add(c: &mut Criterion) { - let max_thread_id = { - if !RAYON_NUM_THREADS.is_power_of_two() { - #[cfg(not(feature = "non_pow2_rayon_thread"))] - { - panic!( - "add --features non_pow2_rayon_thread to enable unsafe feature which support non pow of 2 rayon thread pool" - ); - } - - #[cfg(feature = "non_pow2_rayon_thread")] - { - use sumcheck::{local_thread_pool::create_local_pool_once, util::ceil_log2}; - let max_thread_id = 1 << ceil_log2(RAYON_NUM_THREADS); - create_local_pool_once(1 << ceil_log2(RAYON_NUM_THREADS), true); - max_thread_id - } - } else { - RAYON_NUM_THREADS - } - }; + let max_thread_id = max_usable_threads(); let chip_challenges = ChipChallenges::default(); let circuit_builder = SingerCircuitBuilder::::new(chip_challenges).expect("circuit builder failed"); diff --git a/sumcheck/Cargo.toml b/sumcheck/Cargo.toml index 4b9605cdf..540495c0a 100644 --- a/sumcheck/Cargo.toml +++ b/sumcheck/Cargo.toml @@ -6,7 +6,6 @@ version.workspace = true [dependencies] ark-std.workspace = true -const_env.workspace = true ff.workspace = true ff_ext = { path = "../ff_ext" } goldilocks.workspace = true diff --git a/sumcheck/benches/devirgo_sumcheck.rs b/sumcheck/benches/devirgo_sumcheck.rs index 7116789b9..fd33b9d09 100644 --- a/sumcheck/benches/devirgo_sumcheck.rs +++ b/sumcheck/benches/devirgo_sumcheck.rs @@ -4,7 +4,6 @@ use std::array; use ark_std::test_rng; -use const_env::from_env; use criterion::*; use ff_ext::ExtensionField; use itertools::Itertools; @@ -14,6 +13,7 @@ use goldilocks::GoldilocksExt2; use multilinear_extensions::{ mle::DenseMultilinearExtension, op_mle, + util::max_usable_threads, virtual_poly_v2::{ArcMultilinearExtension, VirtualPolynomialV2 as VirtualPolynomial}, }; use transcript::Transcript; @@ -41,10 +41,10 @@ pub fn transpose(v: Vec>) -> Vec> { } fn prepare_input<'a, E: ExtensionField>( - max_thread_id: usize, nv: usize, ) -> (E, VirtualPolynomial<'a, E>, Vec>) { let mut rng = test_rng(); + let max_thread_id = max_usable_threads(); let size_log2 = ceil_log2(max_thread_id); let fs: [ArcMultilinearExtension<'a, E>; NUM_DEGREE] = array::from_fn(|_| { let mle: ArcMultilinearExtension<'a, E> = @@ -100,9 +100,6 @@ fn prepare_input<'a, E: ExtensionField>( (asserted_sum, virtual_poly_v1, virtual_poly_v2) } -#[from_env] -const RAYON_NUM_THREADS: usize = 8; - fn sumcheck_fn(c: &mut Criterion) { type E = GoldilocksExt2; @@ -119,7 +116,7 @@ fn sumcheck_fn(c: &mut Criterion) { || { let prover_transcript = Transcript::::new(b"test"); let (asserted_sum, virtual_poly, virtual_poly_splitted) = - { prepare_input(RAYON_NUM_THREADS, nv) }; + { prepare_input(nv) }; ( prover_transcript, asserted_sum, @@ -150,6 +147,7 @@ fn sumcheck_fn(c: &mut Criterion) { fn devirgo_sumcheck_fn(c: &mut Criterion) { type E = GoldilocksExt2; + let threads = max_usable_threads(); for nv in NV.into_iter() { // expand more input size once runtime is acceptable let mut group = c.benchmark_group(format!("devirgo_nv_{}", nv)); @@ -163,7 +161,7 @@ fn devirgo_sumcheck_fn(c: &mut Criterion) { || { let prover_transcript = Transcript::::new(b"test"); let (asserted_sum, virtual_poly, virtual_poly_splitted) = - { prepare_input(RAYON_NUM_THREADS, nv) }; + { prepare_input(nv) }; ( prover_transcript, asserted_sum, @@ -178,7 +176,7 @@ fn devirgo_sumcheck_fn(c: &mut Criterion) { virtual_poly_splitted, )| { let (_sumcheck_proof_v2, _) = IOPProverState::::prove_batch_polys( - RAYON_NUM_THREADS, + threads, virtual_poly_splitted, &mut prover_transcript, ); diff --git a/sumcheck/examples/devirgo_sumcheck.rs b/sumcheck/examples/devirgo_sumcheck.rs deleted file mode 100644 index 29ff81368..000000000 --- a/sumcheck/examples/devirgo_sumcheck.rs +++ /dev/null @@ -1,112 +0,0 @@ -use std::sync::Arc; - -use ark_std::test_rng; -use const_env::from_env; -use ff_ext::{ExtensionField, ff::Field}; -use goldilocks::GoldilocksExt2; -use itertools::Itertools; -use multilinear_extensions::{ - commutative_op_mle_pair, - mle::{ArcDenseMultilinearExtension, DenseMultilinearExtension, MultilinearExtension}, - virtual_poly::VirtualPolynomial, -}; -use sumcheck::{ - structs::{IOPProverState, IOPVerifierState}, - util::ceil_log2, -}; -use transcript::Transcript; - -type E = GoldilocksExt2; - -fn prepare_input( - max_thread_id: usize, -) -> (E, VirtualPolynomial, Vec>) { - let nv = 10; - let mut rng = test_rng(); - let size_log2 = ceil_log2(max_thread_id); - let f1: Arc> = - DenseMultilinearExtension::::random(nv, &mut rng).into(); - let g1: Arc> = - DenseMultilinearExtension::::random(nv, &mut rng).into(); - - let mut virtual_poly_1 = VirtualPolynomial::new_from_mle(f1.clone(), E::BaseField::ONE); - virtual_poly_1.mul_by_mle(g1.clone(), ::BaseField::ONE); - - let mut virtual_poly_f1: Vec> = match &f1.evaluations { - multilinear_extensions::mle::FieldType::Base(evaluations) => evaluations - .chunks((1 << nv) >> size_log2) - .map(|chunk| { - DenseMultilinearExtension::::from_evaluations_vec(nv - size_log2, chunk.to_vec()) - .into() - }) - .map(|mle| VirtualPolynomial::new_from_mle(mle, E::BaseField::ONE)) - .collect_vec(), - _ => unreachable!(), - }; - - let poly_g1: Vec> = match &g1.evaluations { - multilinear_extensions::mle::FieldType::Base(evaluations) => evaluations - .chunks((1 << nv) >> size_log2) - .map(|chunk| { - DenseMultilinearExtension::::from_evaluations_vec(nv - size_log2, chunk.to_vec()) - .into() - }) - .collect_vec(), - _ => unreachable!(), - }; - - let asserted_sum = commutative_op_mle_pair!(|f1, g1| { - (0..f1.len()) - .map(|i| f1[i] * g1[i]) - .fold(E::ZERO, |acc, item| acc + item) - }); - - virtual_poly_f1 - .iter_mut() - .zip(poly_g1.iter()) - .for_each(|(f1, g1)| f1.mul_by_mle(g1.clone(), E::BaseField::ONE)); - (asserted_sum, virtual_poly_1, virtual_poly_f1) -} - -#[from_env] -const RAYON_NUM_THREADS: usize = 8; - -fn main() { - let mut prover_transcript_v1 = Transcript::::new(b"test"); - let mut prover_transcript_v2 = Transcript::::new(b"test"); - - let (asserted_sum, virtual_poly, virtual_poly_splitted) = prepare_input(RAYON_NUM_THREADS); - let (sumcheck_proof_v2, _) = IOPProverState::::prove_batch_polys( - RAYON_NUM_THREADS, - virtual_poly_splitted.clone(), - &mut prover_transcript_v2, - ); - println!("v2 finish"); - - let mut transcript = Transcript::new(b"test"); - let poly_info = virtual_poly.aux_info.clone(); - let subclaim = IOPVerifierState::::verify( - asserted_sum, - &sumcheck_proof_v2, - &poly_info, - &mut transcript, - ); - assert!( - virtual_poly.evaluate( - subclaim - .point - .iter() - .map(|c| c.elements) - .collect::>() - .as_ref() - ) == subclaim.expected_evaluation, - "wrong subclaim" - ); - - #[allow(deprecated)] - let (sumcheck_proof_v1, _) = - IOPProverState::::prove_parallel(virtual_poly.clone(), &mut prover_transcript_v1); - - println!("v1 finish"); - assert!(sumcheck_proof_v2 == sumcheck_proof_v1); -} diff --git a/sumcheck/src/prover_v2.rs b/sumcheck/src/prover_v2.rs index 2336e52a9..f2ede8680 100644 --- a/sumcheck/src/prover_v2.rs +++ b/sumcheck/src/prover_v2.rs @@ -43,6 +43,7 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { ) -> (IOPProof, IOPProverStateV2<'a, E>) { assert!(!polys.is_empty()); assert_eq!(polys.len(), max_thread_id); + assert!(max_thread_id.is_power_of_two()); let log2_max_thread_id = ceil_log2(max_thread_id); // do not support SIZE not power of 2 assert!(