From 7a965566f468f22131ca76367ddf4166c3515f9f Mon Sep 17 00:00:00 2001 From: soham Date: Fri, 25 Oct 2024 20:42:54 +0530 Subject: [PATCH 01/10] `SLLI` opcode (#434) Closes #364 --- ceno_zkvm/src/instructions/riscv/shift_imm.rs | 172 +++++++++++++----- 1 file changed, 122 insertions(+), 50 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/shift_imm.rs b/ceno_zkvm/src/instructions/riscv/shift_imm.rs index b5cdda27f..1135b5cc0 100644 --- a/ceno_zkvm/src/instructions/riscv/shift_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/shift_imm.rs @@ -10,28 +10,36 @@ use crate::{ }, witness::LkMultiplicity, }; -use ceno_emul::StepRecord; +use ceno_emul::{InsnKind, StepRecord}; use ff_ext::ExtensionField; use std::{marker::PhantomData, mem::MaybeUninit}; -pub struct InstructionConfig { +pub struct ShiftImmConfig { i_insn: IInstructionConfig, + rs1_read: UInt, imm: UInt, rd_written: UInt, - remainder: UInt, - div_config: DivConfig, + + // for SRLI division arithmetics + remainder: Option>, + div_config: Option>, } pub struct ShiftImmInstruction(PhantomData<(E, I)>); +pub struct SlliOp; +impl RIVInstruction for SlliOp { + const INST_KIND: ceno_emul::InsnKind = ceno_emul::InsnKind::SLLI; +} + pub struct SrliOp; impl RIVInstruction for SrliOp { const INST_KIND: ceno_emul::InsnKind = ceno_emul::InsnKind::SRLI; } impl Instruction for ShiftImmInstruction { - type InstructionConfig = InstructionConfig; + type InstructionConfig = ShiftImmConfig; fn name() -> String { format!("{:?}", I::INST_KIND) @@ -41,33 +49,56 @@ impl Instruction for ShiftImmInstructio circuit_builder: &mut CircuitBuilder, ) -> Result { let mut imm = UInt::new(|| "imm", circuit_builder)?; - let mut rd_written = UInt::new(|| "rd_written", circuit_builder)?; - // Note: `imm` is set to 2**imm (upto 32 bit) just for SRLI for efficient verification + // Note: `imm` is set to 2**imm (upto 32 bit) just for efficient verification // Goal is to constrain: // rs1 == rd_written * imm + remainder - let remainder = UInt::new(|| "remainder", circuit_builder)?; - let div_config = DivConfig::construct_circuit( - circuit_builder, - || "srli_div", - &mut imm, - &mut rd_written, - &remainder, - )?; + let (rs1_read, rd_written, remainder, div_config) = match I::INST_KIND { + InsnKind::SLLI => { + let mut rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; + let rd_written = rs1_read.mul( + || "rd_written = rs1_read * imm", + circuit_builder, + &mut imm, + true, + )?; + + (rs1_read, rd_written, None, None) + } + InsnKind::SRLI => { + let mut rd_written = UInt::new(|| "rd_written", circuit_builder)?; + let remainder = UInt::new(|| "remainder", circuit_builder)?; + let div_config = DivConfig::construct_circuit( + circuit_builder, + || "srli_div", + &mut imm, + &mut rd_written, + &remainder, + )?; + ( + div_config.dividend.clone(), + rd_written, + Some(remainder), + Some(div_config), + ) + } + _ => unreachable!(), + }; let i_insn = IInstructionConfig::::construct_circuit( circuit_builder, I::INST_KIND, &imm.value(), - div_config.dividend.register_expr(), + rs1_read.register_expr(), rd_written.register_expr(), false, )?; - Ok(InstructionConfig { + Ok(ShiftImmConfig { i_insn, imm, rd_written, + rs1_read, remainder, div_config, }) @@ -79,26 +110,38 @@ impl Instruction for ShiftImmInstructio lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - let rd_written = Value::new(step.rd().unwrap().value.after, lk_multiplicity); - - let (remainder, imm) = { - let rs1_read = step.rs1().unwrap().value; - let imm = step.insn().imm_or_funct7(); - ( - Value::new(rs1_read % imm, lk_multiplicity), - Value::new(imm, lk_multiplicity), - ) + let imm = Value::new(step.insn().imm_or_funct7(), lk_multiplicity); + match I::INST_KIND { + InsnKind::SLLI => { + let rs1_read = Value::new_unchecked(step.rs1().unwrap().value); + let rd_written = rs1_read.mul(&imm, lk_multiplicity, true); + config.rs1_read.assign_value(instance, rs1_read); + config + .rd_written + .assign_mul_outcome(instance, lk_multiplicity, &rd_written)?; + } + InsnKind::SRLI => { + let rd_written = Value::new(step.rd().unwrap().value.after, lk_multiplicity); + let rs1_read = step.rs1().unwrap().value; + let remainder = Value::new(rs1_read % imm.as_u32(), lk_multiplicity); + config.div_config.as_ref().unwrap().assign_instance( + instance, + lk_multiplicity, + &imm, + &rd_written, + &remainder, + )?; + config + .remainder + .as_ref() + .unwrap() + .assign_value(instance, remainder); + config.rd_written.assign_value(instance, rd_written); + } + _ => unreachable!(), }; - config.div_config.assign_instance( - instance, - lk_multiplicity, - &imm, - &rd_written, - &remainder, - )?; + config.imm.assign_value(instance, imm); - config.rd_written.assign_value(instance, rd_written); - config.remainder.assign_value(instance, remainder); config .i_insn @@ -118,31 +161,61 @@ mod test { use crate::{ Value, circuit_builder::{CircuitBuilder, ConstraintSystem}, - instructions::{Instruction, riscv::constants::UInt}, + instructions::{ + Instruction, + riscv::{RIVInstruction, constants::UInt}, + }, scheme::mock_prover::{MOCK_PC_START, MockProver}, }; - use super::{ShiftImmInstruction, SrliOp}; + use super::{ShiftImmInstruction, SlliOp, SrliOp}; + + #[test] + fn test_opcode_slli() { + verify::("imm = 3, rs1 = 32", 3, 32, 32 << 3); + verify::("imm = 3, rs1 = 33", 3, 33, 33 << 3); + + verify::("imm = 31, rs1 = 32", 31, 32, 32 << 31); + verify::("imm = 31, rs1 = 33", 31, 33, 33 << 31); + } #[test] fn test_opcode_srli() { - // imm = 3 - verify_srli(3, 32, 32 >> 3); - verify_srli(3, 33, 33 >> 3); - // imm = 31 - verify_srli(31, 32, 32 >> 31); - verify_srli(31, 33, 33 >> 31); + verify::("imm = 3, rs1 = 32", 3, 32, 32 >> 3); + verify::("imm = 3, rs1 = 33", 3, 33, 33 >> 3); + + verify::("imm = 31, rs1 = 32", 31, 32, 32 >> 31); + verify::("imm = 31, rs1 = 33", 31, 33, 33 >> 31); } - fn verify_srli(imm: u32, rs1_read: u32, expected_rd_written: u32) { + fn verify( + name: &'static str, + imm: u32, + rs1_read: u32, + expected_rd_written: u32, + ) { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); + + let (prefix, insn_code, rd_written) = match I::INST_KIND { + InsnKind::SLLI => ( + "SLLI", + encode_rv32(InsnKind::SLLI, 2, 0, 4, imm), + rs1_read << imm, + ), + InsnKind::SRLI => ( + "SRLI", + encode_rv32(InsnKind::SRLI, 2, 0, 4, imm), + rs1_read >> imm, + ), + _ => unreachable!(), + }; + let config = cb .namespace( - || "srli", + || format!("{prefix}_({name})"), |cb| { - let config = - ShiftImmInstruction::::construct_circuit(cb); + let config = ShiftImmInstruction::::construct_circuit(cb); Ok(config) }, ) @@ -162,8 +235,7 @@ mod test { ) .unwrap(); - let insn_code = encode_rv32(InsnKind::SRLI, 2, 0, 4, imm); - let (raw_witin, lkm) = ShiftImmInstruction::::assign_instances( + let (raw_witin, lkm) = ShiftImmInstruction::::assign_instances( &config, cb.cs.num_witin as usize, vec![StepRecord::new_i_instruction( @@ -171,7 +243,7 @@ mod test { Change::new(MOCK_PC_START, MOCK_PC_START + PC_STEP_SIZE), insn_code, rs1_read, - Change::new(0, rs1_read >> imm), + Change::new(0, rd_written), 0, )], ) From 79bb73bb89c7a1f503b0838711acaa4d7e7679f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthias=20G=C3=B6rgens?= Date: Mon, 28 Oct 2024 10:20:03 +0800 Subject: [PATCH 02/10] Ignore `.DS_Store` (#472) --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 3f7f23880..b432b7fc9 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ target log.txt logs/ table_cache_dev_* +.DS_Store From 7471ae5f5fc680cf76850ca2516e21baf1c96251 Mon Sep 17 00:00:00 2001 From: Kimi Wu Date: Mon, 28 Oct 2024 11:42:05 +0800 Subject: [PATCH 03/10] Feat/#367 SLTI (#453) To close #367 Co-authored-by: sm.wu --- ceno_emul/src/rv32im.rs | 22 -- ceno_zkvm/src/gadgets/is_lt.rs | 25 +- ceno_zkvm/src/instructions/riscv.rs | 3 + ceno_zkvm/src/instructions/riscv/arith_imm.rs | 15 +- .../src/instructions/riscv/branch/test.rs | 24 +- ceno_zkvm/src/instructions/riscv/jump/test.rs | 18 +- ceno_zkvm/src/instructions/riscv/slti.rs | 220 ++++++++++++++++++ .../src/instructions/riscv/test_utils.rs | 23 ++ ceno_zkvm/src/uint.rs | 6 + 9 files changed, 283 insertions(+), 73 deletions(-) create mode 100644 ceno_zkvm/src/instructions/riscv/slti.rs create mode 100644 ceno_zkvm/src/instructions/riscv/test_utils.rs diff --git a/ceno_emul/src/rv32im.rs b/ceno_emul/src/rv32im.rs index a150918e3..2abd03cce 100644 --- a/ceno_emul/src/rv32im.rs +++ b/ceno_emul/src/rv32im.rs @@ -218,28 +218,6 @@ impl DecodedInstruction { } } - #[allow(dead_code)] - pub fn from_raw(kind: InsnKind, rs1: u32, rs2: u32, rd: u32) -> Self { - // limit the range of inputs - let rs2 = rs2 & 0x1f; // 5bits mask - let rs1 = rs1 & 0x1f; - let rd = rd & 0x1f; - let func7 = kind.codes().func7; - let func3 = kind.codes().func3; - let opcode = kind.codes().opcode; - let insn = func7 << 25 | rs2 << 20 | rs1 << 15 | func3 << 12 | rd << 7 | opcode; - Self { - insn, - top_bit: func7 | 0x80, - func7, - rs2, - rs1, - func3, - rd, - opcode, - } - } - pub fn encoded(&self) -> u32 { self.insn } diff --git a/ceno_zkvm/src/gadgets/is_lt.rs b/ceno_zkvm/src/gadgets/is_lt.rs index e35ea6b7a..f8d40cdee 100644 --- a/ceno_zkvm/src/gadgets/is_lt.rs +++ b/ceno_zkvm/src/gadgets/is_lt.rs @@ -109,11 +109,23 @@ impl IsLtConfig { lhs: u64, rhs: u64, ) -> Result<(), ZKVMError> { - let is_lt = lhs < rhs; - set_val!(instance, self.is_lt, is_lt as u64); + set_val!(instance, self.is_lt, (lhs < rhs) as u64); self.config.assign_instance(instance, lkm, lhs, rhs)?; Ok(()) } + + pub fn assign_instance_signed( + &self, + instance: &mut [MaybeUninit], + lkm: &mut LkMultiplicity, + lhs: SWord, + rhs: SWord, + ) -> Result<(), ZKVMError> { + set_val!(instance, self.is_lt, (lhs < rhs) as u64); + self.config + .assign_instance_signed(instance, lkm, lhs, rhs)?; + Ok(()) + } } #[derive(Debug, Clone)] @@ -337,12 +349,9 @@ impl InnerSignedLtConfig { 1, )?; - // Convert two's complement representation into field arithmetic. - // Example: 0xFFFF_FFFF = 2^32 - 1 --> shift --> -1 - let neg_shift = -Expression::Constant((1_u64 << 32).into()); - let lhs_value = lhs.value() + is_lhs_neg.expr() * neg_shift.clone(); - let rhs_value = rhs.value() + is_rhs_neg.expr() * neg_shift; - + // Convert to field arithmetic. + let lhs_value = lhs.to_field_expr(is_lhs_neg.expr()); + let rhs_value = rhs.to_field_expr(is_rhs_neg.expr()); let config = InnerLtConfig::construct_circuit( cb, format!("{name} (lhs < rhs)"), diff --git a/ceno_zkvm/src/instructions/riscv.rs b/ceno_zkvm/src/instructions/riscv.rs index 5b88d8b92..96b192d60 100644 --- a/ceno_zkvm/src/instructions/riscv.rs +++ b/ceno_zkvm/src/instructions/riscv.rs @@ -17,6 +17,7 @@ pub mod mulh; pub mod shift; pub mod shift_imm; pub mod slt; +pub mod slti; pub mod sltu; mod b_insn; @@ -33,6 +34,8 @@ mod memory; mod s_insn; #[cfg(test)] mod test; +#[cfg(test)] +mod test_utils; pub trait RIVInstruction { const INST_KIND: InsnKind; diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm.rs b/ceno_zkvm/src/instructions/riscv/arith_imm.rs index 94f676f01..74e19acb2 100644 --- a/ceno_zkvm/src/instructions/riscv/arith_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/arith_imm.rs @@ -88,21 +88,12 @@ mod test { use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, - instructions::Instruction, + instructions::{Instruction, riscv::test_utils::imm_i}, scheme::mock_prover::{MOCK_PC_START, MockProver}, }; use super::AddiInstruction; - fn imm(imm: i32) -> u32 { - // imm is 12 bits in B-type - const IMM_MAX: i32 = 2i32.pow(12); - if imm.is_negative() { - (IMM_MAX + imm) as u32 - } else { - imm as u32 - } - } #[test] fn test_opcode_addi() { let mut cs = ConstraintSystem::::new(|| "riscv"); @@ -118,7 +109,7 @@ mod test { .unwrap() .unwrap(); - let insn_code = encode_rv32(InsnKind::ADDI, 2, 0, 4, imm(3)); + let insn_code = encode_rv32(InsnKind::ADDI, 2, 0, 4, imm_i(3)); let (raw_witin, lkm) = AddiInstruction::::assign_instances( &config, cb.cs.num_witin as usize, @@ -162,7 +153,7 @@ mod test { .unwrap() .unwrap(); - let insn_code = encode_rv32(InsnKind::ADDI, 2, 0, 4, imm(-3)); + let insn_code = encode_rv32(InsnKind::ADDI, 2, 0, 4, imm_i(-3)); let (raw_witin, lkm) = AddiInstruction::::assign_instances( &config, cb.cs.num_witin as usize, diff --git a/ceno_zkvm/src/instructions/riscv/branch/test.rs b/ceno_zkvm/src/instructions/riscv/branch/test.rs index 6b2d6fc01..36746fff3 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/test.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/test.rs @@ -7,23 +7,13 @@ use super::*; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, error::ZKVMError, - instructions::Instruction, + instructions::{Instruction, riscv::test_utils::imm_b}, scheme::mock_prover::{MOCK_PC_START, MockProver}, }; const A: Word = 0xbead1010; const B: Word = 0xef552020; -fn imm(imm: i32) -> u32 { - // imm is 13 bits in B-type - const IMM_MAX: i32 = 2i32.pow(13); - if imm.is_negative() { - (IMM_MAX + imm) as u32 - } else { - imm as u32 - } -} - #[test] fn test_opcode_beq() { impl_opcode_beq(false); @@ -44,7 +34,7 @@ fn impl_opcode_beq(equal: bool) { .unwrap() .unwrap(); - let insn_code = encode_rv32(InsnKind::BEQ, 2, 3, 0, imm(8)); + let insn_code = encode_rv32(InsnKind::BEQ, 2, 3, 0, imm_b(8)); let pc_offset = if equal { 8 } else { PC_STEP_SIZE }; let (raw_witin, lkm) = BeqInstruction::assign_instances(&config, cb.cs.num_witin as usize, vec![ @@ -93,7 +83,7 @@ fn impl_opcode_bne(equal: bool) { .unwrap() .unwrap(); - let insn_code = encode_rv32(InsnKind::BNE, 2, 3, 0, imm(8)); + let insn_code = encode_rv32(InsnKind::BNE, 2, 3, 0, imm_b(8)); let pc_offset = if equal { PC_STEP_SIZE } else { 8 }; let (raw_witin, lkm) = BneInstruction::assign_instances(&config, cb.cs.num_witin as usize, vec![ @@ -145,7 +135,7 @@ fn impl_bltu_circuit(taken: bool, a: u32, b: u32) -> Result<(), ZKVMError> { MOCK_PC_START + PC_STEP_SIZE }; - let insn_code = encode_rv32(InsnKind::BLTU, 2, 3, 0, imm(-8)); + let insn_code = encode_rv32(InsnKind::BLTU, 2, 3, 0, imm_b(-8)); println!("{:#b}", insn_code); let (raw_witin, lkm) = BltuInstruction::assign_instances(&config, circuit_builder.cs.num_witin as usize, vec![ @@ -198,7 +188,7 @@ fn impl_bgeu_circuit(taken: bool, a: u32, b: u32) -> Result<(), ZKVMError> { MOCK_PC_START + PC_STEP_SIZE }; - let insn_code = encode_rv32(InsnKind::BGEU, 2, 3, 0, imm(-8)); + let insn_code = encode_rv32(InsnKind::BGEU, 2, 3, 0, imm_b(-8)); let (raw_witin, lkm) = BgeuInstruction::assign_instances(&config, circuit_builder.cs.num_witin as usize, vec![ StepRecord::new_b_instruction( @@ -251,7 +241,7 @@ fn impl_blt_circuit(taken: bool, a: i32, b: i32) -> Result<(), ZKVMError> { MOCK_PC_START + PC_STEP_SIZE }; - let insn_code = encode_rv32(InsnKind::BLT, 2, 3, 0, imm(-8)); + let insn_code = encode_rv32(InsnKind::BLT, 2, 3, 0, imm_b(-8)); let (raw_witin, lkm) = BltInstruction::assign_instances(&config, circuit_builder.cs.num_witin as usize, vec![ StepRecord::new_b_instruction( @@ -304,7 +294,7 @@ fn impl_bge_circuit(taken: bool, a: i32, b: i32) -> Result<(), ZKVMError> { MOCK_PC_START + PC_STEP_SIZE }; - let insn_code = encode_rv32(InsnKind::BGE, 2, 3, 0, imm(-8)); + let insn_code = encode_rv32(InsnKind::BGE, 2, 3, 0, imm_b(-8)); let (raw_witin, lkm) = BgeInstruction::assign_instances(&config, circuit_builder.cs.num_witin as usize, vec![ StepRecord::new_b_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/jump/test.rs b/ceno_zkvm/src/instructions/riscv/jump/test.rs index f1152ce7c..a1b17e911 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/test.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/test.rs @@ -5,21 +5,15 @@ use multilinear_extensions::mle::IntoMLEs; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, - instructions::Instruction, + instructions::{ + Instruction, + riscv::test_utils::{imm_j, imm_u}, + }, scheme::mock_prover::{MOCK_PC_START, MockProver}, }; use super::{AuipcInstruction, JalInstruction, JalrInstruction, LuiInstruction}; -fn imm_j(imm: i32) -> u32 { - // imm is 21 bits in J-type - const IMM_MAX: i32 = 2i32.pow(21); - if imm.is_negative() { - (IMM_MAX + imm) as u32 - } else { - imm as u32 - } -} #[test] fn test_opcode_jal() { let mut cs = ConstraintSystem::::new(|| "riscv"); @@ -113,10 +107,6 @@ fn test_opcode_jalr() { ); } -fn imm_u(imm: u32) -> u32 { - // valid imm is imm[12:31] in U-type - imm << 12 -} #[test] fn test_opcode_lui() { let mut cs = ConstraintSystem::::new(|| "riscv"); diff --git a/ceno_zkvm/src/instructions/riscv/slti.rs b/ceno_zkvm/src/instructions/riscv/slti.rs new file mode 100644 index 000000000..af31961df --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/slti.rs @@ -0,0 +1,220 @@ +use std::marker::PhantomData; + +use ceno_emul::{InsnKind, SWord, StepRecord, Word}; +use ff_ext::ExtensionField; + +use super::{ + constants::{UINT_LIMBS, UInt}, + i_insn::IInstructionConfig, +}; +use crate::{ + circuit_builder::CircuitBuilder, + error::ZKVMError, + expression::{Expression, ToExpr, WitIn}, + gadgets::IsLtConfig, + instructions::Instruction, + set_val, + tables::InsnRecord, + uint::Value, + witness::LkMultiplicity, +}; +use core::mem::MaybeUninit; + +#[derive(Debug)] +pub struct InstructionConfig { + i_insn: IInstructionConfig, + + rs1_read: UInt, + imm: WitIn, + #[allow(dead_code)] + rd_written: UInt, + + is_rs1_neg: IsLtConfig, + lt: IsLtConfig, +} + +pub struct SltiInstruction(PhantomData); + +impl Instruction for SltiInstruction { + type InstructionConfig = InstructionConfig; + + fn name() -> String { + format!("{:?}", InsnKind::SLTI) + } + + fn construct_circuit(cb: &mut CircuitBuilder) -> Result { + // If rs1_read < imm, rd_written = 1. Otherwise rd_written = 0 + let rs1_read = UInt::new_unchecked(|| "rs1_read", cb)?; + let imm = cb.create_witin(|| "imm")?; + + let max_signed_limb_expr: Expression<_> = ((1 << (UInt::::LIMB_BITS - 1)) - 1).into(); + let is_rs1_neg = IsLtConfig::construct_circuit( + cb, + || "lhs_msb", + max_signed_limb_expr.clone(), + rs1_read.limbs.iter().last().unwrap().expr(), // msb limb + 1, + )?; + + let lt = IsLtConfig::construct_circuit( + cb, + || "rs1 < imm", + rs1_read.to_field_expr(is_rs1_neg.expr()), + imm.expr(), + UINT_LIMBS, + )?; + let rd_written = UInt::from_exprs_unchecked(vec![lt.expr()])?; + + let i_insn = IInstructionConfig::::construct_circuit( + cb, + InsnKind::SLTI, + &imm.expr(), + rs1_read.register_expr(), + rd_written.register_expr(), + false, + )?; + + Ok(InstructionConfig { + i_insn, + rs1_read, + imm, + rd_written, + is_rs1_neg, + lt, + }) + } + + fn assign_instance( + config: &Self::InstructionConfig, + instance: &mut [MaybeUninit], + lkm: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + config.i_insn.assign_instance(instance, lkm, step)?; + + let rs1 = step.rs1().unwrap().value; + let max_signed_limb = (1u64 << (UInt::::LIMB_BITS - 1)) - 1; + let rs1_value = Value::new_unchecked(rs1 as Word); + config + .rs1_read + .assign_value(instance, Value::new_unchecked(rs1)); + config.is_rs1_neg.assign_instance( + instance, + lkm, + max_signed_limb, + *rs1_value.limbs.last().unwrap() as u64, + )?; + + let imm = step.insn().imm_or_funct7(); + let imm_field = InsnRecord::imm_or_funct7_field::(&step.insn()); + set_val!(instance, config.imm, imm_field); + + config + .lt + .assign_instance_signed(instance, lkm, rs1 as SWord, imm as SWord)?; + + Ok(()) + } +} + +#[cfg(test)] +mod test { + use ceno_emul::{Change, PC_STEP_SIZE, StepRecord, Word, encode_rv32}; + use goldilocks::GoldilocksExt2; + + use itertools::Itertools; + use multilinear_extensions::mle::IntoMLEs; + use rand::Rng; + + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::test_utils::imm_i}, + scheme::mock_prover::{MOCK_PC_START, MockProver}, + }; + + fn verify(name: &'static str, rs1: i32, imm: i32, rd: Word) { + let mut cs = ConstraintSystem::::new(|| "riscv"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = cb + .namespace( + || format!("SLTI/{name}"), + |cb| { + let config = SltiInstruction::construct_circuit(cb); + Ok(config) + }, + ) + .unwrap() + .unwrap(); + + let insn_code = encode_rv32(InsnKind::SLTI, 2, 0, 4, imm_i(imm)); + let (raw_witin, lkm) = + SltiInstruction::assign_instances(&config, cb.cs.num_witin as usize, vec![ + StepRecord::new_i_instruction( + 3, + Change::new(MOCK_PC_START, MOCK_PC_START + PC_STEP_SIZE), + insn_code, + rs1 as Word, + Change::new(0, rd), + 0, + ), + ]) + .unwrap(); + + let expected_rd_written = + UInt::from_const_unchecked(Value::new_unchecked(rd).as_u16_limbs().to_vec()); + config + .rd_written + .require_equal(|| "assert_rd_written", &mut cb, &expected_rd_written) + .unwrap(); + + MockProver::assert_satisfied( + &cb, + &raw_witin + .de_interleaving() + .into_mles() + .into_iter() + .map(|v| v.into()) + .collect_vec(), + &[insn_code], + None, + Some(lkm), + ); + } + + #[test] + fn test_slti_true() { + verify("lt = true, 0 < 1", 0, 1, 1); + verify("lt = true, 1 < 2", 1, 2, 1); + verify("lt = true, -1 < 0", -1, 0, 1); + verify("lt = true, -1 < 1", -1, 1, 1); + verify("lt = true, -2 < -1", -2, -1, 1); + // -2048 <= imm <= 2047 + verify("lt = true, imm upper bondary", i32::MIN, 2047, 1); + verify("lt = true, imm lower bondary", i32::MIN, -2048, 1); + } + + #[test] + fn test_slti_false() { + verify("lt = false, 1 < 0", 1, 0, 0); + verify("lt = false, 2 < 1", 2, 1, 0); + verify("lt = false, 0 < -1", 0, -1, 0); + verify("lt = false, 1 < -1", 1, -1, 0); + verify("lt = false, -1 < -2", -1, -2, 0); + verify("lt = false, 0 == 0", 0, 0, 0); + verify("lt = false, 1 == 1", 1, 1, 0); + verify("lt = false, -1 == -1", -1, -1, 0); + // -2048 <= imm <= 2047 + verify("lt = false, imm upper bondary", i32::MAX, 2047, 0); + verify("lt = false, imm lower bondary", i32::MAX, -2048, 0); + } + + #[test] + fn test_slti_random() { + let mut rng = rand::thread_rng(); + let a: i32 = rng.gen(); + let b: i32 = rng.gen::() % 2048; + println!("random: {} u32 { + // imm is 13 bits in B-type + imm_with_max_valid_bits(imm, 13) +} + +pub fn imm_i(imm: i32) -> u32 { + // imm is 12 bits in I-type + imm_with_max_valid_bits(imm, 12) +} + +pub fn imm_j(imm: i32) -> u32 { + // imm is 21 bits in J-type + imm_with_max_valid_bits(imm, 21) +} + +fn imm_with_max_valid_bits(imm: i32, bits: u32) -> u32 { + imm as u32 & !(u32::MAX << bits) +} + +pub fn imm_u(imm: u32) -> u32 { + // valid imm is imm[12:31] in U-type + imm << 12 +} diff --git a/ceno_zkvm/src/uint.rs b/ceno_zkvm/src/uint.rs index fef0c80bc..1470aa14d 100644 --- a/ceno_zkvm/src/uint.rs +++ b/ceno_zkvm/src/uint.rs @@ -527,6 +527,12 @@ impl UIntLimbs { UIntLimbs::from_exprs_unchecked(self_hi)?, )) } + + pub fn to_field_expr(&self, is_neg: Expression) -> Expression { + // Convert two's complement representation into field arithmetic. + // Example: 0xFFFF_FFFF = 2^32 - 1 --> shift --> -1 + self.value() - is_neg * (1_u64 << 32) + } } /// Construct `UIntLimbs` from `Vec` From 0c56fc97513ba69a2197f7201c5d324e8ae45218 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthias=20G=C3=B6rgens?= Date: Mon, 28 Oct 2024 15:22:34 +0800 Subject: [PATCH 04/10] Remove useless macro (#467) One case of the macro was totally unused, the other did not actually provide any benefit over just typing out the code. --- ceno_zkvm/src/expression.rs | 16 ---------------- ceno_zkvm/src/uint/arithmetic.rs | 13 +++++-------- 2 files changed, 5 insertions(+), 24 deletions(-) diff --git a/ceno_zkvm/src/expression.rs b/ceno_zkvm/src/expression.rs index 90f2bba97..0cb39bfe3 100644 --- a/ceno_zkvm/src/expression.rs +++ b/ceno_zkvm/src/expression.rs @@ -639,22 +639,6 @@ impl WitIn { } } -#[macro_export] -/// this is to avoid non-monomial expression -macro_rules! create_witin_from_expr { - // Handle the case for a single expression - ($name:expr, $builder:expr, $debug:expr, $e:expr) => { - WitIn::from_expr($name, $builder, $e, $debug) - }; - // Recursively handle multiple expressions and create a flat tuple with error handling - ($name:expr, $builder:expr, $debug:expr, $e:expr, $($rest:expr),+) => { - { - // Return a Result tuple, handling errors - Ok::<_, ZKVMError>((WitIn::from_expr($name, $builder, $e, $debug)?, $(WitIn::from_expr($name, $builder, $rest)?),*)) - } - }; -} - pub trait ToExpr { type Output; fn expr(&self) -> Self::Output; diff --git a/ceno_zkvm/src/uint/arithmetic.rs b/ceno_zkvm/src/uint/arithmetic.rs index 3ce3b65f6..62754aa0d 100644 --- a/ceno_zkvm/src/uint/arithmetic.rs +++ b/ceno_zkvm/src/uint/arithmetic.rs @@ -5,7 +5,6 @@ use itertools::{Itertools, izip}; use super::{UIntLimbs, UintLimb}; use crate::{ circuit_builder::CircuitBuilder, - create_witin_from_expr, error::ZKVMError, expression::{Expression, ToExpr, WitIn}, gadgets::AssertLTConfig, @@ -281,7 +280,7 @@ impl UIntLimbs { .iter() .fold(Expression::ZERO, |acc, flag| acc.clone() + flag.expr()); - let sum_flag = create_witin_from_expr!(|| "sum_flag", circuit_builder, false, sum_expr)?; + let sum_flag = WitIn::from_expr(|| "sum_flag", circuit_builder, sum_expr, false)?; let (is_equal, diff_inv) = circuit_builder.is_equal(sum_flag.expr(), Expression::from(n_limbs))?; Ok(IsEqualConfig { @@ -314,7 +313,7 @@ impl UIntLimbs { let inv_128 = F::from(128).invert().unwrap(); let msb = (high_limb - high_limb_no_msb.expr()) * Expression::Constant(inv_128); - let msb = create_witin_from_expr!(|| "msb", circuit_builder, false, msb)?; + let msb = WitIn::from_expr(|| "msb", circuit_builder, msb, false)?; Ok(MsbConfig { msb, high_limb_no_msb, @@ -359,7 +358,7 @@ impl UIntLimbs { .rev() .enumerate() .map(|(i, expr)| { - create_witin_from_expr!(|| format!("si_expr_{i}"), circuit_builder, false, expr) + WitIn::from_expr(|| format!("si_expr_{i}"), circuit_builder, expr, false) }) .collect::, ZKVMError>>()?; @@ -394,10 +393,8 @@ impl UIntLimbs { // check the first byte difference has a inverse // unwrap is safe because vector len > 0 - let lhs_ne_byte = - create_witin_from_expr!(|| "lhs_ne_byte", circuit_builder, false, sa.clone())?; - let rhs_ne_byte = - create_witin_from_expr!(|| "rhs_ne_byte", circuit_builder, false, sb.clone())?; + let lhs_ne_byte = WitIn::from_expr(|| "lhs_ne_byte", circuit_builder, sa.clone(), false)?; + let rhs_ne_byte = WitIn::from_expr(|| "rhs_ne_byte", circuit_builder, sb.clone(), false)?; let index_ne = si.first().unwrap(); circuit_builder.require_zero( || "byte inverse check", From 17ae298ec5120e37238a1c3fe7a1f5a8229fda89 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 28 Oct 2024 07:35:43 +0000 Subject: [PATCH 05/10] Bump serde from 1.0.210 to 1.0.213 (#477) --- Cargo.lock | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 24e4c756d..8a826994f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1630,18 +1630,18 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] name = "serde" -version = "1.0.210" +version = "1.0.213" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a" +checksum = "3ea7893ff5e2466df8d720bb615088341b295f849602c6956047f8f80f0e9bc1" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.210" +version = "1.0.213" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f" +checksum = "7e85ad2009c50b58e87caa8cd6dac16bdf511bbfb7af6c33df902396aa480fa5" dependencies = [ "proc-macro2", "quote", From e3ce193cb8ca03acf3f151b78e18200066b21438 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 28 Oct 2024 07:35:59 +0000 Subject: [PATCH 06/10] Bump anyhow from 1.0.90 to 1.0.91 (#478) --- Cargo.lock | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8a826994f..bf37a3132 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -108,9 +108,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.90" +version = "1.0.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37bf3594c4c988a53154954629820791dde498571819ae4ca50ca811e060cc95" +checksum = "c042108f3ed77fd83760a5fd79b53be043192bb3b9dba91d8c574c0ada7850c8" [[package]] name = "ark-std" From 1b8d622e077dfc0aeb4cc2700f9c58ef6ace0d40 Mon Sep 17 00:00:00 2001 From: Ming Date: Mon, 28 Oct 2024 16:54:47 +0800 Subject: [PATCH 07/10] Simplify thread pool configuration (#464) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR should resolve occasionally hang issue during various condition. Root cause is (probably) from `RAYON_NUM_THREADS` env parse from `cons_env` crate, which properly a race condition so the environment var change doesn't trigger a rebuild effectively. This will cause potiential mismatch with system rayon thread (for unknown reason), probably a bug in `cons_env`. Despite the root cause is not 100% for sure, we can clean up those complexity, and instead just respect rayon thread pool parsing from global entry, which greatly simplify the overall flow. Change has been verified on remote benchmark machine. before/after didn't cause performance difference --------- Co-authored-by: Matthias Görgens --- Cargo.lock | 24 ----- Cargo.toml | 1 - build.rs | 3 - ceno_zkvm/Cargo.toml | 1 - ceno_zkvm/benches/riscv_add.rs | 24 ----- ceno_zkvm/examples/riscv_opcodes.rs | 27 +---- ceno_zkvm/src/scheme/prover.rs | 16 +-- ceno_zkvm/src/scheme/tests.rs | 3 +- ceno_zkvm/src/utils.rs | 4 +- gkr/Cargo.toml | 1 - gkr/benches/keccak256.rs | 25 +---- multilinear_extensions/src/util.rs | 18 ++++ multilinear_extensions/src/virtual_poly.rs | 5 +- singer/Cargo.toml | 1 - singer/benches/add.rs | 25 +---- sumcheck/Cargo.toml | 1 - sumcheck/benches/devirgo_sumcheck.rs | 14 ++- sumcheck/examples/devirgo_sumcheck.rs | 112 --------------------- sumcheck/src/prover_v2.rs | 1 + 19 files changed, 40 insertions(+), 266 deletions(-) delete mode 100644 build.rs delete mode 100644 sumcheck/examples/devirgo_sumcheck.rs diff --git a/Cargo.lock b/Cargo.lock index bf37a3132..c80dfdedb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -262,7 +262,6 @@ dependencies = [ "ceno_emul", "cfg-if", "clap", - "const_env", "criterion", "ff", "ff_ext", @@ -409,26 +408,6 @@ dependencies = [ "tiny-keccak", ] -[[package]] -name = "const_env" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e9e4f72c6e3398ca6da372abd9affd8f89781fe728869bbf986206e9af9627e" -dependencies = [ - "const_env_impl", -] - -[[package]] -name = "const_env_impl" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a4f51209740b5e1589e702b3044cdd4562cef41b6da404904192ffffb852d62" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", -] - [[package]] name = "constant_time_eq" version = "0.3.1" @@ -719,7 +698,6 @@ version = "0.1.0" dependencies = [ "ark-std", "cfg-if", - "const_env", "criterion", "crossbeam-channel", "ff", @@ -1692,7 +1670,6 @@ version = "0.1.0" dependencies = [ "ark-std", "cfg-if", - "const_env", "criterion", "ff", "ff_ext", @@ -1800,7 +1777,6 @@ name = "sumcheck" version = "0.1.0" dependencies = [ "ark-std", - "const_env", "criterion", "crossbeam-channel", "ff", diff --git a/Cargo.toml b/Cargo.toml index fc601342b..ea6d35b27 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,7 +26,6 @@ version = "0.1.0" [workspace.dependencies] ark-std = "0.4" cfg-if = "1.0" -const_env = "0.1" criterion = { version = "0.5", features = ["html_reports"] } crossbeam-channel = "0.5" ff = "0.13" diff --git a/build.rs b/build.rs deleted file mode 100644 index 3e31cb0a9..000000000 --- a/build.rs +++ /dev/null @@ -1,3 +0,0 @@ -fn main() { - println!("cargo:rerun-if-env-changed=RAYON_NUM_THREADS"); -} diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index 36d2e77cb..ac16a6de9 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -36,7 +36,6 @@ thread_local = "1.1" [dev-dependencies] base64 = "0.22" cfg-if.workspace = true -const_env.workspace = true criterion.workspace = true pprof.workspace = true serde_json.workspace = true diff --git a/ceno_zkvm/benches/riscv_add.rs b/ceno_zkvm/benches/riscv_add.rs index 9d69e120f..16d5cfe67 100644 --- a/ceno_zkvm/benches/riscv_add.rs +++ b/ceno_zkvm/benches/riscv_add.rs @@ -7,7 +7,6 @@ use ceno_zkvm::{ scheme::prover::ZKVMProver, structs::{ZKVMConstraintSystem, ZKVMFixedTraces}, }; -use const_env::from_env; use criterion::*; use ceno_zkvm::scheme::constants::MAX_NUM_VARIABLES; @@ -37,31 +36,9 @@ cfg_if::cfg_if! { criterion_main!(op_add); const NUM_SAMPLES: usize = 10; -#[from_env] -const RAYON_NUM_THREADS: usize = 8; fn bench_add(c: &mut Criterion) { type Pcs = BasefoldDefault; - let max_threads = { - if !RAYON_NUM_THREADS.is_power_of_two() { - #[cfg(not(feature = "non_pow2_rayon_thread"))] - { - panic!( - "add --features non_pow2_rayon_thread to enable unsafe feature which support non pow of 2 rayon thread pool" - ); - } - - #[cfg(feature = "non_pow2_rayon_thread")] - { - use sumcheck::{local_thread_pool::create_local_pool_once, util::ceil_log2}; - let max_thread_id = 1 << ceil_log2(RAYON_NUM_THREADS); - create_local_pool_once(1 << ceil_log2(RAYON_NUM_THREADS), true); - max_thread_id - } - } else { - RAYON_NUM_THREADS - } - }; let mut zkvm_cs = ZKVMConstraintSystem::default(); let _ = zkvm_cs.register_opcode_circuit::>(); let mut zkvm_fixed_traces = ZKVMFixedTraces::default(); @@ -128,7 +105,6 @@ fn bench_add(c: &mut Criterion) { commit, &[], num_instances, - max_threads, &mut transcript, &challenges, ) diff --git a/ceno_zkvm/examples/riscv_opcodes.rs b/ceno_zkvm/examples/riscv_opcodes.rs index d322d360c..b979975d2 100644 --- a/ceno_zkvm/examples/riscv_opcodes.rs +++ b/ceno_zkvm/examples/riscv_opcodes.rs @@ -8,7 +8,6 @@ use ceno_zkvm::{ tables::{MemFinalRecord, ProgramTableCircuit, initial_memory, initial_registers}, }; use clap::Parser; -use const_env::from_env; use ceno_emul::{ ByteAddr, CENO_PLATFORM, EmuContext, @@ -28,9 +27,6 @@ use tracing_flame::FlameLayer; use tracing_subscriber::{EnvFilter, Registry, fmt, layer::SubscriberExt}; use transcript::Transcript; -#[from_env] -const RAYON_NUM_THREADS: usize = 8; - const PROGRAM_SIZE: usize = 512; // For now, we assume registers // - x0 is not touched, @@ -80,27 +76,6 @@ fn main() { type E = GoldilocksExt2; type Pcs = Basefold; - let max_threads = { - if !RAYON_NUM_THREADS.is_power_of_two() { - #[cfg(not(feature = "non_pow2_rayon_thread"))] - { - panic!( - "add --features non_pow2_rayon_thread to enable unsafe feature which support non pow of 2 rayon thread pool" - ); - } - - #[cfg(feature = "non_pow2_rayon_thread")] - { - use sumcheck::{local_thread_pool::create_local_pool_once, util::ceil_log2}; - let max_thread_id = 1 << ceil_log2(RAYON_NUM_THREADS); - create_local_pool_once(1 << ceil_log2(RAYON_NUM_THREADS), true); - max_thread_id - } - } else { - RAYON_NUM_THREADS - } - }; - let (flame_layer, _guard) = FlameLayer::with_file("./tracing.folded").unwrap(); let subscriber = Registry::default() .with( @@ -237,7 +212,7 @@ fn main() { let transcript = Transcript::new(b"riscv"); let mut zkvm_proof = prover - .create_proof(zkvm_witness, pi, max_threads, transcript) + .create_proof(zkvm_witness, pi, transcript) .expect("create_proof failed"); println!( diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 2c92197a5..c87283c01 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -32,7 +32,7 @@ use crate::{ structs::{ Point, ProvingKey, TowerProofs, TowerProver, TowerProverSpec, ZKVMProvingKey, ZKVMWitnesses, }, - utils::{get_challenge_pows, next_pow2_instance_padding, proper_num_threads}, + utils::{get_challenge_pows, next_pow2_instance_padding, optimal_sumcheck_threads}, virtual_polys::VirtualPolynomials, }; @@ -52,7 +52,6 @@ impl> ZKVMProver { &self, witnesses: ZKVMWitnesses, pi: PublicValues, - max_threads: usize, mut transcript: Transcript, ) -> Result, ZKVMError> { let mut vm_proof = ZKVMProof::empty(pi); @@ -135,7 +134,6 @@ impl> ZKVMProver { wits_commit, pi, num_instances, - max_threads, transcript, &challenges, )?; @@ -155,7 +153,6 @@ impl> ZKVMProver { witness.into_iter().map(|v| v.into()).collect_vec(), wits_commit, pi, - max_threads, transcript, &challenges, )?; @@ -186,7 +183,6 @@ impl> ZKVMProver { wits_commit: PCS::CommitmentWithData, pi: &[E::BaseField], num_instances: usize, - max_threads: usize, transcript: &mut Transcript, challenges: &[E; 2], ) -> Result, ZKVMError> { @@ -320,7 +316,6 @@ impl> ZKVMProver { let lk_q2_out_eval = lk_wit_layers[0][3].get_ext_field_vec()[0]; assert!(record_r_out_evals.len() == NUM_FANIN && record_w_out_evals.len() == NUM_FANIN); let (rt_tower, tower_proof) = TowerProver::create_proof( - max_threads, vec![ TowerProverSpec { witness: r_wit_layers, @@ -363,7 +358,7 @@ impl> ZKVMProver { rt_tower[..log2_num_instances].to_vec(), ); - let num_threads = proper_num_threads(log2_num_instances, max_threads); + let num_threads = optimal_sumcheck_threads(log2_num_instances); let alpha_pow = get_challenge_pows( MAINCONSTRAIN_SUMCHECK_BATCH_SIZE + cs.assert_zero_sumcheck_expressions.len(), transcript, @@ -624,7 +619,6 @@ impl> ZKVMProver { witnesses: Vec>, wits_commit: PCS::CommitmentWithData, pi: &[E::BaseField], - max_threads: usize, transcript: &mut Transcript, challenges: &[E; 2], ) -> Result, ZKVMError> { @@ -843,7 +837,6 @@ impl> ZKVMProver { .collect_vec(); let (rt_tower, tower_proof) = TowerProver::create_proof( - max_threads, // pattern [r1, w1, r2, w2, ...] same pair are chain together r_wit_layers .into_iter() @@ -884,7 +877,7 @@ impl> ZKVMProver { // If all table length are the same, we can skip this sumcheck let span = entered_span!("sumcheck::opening_same_point"); // NOTE: max concurrency will be dominated by smallest table since it will blo - let num_threads = proper_num_threads(min_log2_num_instance, max_threads); + let num_threads = optimal_sumcheck_threads(min_log2_num_instance); let alpha_pow = get_challenge_pows( cs.r_table_expressions.len() + cs.w_table_expressions.len() @@ -1074,7 +1067,6 @@ impl TowerProofs { /// Tower Prover impl TowerProver { pub fn create_proof<'a, E: ExtensionField>( - max_threads: usize, prod_specs: Vec>, logup_specs: Vec>, num_fanin: usize, @@ -1109,7 +1101,7 @@ impl TowerProver { let (next_rt, _) = (1..=max_round_index).fold((initial_rt, alpha_pows), |(out_rt, alpha_pows), round| { // in first few round we just run on single thread - let num_threads = proper_num_threads(out_rt.len(), max_threads); + let num_threads = optimal_sumcheck_threads(out_rt.len()); let eq: ArcMultilinearExtension = build_eq_x_r_vec(&out_rt).into_mle().into(); let mut virtual_polys = VirtualPolynomials::::new(num_threads, out_rt.len()); diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 80c43af0b..f928a8918 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -141,7 +141,6 @@ fn test_rw_lk_expression_combination() { commit, &[], num_instances, - 1, &mut transcript, &prover_challenges, ) @@ -290,7 +289,7 @@ fn test_single_add_instance_e2e() { let pi = PublicValues::new(0, 0, 0, 0, 0); let transcript = Transcript::new(b"riscv"); let zkvm_proof = prover - .create_proof(zkvm_witness, pi, 1, transcript) + .create_proof(zkvm_witness, pi, transcript) .expect("create_proof failed"); let transcript = Transcript::new(b"riscv"); diff --git a/ceno_zkvm/src/utils.rs b/ceno_zkvm/src/utils.rs index 119b13439..47af02967 100644 --- a/ceno_zkvm/src/utils.rs +++ b/ceno_zkvm/src/utils.rs @@ -2,6 +2,7 @@ use ff::Field; use ff_ext::ExtensionField; use goldilocks::SmallField; use itertools::Itertools; +use multilinear_extensions::util::max_usable_threads; use transcript::Transcript; /// convert ext field element to u64, assume it is inside the range @@ -113,7 +114,8 @@ pub fn u64vec(x: u64) -> [u64; W] { /// we expect each thread at least take 4 num of sumcheck variables /// return optimal num threads to run sumcheck -pub fn proper_num_threads(num_vars: usize, expected_max_threads: usize) -> usize { +pub fn optimal_sumcheck_threads(num_vars: usize) -> usize { + let expected_max_threads = max_usable_threads(); let min_numvar_per_thread = 4; if num_vars <= min_numvar_per_thread { 1 diff --git a/gkr/Cargo.toml b/gkr/Cargo.toml index 702791bd0..fab3ec6b8 100644 --- a/gkr/Cargo.toml +++ b/gkr/Cargo.toml @@ -9,7 +9,6 @@ ark-std.workspace = true ff.workspace = true goldilocks.workspace = true -const_env.workspace = true crossbeam-channel.workspace = true ff_ext = { path = "../ff_ext" } itertools.workspace = true diff --git a/gkr/benches/keccak256.rs b/gkr/benches/keccak256.rs index fce8ffb31..d48920dd7 100644 --- a/gkr/benches/keccak256.rs +++ b/gkr/benches/keccak256.rs @@ -3,11 +3,11 @@ use std::time::Duration; -use const_env::from_env; use criterion::*; use gkr::gadgets::keccak256::{keccak256_circuit, prove_keccak256, verify_keccak256}; use goldilocks::GoldilocksExt2; +use multilinear_extensions::util::max_usable_threads; cfg_if::cfg_if! { if #[cfg(feature = "flamegraph")] { @@ -28,8 +28,6 @@ cfg_if::cfg_if! { criterion_main!(keccak256); const NUM_SAMPLES: usize = 10; -#[from_env] -const RAYON_NUM_THREADS: usize = 8; fn bench_keccak256(c: &mut Criterion) { println!( @@ -37,26 +35,7 @@ fn bench_keccak256(c: &mut Criterion) { keccak256_circuit::().layers.len() ); - let max_thread_id = { - if !RAYON_NUM_THREADS.is_power_of_two() { - #[cfg(not(feature = "non_pow2_rayon_thread"))] - { - panic!( - "add --features non_pow2_rayon_thread to enable unsafe feature which support non pow of 2 rayon thread pool" - ); - } - - #[cfg(feature = "non_pow2_rayon_thread")] - { - use sumcheck::{local_thread_pool::create_local_pool_once, util::ceil_log2}; - let max_thread_id = 1 << ceil_log2(RAYON_NUM_THREADS); - create_local_pool_once(1 << ceil_log2(RAYON_NUM_THREADS), true); - max_thread_id - } - } else { - RAYON_NUM_THREADS - } - }; + let max_thread_id = max_usable_threads(); let circuit = keccak256_circuit::(); diff --git a/multilinear_extensions/src/util.rs b/multilinear_extensions/src/util.rs index 28e4f8284..a0a8e56a2 100644 --- a/multilinear_extensions/src/util.rs +++ b/multilinear_extensions/src/util.rs @@ -30,3 +30,21 @@ pub fn create_uninit_vec(len: usize) -> Vec> { pub fn largest_even_below(n: usize) -> usize { if n % 2 == 0 { n } else { n.saturating_sub(1) } } + +fn prev_power_of_two(n: usize) -> usize { + (n + 1).next_power_of_two() / 2 +} + +/// Largest power of two that fits the available rayon threads +pub fn max_usable_threads() -> usize { + if cfg!(test) { + 1 + } else { + let n = rayon::current_num_threads(); + let threads = prev_power_of_two(n); + if n != threads { + tracing::warn!("thread size {n} is not power of 2, using {threads} threads instead."); + } + threads + } +} diff --git a/multilinear_extensions/src/virtual_poly.rs b/multilinear_extensions/src/virtual_poly.rs index a5468ad90..45e4e9fb7 100644 --- a/multilinear_extensions/src/virtual_poly.rs +++ b/multilinear_extensions/src/virtual_poly.rs @@ -2,7 +2,7 @@ use std::{cmp::max, collections::HashMap, marker::PhantomData, mem::MaybeUninit, use crate::{ mle::{ArcDenseMultilinearExtension, DenseMultilinearExtension, MultilinearExtension}, - util::{bit_decompose, create_uninit_vec}, + util::{bit_decompose, create_uninit_vec, max_usable_threads}, }; use ark_std::{end_timer, iterable::Iterable, rand::Rng, start_timer}; use ff::{Field, PrimeField}; @@ -452,8 +452,7 @@ pub fn build_eq_x_r_vec(r: &[E]) -> Vec { // .... // 1 1 1 1 -> r0 * r1 * r2 * r3 // we will need 2^num_var evaluations - let nthreads = - std::env::var("RAYON_NUM_THREADS").map_or(8, |s| s.parse::().unwrap_or(8)); + let nthreads = max_usable_threads(); let nbits = nthreads.trailing_zeros() as usize; assert_eq!(1 << nbits, nthreads); diff --git a/singer/Cargo.toml b/singer/Cargo.toml index e6c10602e..ac71ab4b9 100644 --- a/singer/Cargo.toml +++ b/singer/Cargo.toml @@ -28,7 +28,6 @@ tracing-subscriber.workspace = true [dev-dependencies] cfg-if.workspace = true -const_env.workspace = true criterion.workspace = true pprof.workspace = true tracing.workspace = true diff --git a/singer/benches/add.rs b/singer/benches/add.rs index 70fee6f28..5984a19b2 100644 --- a/singer/benches/add.rs +++ b/singer/benches/add.rs @@ -4,7 +4,6 @@ use std::time::{Duration, Instant}; use ark_std::test_rng; -use const_env::from_env; use criterion::*; use ff_ext::{ExtensionField, ff::Field}; @@ -30,9 +29,8 @@ cfg_if::cfg_if! { criterion_main!(op_add); const NUM_SAMPLES: usize = 10; -#[from_env] -const RAYON_NUM_THREADS: usize = 8; +use multilinear_extensions::util::max_usable_threads; use singer::{ CircuitWiresIn, SingerGraphBuilder, SingerParams, instructions::{Instruction, InstructionGraph, SingerCircuitBuilder, add::AddInstruction}, @@ -42,26 +40,7 @@ use singer_utils::structs::ChipChallenges; use transcript::Transcript; fn bench_add(c: &mut Criterion) { - let max_thread_id = { - if !RAYON_NUM_THREADS.is_power_of_two() { - #[cfg(not(feature = "non_pow2_rayon_thread"))] - { - panic!( - "add --features non_pow2_rayon_thread to enable unsafe feature which support non pow of 2 rayon thread pool" - ); - } - - #[cfg(feature = "non_pow2_rayon_thread")] - { - use sumcheck::{local_thread_pool::create_local_pool_once, util::ceil_log2}; - let max_thread_id = 1 << ceil_log2(RAYON_NUM_THREADS); - create_local_pool_once(1 << ceil_log2(RAYON_NUM_THREADS), true); - max_thread_id - } - } else { - RAYON_NUM_THREADS - } - }; + let max_thread_id = max_usable_threads(); let chip_challenges = ChipChallenges::default(); let circuit_builder = SingerCircuitBuilder::::new(chip_challenges).expect("circuit builder failed"); diff --git a/sumcheck/Cargo.toml b/sumcheck/Cargo.toml index 4b9605cdf..540495c0a 100644 --- a/sumcheck/Cargo.toml +++ b/sumcheck/Cargo.toml @@ -6,7 +6,6 @@ version.workspace = true [dependencies] ark-std.workspace = true -const_env.workspace = true ff.workspace = true ff_ext = { path = "../ff_ext" } goldilocks.workspace = true diff --git a/sumcheck/benches/devirgo_sumcheck.rs b/sumcheck/benches/devirgo_sumcheck.rs index 7116789b9..fd33b9d09 100644 --- a/sumcheck/benches/devirgo_sumcheck.rs +++ b/sumcheck/benches/devirgo_sumcheck.rs @@ -4,7 +4,6 @@ use std::array; use ark_std::test_rng; -use const_env::from_env; use criterion::*; use ff_ext::ExtensionField; use itertools::Itertools; @@ -14,6 +13,7 @@ use goldilocks::GoldilocksExt2; use multilinear_extensions::{ mle::DenseMultilinearExtension, op_mle, + util::max_usable_threads, virtual_poly_v2::{ArcMultilinearExtension, VirtualPolynomialV2 as VirtualPolynomial}, }; use transcript::Transcript; @@ -41,10 +41,10 @@ pub fn transpose(v: Vec>) -> Vec> { } fn prepare_input<'a, E: ExtensionField>( - max_thread_id: usize, nv: usize, ) -> (E, VirtualPolynomial<'a, E>, Vec>) { let mut rng = test_rng(); + let max_thread_id = max_usable_threads(); let size_log2 = ceil_log2(max_thread_id); let fs: [ArcMultilinearExtension<'a, E>; NUM_DEGREE] = array::from_fn(|_| { let mle: ArcMultilinearExtension<'a, E> = @@ -100,9 +100,6 @@ fn prepare_input<'a, E: ExtensionField>( (asserted_sum, virtual_poly_v1, virtual_poly_v2) } -#[from_env] -const RAYON_NUM_THREADS: usize = 8; - fn sumcheck_fn(c: &mut Criterion) { type E = GoldilocksExt2; @@ -119,7 +116,7 @@ fn sumcheck_fn(c: &mut Criterion) { || { let prover_transcript = Transcript::::new(b"test"); let (asserted_sum, virtual_poly, virtual_poly_splitted) = - { prepare_input(RAYON_NUM_THREADS, nv) }; + { prepare_input(nv) }; ( prover_transcript, asserted_sum, @@ -150,6 +147,7 @@ fn sumcheck_fn(c: &mut Criterion) { fn devirgo_sumcheck_fn(c: &mut Criterion) { type E = GoldilocksExt2; + let threads = max_usable_threads(); for nv in NV.into_iter() { // expand more input size once runtime is acceptable let mut group = c.benchmark_group(format!("devirgo_nv_{}", nv)); @@ -163,7 +161,7 @@ fn devirgo_sumcheck_fn(c: &mut Criterion) { || { let prover_transcript = Transcript::::new(b"test"); let (asserted_sum, virtual_poly, virtual_poly_splitted) = - { prepare_input(RAYON_NUM_THREADS, nv) }; + { prepare_input(nv) }; ( prover_transcript, asserted_sum, @@ -178,7 +176,7 @@ fn devirgo_sumcheck_fn(c: &mut Criterion) { virtual_poly_splitted, )| { let (_sumcheck_proof_v2, _) = IOPProverState::::prove_batch_polys( - RAYON_NUM_THREADS, + threads, virtual_poly_splitted, &mut prover_transcript, ); diff --git a/sumcheck/examples/devirgo_sumcheck.rs b/sumcheck/examples/devirgo_sumcheck.rs deleted file mode 100644 index 29ff81368..000000000 --- a/sumcheck/examples/devirgo_sumcheck.rs +++ /dev/null @@ -1,112 +0,0 @@ -use std::sync::Arc; - -use ark_std::test_rng; -use const_env::from_env; -use ff_ext::{ExtensionField, ff::Field}; -use goldilocks::GoldilocksExt2; -use itertools::Itertools; -use multilinear_extensions::{ - commutative_op_mle_pair, - mle::{ArcDenseMultilinearExtension, DenseMultilinearExtension, MultilinearExtension}, - virtual_poly::VirtualPolynomial, -}; -use sumcheck::{ - structs::{IOPProverState, IOPVerifierState}, - util::ceil_log2, -}; -use transcript::Transcript; - -type E = GoldilocksExt2; - -fn prepare_input( - max_thread_id: usize, -) -> (E, VirtualPolynomial, Vec>) { - let nv = 10; - let mut rng = test_rng(); - let size_log2 = ceil_log2(max_thread_id); - let f1: Arc> = - DenseMultilinearExtension::::random(nv, &mut rng).into(); - let g1: Arc> = - DenseMultilinearExtension::::random(nv, &mut rng).into(); - - let mut virtual_poly_1 = VirtualPolynomial::new_from_mle(f1.clone(), E::BaseField::ONE); - virtual_poly_1.mul_by_mle(g1.clone(), ::BaseField::ONE); - - let mut virtual_poly_f1: Vec> = match &f1.evaluations { - multilinear_extensions::mle::FieldType::Base(evaluations) => evaluations - .chunks((1 << nv) >> size_log2) - .map(|chunk| { - DenseMultilinearExtension::::from_evaluations_vec(nv - size_log2, chunk.to_vec()) - .into() - }) - .map(|mle| VirtualPolynomial::new_from_mle(mle, E::BaseField::ONE)) - .collect_vec(), - _ => unreachable!(), - }; - - let poly_g1: Vec> = match &g1.evaluations { - multilinear_extensions::mle::FieldType::Base(evaluations) => evaluations - .chunks((1 << nv) >> size_log2) - .map(|chunk| { - DenseMultilinearExtension::::from_evaluations_vec(nv - size_log2, chunk.to_vec()) - .into() - }) - .collect_vec(), - _ => unreachable!(), - }; - - let asserted_sum = commutative_op_mle_pair!(|f1, g1| { - (0..f1.len()) - .map(|i| f1[i] * g1[i]) - .fold(E::ZERO, |acc, item| acc + item) - }); - - virtual_poly_f1 - .iter_mut() - .zip(poly_g1.iter()) - .for_each(|(f1, g1)| f1.mul_by_mle(g1.clone(), E::BaseField::ONE)); - (asserted_sum, virtual_poly_1, virtual_poly_f1) -} - -#[from_env] -const RAYON_NUM_THREADS: usize = 8; - -fn main() { - let mut prover_transcript_v1 = Transcript::::new(b"test"); - let mut prover_transcript_v2 = Transcript::::new(b"test"); - - let (asserted_sum, virtual_poly, virtual_poly_splitted) = prepare_input(RAYON_NUM_THREADS); - let (sumcheck_proof_v2, _) = IOPProverState::::prove_batch_polys( - RAYON_NUM_THREADS, - virtual_poly_splitted.clone(), - &mut prover_transcript_v2, - ); - println!("v2 finish"); - - let mut transcript = Transcript::new(b"test"); - let poly_info = virtual_poly.aux_info.clone(); - let subclaim = IOPVerifierState::::verify( - asserted_sum, - &sumcheck_proof_v2, - &poly_info, - &mut transcript, - ); - assert!( - virtual_poly.evaluate( - subclaim - .point - .iter() - .map(|c| c.elements) - .collect::>() - .as_ref() - ) == subclaim.expected_evaluation, - "wrong subclaim" - ); - - #[allow(deprecated)] - let (sumcheck_proof_v1, _) = - IOPProverState::::prove_parallel(virtual_poly.clone(), &mut prover_transcript_v1); - - println!("v1 finish"); - assert!(sumcheck_proof_v2 == sumcheck_proof_v1); -} diff --git a/sumcheck/src/prover_v2.rs b/sumcheck/src/prover_v2.rs index 2336e52a9..f2ede8680 100644 --- a/sumcheck/src/prover_v2.rs +++ b/sumcheck/src/prover_v2.rs @@ -43,6 +43,7 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { ) -> (IOPProof, IOPProverStateV2<'a, E>) { assert!(!polys.is_empty()); assert_eq!(polys.len(), max_thread_id); + assert!(max_thread_id.is_power_of_two()); let log2_max_thread_id = ceil_log2(max_thread_id); // do not support SIZE not power of 2 assert!( From d0ab3afda12c65b0daefa158310941a50dd0edcc Mon Sep 17 00:00:00 2001 From: soham Date: Mon, 28 Oct 2024 14:39:10 +0530 Subject: [PATCH 08/10] Improvements to errors from mock prover (#480) - Prevent duplicate display of WitIn values - For complex expressions used in require_equal it causes the inference of left and right to fail. Doing `to_monomial_form` before the subtraction in `require_equal` fixes this issue. --- ceno_zkvm/src/chip_handler/general.rs | 8 +++++++- ceno_zkvm/src/expression.rs | 4 +++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index 255ae20d3..85aafab0f 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -168,7 +168,13 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { NR: Into, N: FnOnce() -> NR, { - self.namespace(|| "require_equal", |cb| cb.cs.require_zero(name_fn, a - b)) + self.namespace( + || "require_equal", + |cb| { + cb.cs + .require_zero(name_fn, a.to_monomial_form() - b.to_monomial_form()) + }, + ) } pub fn require_one(&mut self, name_fn: N, expr: Expression) -> Result<(), ZKVMError> diff --git a/ceno_zkvm/src/expression.rs b/ceno_zkvm/src/expression.rs index 0cb39bfe3..4d8aec48c 100644 --- a/ceno_zkvm/src/expression.rs +++ b/ceno_zkvm/src/expression.rs @@ -741,7 +741,9 @@ pub mod fmt { ) -> String { match expression { Expression::WitIn(wit_in) => { - wtns.push(*wit_in); + if !wtns.contains(wit_in) { + wtns.push(*wit_in); + } format!("WitIn({})", wit_in) } Expression::Challenge(id, pow, scaler, offset) => { From ef5caa958bdbea01a495c89def991e7c91f3d338 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthias=20G=C3=B6rgens?= Date: Mon, 28 Oct 2024 17:46:50 +0800 Subject: [PATCH 09/10] Remove unnecessary `goldilocks` patch (#481) We specify Goldilocks as a git dependency, but then immediately patch it. Let's stop doing that. --- Cargo.toml | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ea6d35b27..aea908372 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,7 +29,7 @@ cfg-if = "1.0" criterion = { version = "0.5", features = ["html_reports"] } crossbeam-channel = "0.5" ff = "0.13" -goldilocks = { git = "https://github.com/zhenfeizhang/Goldilocks" } +goldilocks = { git = "https://github.com/hero78119/Goldilocks" } itertools = "0.13" paste = "1" plonky2 = "0.2" @@ -51,8 +51,5 @@ tracing = { version = "0.1", features = [ tracing-flame = "0.2" tracing-subscriber = { version = "0.3", features = ["env-filter"] } -[patch."https://github.com/zhenfeizhang/Goldilocks"] -goldilocks = { git = "https://github.com/hero78119/Goldilocks" } - [profile.release] lto = "thin" From 47f5572ff093b3467b53f4d535b94a162e922b49 Mon Sep 17 00:00:00 2001 From: soham Date: Mon, 28 Oct 2024 15:34:10 +0530 Subject: [PATCH 10/10] SRAI (#463) This PR adds SRAI to shift_imm, also changes logic for SLLI and SRLI. Closes #365 --- ceno_zkvm/src/instructions/riscv/shift_imm.rs | 223 +++++++++++------- ceno_zkvm/src/uint.rs | 1 + 2 files changed, 133 insertions(+), 91 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/shift_imm.rs b/ceno_zkvm/src/instructions/riscv/shift_imm.rs index 1135b5cc0..166d91e21 100644 --- a/ceno_zkvm/src/instructions/riscv/shift_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/shift_imm.rs @@ -3,11 +3,13 @@ use crate::{ Value, circuit_builder::CircuitBuilder, error::ZKVMError, - gadgets::DivConfig, + expression::{Expression, ToExpr, WitIn}, + gadgets::{AssertLTConfig, IsLtConfig}, instructions::{ Instruction, riscv::{constants::UInt, i_insn::IInstructionConfig}, }, + set_val, witness::LkMultiplicity, }; use ceno_emul::{InsnKind, StepRecord}; @@ -17,13 +19,14 @@ use std::{marker::PhantomData, mem::MaybeUninit}; pub struct ShiftImmConfig { i_insn: IInstructionConfig, + imm: WitIn, rs1_read: UInt, - imm: UInt, rd_written: UInt, + outflow: WitIn, + assert_lt_config: AssertLTConfig, - // for SRLI division arithmetics - remainder: Option>, - div_config: Option>, + // SRAI + is_lt_config: Option, } pub struct ShiftImmInstruction(PhantomData<(E, I)>); @@ -33,9 +36,14 @@ impl RIVInstruction for SlliOp { const INST_KIND: ceno_emul::InsnKind = ceno_emul::InsnKind::SLLI; } +pub struct SraiOp; +impl RIVInstruction for SraiOp { + const INST_KIND: ceno_emul::InsnKind = ceno_emul::InsnKind::SRAI; +} + pub struct SrliOp; impl RIVInstruction for SrliOp { - const INST_KIND: ceno_emul::InsnKind = ceno_emul::InsnKind::SRLI; + const INST_KIND: ceno_emul::InsnKind = InsnKind::SRLI; } impl Instruction for ShiftImmInstruction { @@ -48,47 +56,64 @@ impl Instruction for ShiftImmInstructio fn construct_circuit( circuit_builder: &mut CircuitBuilder, ) -> Result { - let mut imm = UInt::new(|| "imm", circuit_builder)?; + // Note: `imm` wtns is set to 2**imm (upto 32 bit) just for efficient verification. + let imm = circuit_builder.create_witin(|| "imm")?; + let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; + let rd_written = UInt::new(|| "rd_written", circuit_builder)?; + + let outflow = circuit_builder.create_witin(|| "outflow")?; + let assert_lt_config = AssertLTConfig::construct_circuit( + circuit_builder, + || "outflow < imm", + outflow.expr(), + imm.expr(), + 2, + )?; + + let two_pow_total_bits: Expression<_> = (1u64 << UInt::::TOTAL_BITS).into(); - // Note: `imm` is set to 2**imm (upto 32 bit) just for efficient verification - // Goal is to constrain: - // rs1 == rd_written * imm + remainder - let (rs1_read, rd_written, remainder, div_config) = match I::INST_KIND { + let is_lt_config = match I::INST_KIND { InsnKind::SLLI => { - let mut rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; - let rd_written = rs1_read.mul( - || "rd_written = rs1_read * imm", - circuit_builder, - &mut imm, - true, + circuit_builder.require_equal( + || "shift check", + rs1_read.value() * imm.expr(), // inflow is zero for this case + outflow.expr() * two_pow_total_bits + rd_written.value(), )?; - - (rs1_read, rd_written, None, None) + None } - InsnKind::SRLI => { - let mut rd_written = UInt::new(|| "rd_written", circuit_builder)?; - let remainder = UInt::new(|| "remainder", circuit_builder)?; - let div_config = DivConfig::construct_circuit( - circuit_builder, - || "srli_div", - &mut imm, - &mut rd_written, - &remainder, + InsnKind::SRAI | InsnKind::SRLI => { + let (inflow, is_lt_config) = match I::INST_KIND { + InsnKind::SRAI => { + let max_signed_limb_expr: Expression<_> = + ((1 << (UInt::::LIMB_BITS - 1)) - 1).into(); + let is_rs1_neg = IsLtConfig::construct_circuit( + circuit_builder, + || "lhs_msb", + max_signed_limb_expr.clone(), + rs1_read.limbs.iter().last().unwrap().expr(), // msb limb + 1, + )?; + let msb_expr: Expression = is_rs1_neg.is_lt.expr(); + let ones = imm.expr() - Expression::ONE; + (msb_expr * ones, Some(is_rs1_neg)) + } + InsnKind::SRLI => (Expression::ZERO, None), + _ => unreachable!(), + }; + circuit_builder.require_equal( + || "shift check", + rd_written.value() * imm.expr() + outflow.expr(), + inflow * two_pow_total_bits + rs1_read.value(), )?; - ( - div_config.dividend.clone(), - rd_written, - Some(remainder), - Some(div_config), - ) + is_lt_config } - _ => unreachable!(), + _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), }; let i_insn = IInstructionConfig::::construct_circuit( circuit_builder, I::INST_KIND, - &imm.value(), + &imm.expr(), rs1_read.register_expr(), rd_written.register_expr(), false, @@ -97,10 +122,11 @@ impl Instruction for ShiftImmInstructio Ok(ShiftImmConfig { i_insn, imm, - rd_written, rs1_read, - remainder, - div_config, + rd_written, + outflow, + assert_lt_config, + is_lt_config, }) } @@ -110,38 +136,36 @@ impl Instruction for ShiftImmInstructio lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - let imm = Value::new(step.insn().imm_or_funct7(), lk_multiplicity); - match I::INST_KIND { - InsnKind::SLLI => { - let rs1_read = Value::new_unchecked(step.rs1().unwrap().value); - let rd_written = rs1_read.mul(&imm, lk_multiplicity, true); - config.rs1_read.assign_value(instance, rs1_read); - config - .rd_written - .assign_mul_outcome(instance, lk_multiplicity, &rd_written)?; - } - InsnKind::SRLI => { - let rd_written = Value::new(step.rd().unwrap().value.after, lk_multiplicity); - let rs1_read = step.rs1().unwrap().value; - let remainder = Value::new(rs1_read % imm.as_u32(), lk_multiplicity); - config.div_config.as_ref().unwrap().assign_instance( - instance, - lk_multiplicity, - &imm, - &rd_written, - &remainder, - )?; - config - .remainder - .as_ref() - .unwrap() - .assign_value(instance, remainder); - config.rd_written.assign_value(instance, rd_written); + let imm = step.insn().imm_or_funct7(); + let rs1_read = Value::new_unchecked(step.rs1().unwrap().value); + let rd_written = Value::new(step.rd().unwrap().value.after, lk_multiplicity); + + set_val!(instance, config.imm, imm as u64); + config.rs1_read.assign_value(instance, rs1_read.clone()); + config.rd_written.assign_value(instance, rd_written); + + let outflow = match I::INST_KIND { + InsnKind::SLLI => (rs1_read.as_u64() * imm as u64) >> UInt::::TOTAL_BITS, + InsnKind::SRAI | InsnKind::SRLI => { + if I::INST_KIND == InsnKind::SRAI { + let max_signed_limb_expr = (1 << (UInt::::LIMB_BITS - 1)) - 1; + config.is_lt_config.as_ref().unwrap().assign_instance( + instance, + lk_multiplicity, + max_signed_limb_expr, + rs1_read.as_u64() >> UInt::::LIMB_BITS, + )?; + } + + rs1_read.as_u64() & (imm as u64 - 1) } - _ => unreachable!(), + _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), }; - config.imm.assign_value(instance, imm); + set_val!(instance, config.outflow, outflow); + config + .assert_lt_config + .assign_instance(instance, lk_multiplicity, outflow, imm as u64)?; config .i_insn @@ -158,6 +182,7 @@ mod test { use itertools::Itertools; use multilinear_extensions::mle::IntoMLEs; + use super::{ShiftImmInstruction, SlliOp, SraiOp, SrliOp}; use crate::{ Value, circuit_builder::{CircuitBuilder, ConstraintSystem}, @@ -168,30 +193,51 @@ mod test { scheme::mock_prover::{MOCK_PC_START, MockProver}, }; - use super::{ShiftImmInstruction, SlliOp, SrliOp}; - #[test] fn test_opcode_slli() { - verify::("imm = 3, rs1 = 32", 3, 32, 32 << 3); - verify::("imm = 3, rs1 = 33", 3, 33, 33 << 3); + // imm = 3 + verify::("32 << 3", 32, 3, 32 << 3); + verify::("33 << 3", 33, 3, 33 << 3); + // imm = 31 + verify::("32 << 31", 32, 31, 32 << 31); + verify::("33 << 31", 33, 31, 33 << 31); + } - verify::("imm = 31, rs1 = 32", 31, 32, 32 << 31); - verify::("imm = 31, rs1 = 33", 31, 33, 33 << 31); + #[test] + fn test_opcode_srai() { + // positive rs1 + // imm = 3 + verify::("32 >> 3", 32, 3, 32 >> 3); + verify::("33 >> 3", 33, 3, 33 >> 3); + // imm = 31 + verify::("32 >> 31", 32, 31, 32 >> 31); + verify::("33 >> 31", 33, 31, 33 >> 31); + + // negative rs1 + // imm = 3 + verify::("-32 >> 3", (-32_i32) as u32, 3, (-32_i32 >> 3) as u32); + verify::("-33 >> 3", (-33_i32) as u32, 3, (-33_i32 >> 3) as u32); + // imm = 31 + verify::("-32 >> 31", (-32_i32) as u32, 31, (-32_i32 >> 31) as u32); + verify::("-33 >> 31", (-33_i32) as u32, 31, (-33_i32 >> 31) as u32); } #[test] fn test_opcode_srli() { - verify::("imm = 3, rs1 = 32", 3, 32, 32 >> 3); - verify::("imm = 3, rs1 = 33", 3, 33, 33 >> 3); - - verify::("imm = 31, rs1 = 32", 31, 32, 32 >> 31); - verify::("imm = 31, rs1 = 33", 31, 33, 33 >> 31); + // imm = 3 + verify::("32 >> 3", 32, 3, 32 >> 3); + verify::("33 >> 3", 33, 3, 33 >> 3); + // imm = 31 + verify::("32 >> 31", 32, 31, 32 >> 31); + verify::("33 >> 31", 33, 31, 33 >> 31); + // rs1 top bit is 1 + verify::("-32 >> 3", (-32_i32) as u32, 3, (-32_i32) as u32 >> 3); } fn verify( name: &'static str, - imm: u32, rs1_read: u32, + imm: u32, expected_rd_written: u32, ) { let mut cs = ConstraintSystem::::new(|| "riscv"); @@ -203,6 +249,11 @@ mod test { encode_rv32(InsnKind::SLLI, 2, 0, 4, imm), rs1_read << imm, ), + InsnKind::SRAI => ( + "SRAI", + encode_rv32(InsnKind::SRAI, 2, 0, 4, imm), + (rs1_read as i32 >> imm as i32) as u32, + ), InsnKind::SRLI => ( "SRLI", encode_rv32(InsnKind::SRLI, 2, 0, 4, imm), @@ -225,7 +276,7 @@ mod test { config .rd_written .require_equal( - || "assert_rd_written", + || format!("{prefix}_({name})_assert_rd_written"), &mut cb, &UInt::from_const_unchecked( Value::new_unchecked(expected_rd_written) @@ -249,16 +300,6 @@ mod test { ) .unwrap(); - let expected_rd_written = UInt::from_const_unchecked( - Value::new_unchecked(expected_rd_written) - .as_u16_limbs() - .to_vec(), - ); - config - .rd_written - .require_equal(|| "assert_rd_written", &mut cb, &expected_rd_written) - .unwrap(); - MockProver::assert_satisfied( &cb, &raw_witin diff --git a/ceno_zkvm/src/uint.rs b/ceno_zkvm/src/uint.rs index 1470aa14d..f243e3769 100644 --- a/ceno_zkvm/src/uint.rs +++ b/ceno_zkvm/src/uint.rs @@ -636,6 +636,7 @@ impl ValueMul { } } +#[derive(Clone)] pub struct Value<'a, T: Into + From + Copy + Default> { #[allow(dead_code)] val: T,