Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix/jalr-degree #504

Merged
merged 1 commit into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 29 additions & 15 deletions ceno_zkvm/src/instructions/riscv/jump/jalr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
Value,
circuit_builder::CircuitBuilder,
error::ZKVMError,
expression::{ToExpr, WitIn},
expression::{Expression, ToExpr, WitIn},
instructions::{
Instruction,
riscv::{constants::UInt, i_insn::IInstructionConfig, insn_base::MemAddr},
Expand All @@ -23,7 +23,7 @@ pub struct JalrConfig<E: ExtensionField> {
pub rs1_read: UInt<E>,
pub imm: WitIn,
pub next_pc_addr: MemAddr<E>,
pub overflow: WitIn,
pub overflow: Option<(WitIn, WitIn)>,
pub rd_written: UInt<E>,
}

Expand Down Expand Up @@ -63,17 +63,25 @@ impl<E: ExtensionField> Instruction<E> for JalrInstruction<E> {
// 3. next_pc = next_pc_addr aligned to even value (round down)

let next_pc_addr = MemAddr::<E>::construct_unaligned(circuit_builder)?;
let overflow = circuit_builder.create_witin(|| "overflow");

let (overflow_expr, overflow) = if cfg!(feature = "forbid_overflow") {
(Expression::ZERO, None)
} else {
let overflow = circuit_builder.create_witin(|| "overflow");
let tmp = circuit_builder.create_witin(|| "overflow1");
circuit_builder.require_zero(|| "overflow_0_or_pm1", overflow.expr() * tmp.expr())?;
circuit_builder.require_equal(
|| "overflow_tmp",
tmp.expr(),
(1 - overflow.expr()) * (1 + overflow.expr()),
)?;
(overflow.expr(), Some((overflow, tmp)))
};

circuit_builder.require_equal(
|| "rs1+imm = next_pc_unrounded + overflow*2^32",
rs1_read.value() + imm.expr(),
next_pc_addr.expr_unaligned() + overflow.expr() * (1u64 << 32),
)?;

circuit_builder.require_zero(
|| "overflow_0_or_pm1",
overflow.expr() * (overflow.expr() - 1) * (overflow.expr() + 1),
next_pc_addr.expr_unaligned() + overflow_expr * (1u64 << 32),
)?;

circuit_builder.require_equal(
Expand Down Expand Up @@ -126,12 +134,18 @@ impl<E: ExtensionField> Instruction<E> for JalrInstruction<E> {
config
.next_pc_addr
.assign_instance(instance, lk_multiplicity, sum)?;
let overflow: E::BaseField = match (overflowing, imm < 0) {
(false, _) => E::BaseField::ZERO,
(true, false) => E::BaseField::ONE,
(true, true) => -E::BaseField::ONE,
};
set_val!(instance, config.overflow, overflow);

if let Some((overflow_cfg, tmp_cfg)) = &config.overflow {
let (overflow, tmp) = match (overflowing, imm < 0) {
(false, _) => (E::BaseField::ZERO, E::BaseField::ONE),
(true, false) => (E::BaseField::ONE, E::BaseField::ZERO),
(true, true) => (-E::BaseField::ONE, E::BaseField::ZERO),
};
set_val!(instance, overflow_cfg, overflow);
set_val!(instance, tmp_cfg, tmp);
} else {
assert!(!overflowing, "overflow not allowed in JALR");
}

config
.i_insn
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/instructions/riscv/jump/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ fn test_opcode_jalr() {
.unwrap();

let imm = -15i32;
let rs1_read: Word = 10u32;
let rs1_read: Word = 100u32;
let new_pc: ByteAddr = ByteAddr(rs1_read.wrapping_add_signed(imm) & (!1));
let insn_code = encode_rv32(InsnKind::JALR, 2, 0, 4, imm as u32);

Expand Down
28 changes: 28 additions & 0 deletions ceno_zkvm/src/scheme/mock_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use std::{
};
use strum::IntoEnumIterator;

const MAX_CONSTRAINT_DEGREE: usize = 2;
const MOCK_PROGRAM_SIZE: usize = 32;
pub const MOCK_PC_START: ByteAddr = ByteAddr(CENO_PLATFORM.pc_base());

Expand All @@ -50,6 +51,11 @@ pub(crate) enum MockProverError<E: ExtensionField> {
name: String,
inst_id: usize,
},
DegreeTooHigh {
expression: Expression<E>,
degree: usize,
name: String,
},
LookupError {
expression: Expression<E>,
evaluated: E,
Expand Down Expand Up @@ -178,6 +184,18 @@ impl<E: ExtensionField> MockProverError<E> {
Inst[{inst_id}]:\n{wtns_fmt}\n",
);
}
Self::DegreeTooHigh {
expression,
degree,
name,
} => {
let expression_fmt = fmt::expr(expression, &mut wtns, false);
println!(
"\nDegreeTooHigh {name:?}: Expression degree is too high\n\
Expression: {expression_fmt}\n\
Degree: {degree} > {MAX_CONSTRAINT_DEGREE}\n",
);
}
Self::LookupError {
expression,
evaluated,
Expand Down Expand Up @@ -251,6 +269,7 @@ impl<E: ExtensionField> MockProverError<E> {
| Self::AssertEqualError { inst_id, .. }
| Self::LookupError { inst_id, .. }
| Self::LkMultiplicityError { inst_id, .. } => *inst_id,
Self::DegreeTooHigh { .. } => unreachable!(),
}
}

Expand Down Expand Up @@ -438,6 +457,14 @@ impl<'a, E: ExtensionField + Hash> MockProver<E> {
.chain(&cb.cs.assert_zero_sumcheck_expressions_namespace_map),
)
{
if expr.degree() > MAX_CONSTRAINT_DEGREE {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

errors.push(MockProverError::DegreeTooHigh {
expression: expr.clone(),
degree: expr.degree(),
name: name.clone(),
});
}

// require_equal does not always have the form of Expr::Sum as
// the sum of witness and constant is expressed as scaled sum
if name.contains("require_equal") && expr.unpack_sum().is_some() {
Expand Down Expand Up @@ -701,6 +728,7 @@ Hints:
.collect_vec();
Self::assert_satisfied(cb, &wits_in, programs, challenge, lkm);
}

pub fn assert_satisfied(
cb: &CircuitBuilder<E>,
wits_in: &[ArcMultilinearExtension<'a, E>],
Expand Down
Loading