Skip to content

Commit

Permalink
fix/jalr-degree
Browse files Browse the repository at this point in the history
  • Loading branch information
Aurélien Nicolas committed Oct 30, 2024
1 parent cc41908 commit 2a5bc11
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 16 deletions.
44 changes: 29 additions & 15 deletions ceno_zkvm/src/instructions/riscv/jump/jalr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
Value,
circuit_builder::CircuitBuilder,
error::ZKVMError,
expression::{ToExpr, WitIn},
expression::{Expression, ToExpr, WitIn},
instructions::{
Instruction,
riscv::{constants::UInt, i_insn::IInstructionConfig, insn_base::MemAddr},
Expand All @@ -23,7 +23,7 @@ pub struct JalrConfig<E: ExtensionField> {
pub rs1_read: UInt<E>,
pub imm: WitIn,
pub next_pc_addr: MemAddr<E>,
pub overflow: WitIn,
pub overflow: Option<(WitIn, WitIn)>,
pub rd_written: UInt<E>,
}

Expand Down Expand Up @@ -63,17 +63,25 @@ impl<E: ExtensionField> Instruction<E> for JalrInstruction<E> {
// 3. next_pc = next_pc_addr aligned to even value (round down)

let next_pc_addr = MemAddr::<E>::construct_unaligned(circuit_builder)?;
let overflow = circuit_builder.create_witin(|| "overflow");

let (overflow_expr, overflow) = if cfg!(feature = "forbid_overflow") {
(Expression::ZERO, None)
} else {
let overflow = circuit_builder.create_witin(|| "overflow");
let tmp = circuit_builder.create_witin(|| "overflow1");
circuit_builder.require_zero(|| "overflow_0_or_pm1", overflow.expr() * tmp.expr())?;
circuit_builder.require_equal(
|| "overflow_tmp",
tmp.expr(),
(1 - overflow.expr()) * (1 + overflow.expr()),
)?;
(overflow.expr(), Some((overflow, tmp)))
};

circuit_builder.require_equal(
|| "rs1+imm = next_pc_unrounded + overflow*2^32",
rs1_read.value() + imm.expr(),
next_pc_addr.expr_unaligned() + overflow.expr() * (1u64 << 32),
)?;

circuit_builder.require_zero(
|| "overflow_0_or_pm1",
overflow.expr() * (overflow.expr() - 1) * (overflow.expr() + 1),
next_pc_addr.expr_unaligned() + overflow_expr * (1u64 << 32),
)?;

circuit_builder.require_equal(
Expand Down Expand Up @@ -126,12 +134,18 @@ impl<E: ExtensionField> Instruction<E> for JalrInstruction<E> {
config
.next_pc_addr
.assign_instance(instance, lk_multiplicity, sum)?;
let overflow: E::BaseField = match (overflowing, imm < 0) {
(false, _) => E::BaseField::ZERO,
(true, false) => E::BaseField::ONE,
(true, true) => -E::BaseField::ONE,
};
set_val!(instance, config.overflow, overflow);

if let Some((overflow_cfg, tmp_cfg)) = &config.overflow {
let (overflow, tmp) = match (overflowing, imm < 0) {
(false, _) => (E::BaseField::ZERO, E::BaseField::ONE),
(true, false) => (E::BaseField::ONE, E::BaseField::ZERO),
(true, true) => (-E::BaseField::ONE, E::BaseField::ZERO),
};
set_val!(instance, overflow_cfg, overflow);
set_val!(instance, tmp_cfg, tmp);
} else {
assert!(!overflowing, "overflow not allowed in JALR");
}

config
.i_insn
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/instructions/riscv/jump/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ fn test_opcode_jalr() {
.unwrap();

let imm = -15i32;
let rs1_read: Word = 10u32;
let rs1_read: Word = 100u32;
let new_pc: ByteAddr = ByteAddr(rs1_read.wrapping_add_signed(imm) & (!1));
let insn_code = encode_rv32(InsnKind::JALR, 2, 0, 4, imm as u32);

Expand Down
28 changes: 28 additions & 0 deletions ceno_zkvm/src/scheme/mock_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use std::{
};
use strum::IntoEnumIterator;

const MAX_CONSTRAINT_DEGREE: usize = 2;
const MOCK_PROGRAM_SIZE: usize = 32;
pub const MOCK_PC_START: ByteAddr = ByteAddr(CENO_PLATFORM.pc_base());

Expand All @@ -50,6 +51,11 @@ pub(crate) enum MockProverError<E: ExtensionField> {
name: String,
inst_id: usize,
},
DegreeTooHigh {
expression: Expression<E>,
degree: usize,
name: String,
},
LookupError {
expression: Expression<E>,
evaluated: E,
Expand Down Expand Up @@ -178,6 +184,18 @@ impl<E: ExtensionField> MockProverError<E> {
Inst[{inst_id}]:\n{wtns_fmt}\n",
);
}
Self::DegreeTooHigh {
expression,
degree,
name,
} => {
let expression_fmt = fmt::expr(expression, &mut wtns, false);
println!(
"\nDegreeTooHigh {name:?}: Expression degree is too high\n\
Expression: {expression_fmt}\n\
Degree: {degree} > {MAX_CONSTRAINT_DEGREE}\n",
);
}
Self::LookupError {
expression,
evaluated,
Expand Down Expand Up @@ -251,6 +269,7 @@ impl<E: ExtensionField> MockProverError<E> {
| Self::AssertEqualError { inst_id, .. }
| Self::LookupError { inst_id, .. }
| Self::LkMultiplicityError { inst_id, .. } => *inst_id,
Self::DegreeTooHigh { .. } => unreachable!(),
}
}

Expand Down Expand Up @@ -438,6 +457,14 @@ impl<'a, E: ExtensionField + Hash> MockProver<E> {
.chain(&cb.cs.assert_zero_sumcheck_expressions_namespace_map),
)
{
if expr.degree() > MAX_CONSTRAINT_DEGREE {
errors.push(MockProverError::DegreeTooHigh {
expression: expr.clone(),
degree: expr.degree(),
name: name.clone(),
});
}

// require_equal does not always have the form of Expr::Sum as
// the sum of witness and constant is expressed as scaled sum
if name.contains("require_equal") && expr.unpack_sum().is_some() {
Expand Down Expand Up @@ -701,6 +728,7 @@ Hints:
.collect_vec();
Self::assert_satisfied(cb, &wits_in, programs, challenge, lkm);
}

pub fn assert_satisfied(
cb: &CircuitBuilder<E>,
wits_in: &[ArcMultilinearExtension<'a, E>],
Expand Down

0 comments on commit 2a5bc11

Please sign in to comment.