diff --git a/ceno_zkvm/src/instructions/riscv/jump/jalr.rs b/ceno_zkvm/src/instructions/riscv/jump/jalr.rs index a80e11c4a..39ee50458 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jalr.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jalr.rs @@ -7,7 +7,7 @@ use crate::{ Value, circuit_builder::CircuitBuilder, error::ZKVMError, - expression::{ToExpr, WitIn}, + expression::{Expression, ToExpr, WitIn}, instructions::{ Instruction, riscv::{constants::UInt, i_insn::IInstructionConfig, insn_base::MemAddr}, @@ -23,7 +23,7 @@ pub struct JalrConfig { pub rs1_read: UInt, pub imm: WitIn, pub next_pc_addr: MemAddr, - pub overflow: WitIn, + pub overflow: Option<(WitIn, WitIn)>, pub rd_written: UInt, } @@ -63,17 +63,25 @@ impl Instruction for JalrInstruction { // 3. next_pc = next_pc_addr aligned to even value (round down) let next_pc_addr = MemAddr::::construct_unaligned(circuit_builder)?; - let overflow = circuit_builder.create_witin(|| "overflow"); + + let (overflow_expr, overflow) = if cfg!(feature = "forbid_overflow") { + (Expression::ZERO, None) + } else { + let overflow = circuit_builder.create_witin(|| "overflow"); + let tmp = circuit_builder.create_witin(|| "overflow1"); + circuit_builder.require_zero(|| "overflow_0_or_pm1", overflow.expr() * tmp.expr())?; + circuit_builder.require_equal( + || "overflow_tmp", + tmp.expr(), + (1 - overflow.expr()) * (1 + overflow.expr()), + )?; + (overflow.expr(), Some((overflow, tmp))) + }; circuit_builder.require_equal( || "rs1+imm = next_pc_unrounded + overflow*2^32", rs1_read.value() + imm.expr(), - next_pc_addr.expr_unaligned() + overflow.expr() * (1u64 << 32), - )?; - - circuit_builder.require_zero( - || "overflow_0_or_pm1", - overflow.expr() * (overflow.expr() - 1) * (overflow.expr() + 1), + next_pc_addr.expr_unaligned() + overflow_expr * (1u64 << 32), )?; circuit_builder.require_equal( @@ -126,12 +134,18 @@ impl Instruction for JalrInstruction { config .next_pc_addr .assign_instance(instance, lk_multiplicity, sum)?; - let overflow: E::BaseField = match (overflowing, imm < 0) { - (false, _) => E::BaseField::ZERO, - (true, false) => E::BaseField::ONE, - (true, true) => -E::BaseField::ONE, - }; - set_val!(instance, config.overflow, overflow); + + if let Some((overflow_cfg, tmp_cfg)) = &config.overflow { + let (overflow, tmp) = match (overflowing, imm < 0) { + (false, _) => (E::BaseField::ZERO, E::BaseField::ONE), + (true, false) => (E::BaseField::ONE, E::BaseField::ZERO), + (true, true) => (-E::BaseField::ONE, E::BaseField::ZERO), + }; + set_val!(instance, overflow_cfg, overflow); + set_val!(instance, tmp_cfg, tmp); + } else { + assert!(!overflowing, "overflow not allowed in JALR"); + } config .i_insn diff --git a/ceno_zkvm/src/instructions/riscv/jump/test.rs b/ceno_zkvm/src/instructions/riscv/jump/test.rs index 887db8da3..4453f4c3a 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/test.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/test.rs @@ -62,7 +62,7 @@ fn test_opcode_jalr() { .unwrap(); let imm = -15i32; - let rs1_read: Word = 10u32; + let rs1_read: Word = 100u32; let new_pc: ByteAddr = ByteAddr(rs1_read.wrapping_add_signed(imm) & (!1)); let insn_code = encode_rv32(InsnKind::JALR, 2, 0, 4, imm as u32); diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index 541236d28..257019513 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -30,6 +30,7 @@ use std::{ }; use strum::IntoEnumIterator; +const MAX_CONSTRAINT_DEGREE: usize = 2; const MOCK_PROGRAM_SIZE: usize = 32; pub const MOCK_PC_START: ByteAddr = ByteAddr(CENO_PLATFORM.pc_base()); @@ -50,6 +51,11 @@ pub(crate) enum MockProverError { name: String, inst_id: usize, }, + DegreeTooHigh { + expression: Expression, + degree: usize, + name: String, + }, LookupError { expression: Expression, evaluated: E, @@ -178,6 +184,18 @@ impl MockProverError { Inst[{inst_id}]:\n{wtns_fmt}\n", ); } + Self::DegreeTooHigh { + expression, + degree, + name, + } => { + let expression_fmt = fmt::expr(expression, &mut wtns, false); + println!( + "\nDegreeTooHigh {name:?}: Expression degree is too high\n\ + Expression: {expression_fmt}\n\ + Degree: {degree} > {MAX_CONSTRAINT_DEGREE}\n", + ); + } Self::LookupError { expression, evaluated, @@ -251,6 +269,7 @@ impl MockProverError { | Self::AssertEqualError { inst_id, .. } | Self::LookupError { inst_id, .. } | Self::LkMultiplicityError { inst_id, .. } => *inst_id, + Self::DegreeTooHigh { .. } => unreachable!(), } } @@ -438,6 +457,14 @@ impl<'a, E: ExtensionField + Hash> MockProver { .chain(&cb.cs.assert_zero_sumcheck_expressions_namespace_map), ) { + if expr.degree() > MAX_CONSTRAINT_DEGREE { + errors.push(MockProverError::DegreeTooHigh { + expression: expr.clone(), + degree: expr.degree(), + name: name.clone(), + }); + } + // require_equal does not always have the form of Expr::Sum as // the sum of witness and constant is expressed as scaled sum if name.contains("require_equal") && expr.unpack_sum().is_some() { @@ -701,6 +728,7 @@ Hints: .collect_vec(); Self::assert_satisfied(cb, &wits_in, programs, challenge, lkm); } + pub fn assert_satisfied( cb: &CircuitBuilder, wits_in: &[ArcMultilinearExtension<'a, E>],