Skip to content

Commit

Permalink
BEQ BNE Circuit (#257)
Browse files Browse the repository at this point in the history
_Issue #136 #137_

Depends on #272 (done).

---------

Co-authored-by: Aurélien Nicolas <[email protected]>
  • Loading branch information
naure and Aurélien Nicolas authored Sep 24, 2024
1 parent 5004d94 commit 3f54ed4
Show file tree
Hide file tree
Showing 14 changed files with 381 additions and 25 deletions.
28 changes: 28 additions & 0 deletions ceno_emul/src/tracer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,34 @@ impl StepRecord {
}
}

pub fn new_b_instruction(
cycle: Cycle,
pc: Change<ByteAddr>,
insn_code: Word,
rs1_read: Word,
rs2_read: Word,
previous_cycle: Cycle,
) -> StepRecord {
let insn = DecodedInstruction::new(insn_code);
StepRecord {
cycle,
pc,
insn_code,
rs1: Some(ReadOp {
addr: CENO_PLATFORM.register_vma(insn.rs1() as RegIdx).into(),
value: rs1_read,
previous_cycle,
}),
rs2: Some(ReadOp {
addr: CENO_PLATFORM.register_vma(insn.rs2() as RegIdx).into(),
value: rs2_read,
previous_cycle,
}),
rd: None,
memory_op: None,
}
}

pub fn cycle(&self) -> Cycle {
self.cycle
}
Expand Down
16 changes: 8 additions & 8 deletions ceno_zkvm/src/circuit_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,8 @@ impl<E: ExtensionField> ConstraintSystem<E> {
assert_eq!(
rlc_record.degree(),
1,
"rlc record degree {} != 1",
rlc_record.degree()
"rlc lk_record degree ({})",
name_fn().into()
);
self.lk_expressions.push(rlc_record);
let path = self.ns.compute_path(name_fn().into());
Expand All @@ -223,8 +223,8 @@ impl<E: ExtensionField> ConstraintSystem<E> {
assert_eq!(
rlc_record.degree(),
1,
"rlc record degree {} != 1",
rlc_record.degree()
"rlc lk_table_record degree ({})",
name_fn().into()
);
self.lk_table_expressions.push(LogupTableExpression {
values: rlc_record,
Expand All @@ -244,8 +244,8 @@ impl<E: ExtensionField> ConstraintSystem<E> {
assert_eq!(
rlc_record.degree(),
1,
"rlc record degree {} != 1",
rlc_record.degree()
"rlc read_record degree ({})",
name_fn().into()
);
self.r_expressions.push(rlc_record);
let path = self.ns.compute_path(name_fn().into());
Expand All @@ -261,8 +261,8 @@ impl<E: ExtensionField> ConstraintSystem<E> {
assert_eq!(
rlc_record.degree(),
1,
"rlc record degree {} != 1",
rlc_record.degree()
"rlc write_record degree ({})",
name_fn().into()
);
self.w_expressions.push(rlc_record);
let path = self.ns.compute_path(name_fn().into());
Expand Down
10 changes: 7 additions & 3 deletions ceno_zkvm/src/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,10 @@ impl<E: ExtensionField> Expression<E> {
Expression::Constant(c) => *c == E::BaseField::ZERO,
Expression::Sum(a, b) => Self::is_zero_expr(a) && Self::is_zero_expr(b),
Expression::Product(a, b) => Self::is_zero_expr(a) || Self::is_zero_expr(b),
Expression::ScaledSum(_, _, _) => false,
Expression::Challenge(_, _, _, _) => false,
Expression::ScaledSum(x, a, b) => {
(Self::is_zero_expr(x) || Self::is_zero_expr(a)) && Self::is_zero_expr(b)
}
Expression::Challenge(_, _, scalar, offset) => *scalar == E::ZERO && *offset == E::ZERO,
}
}

Expand All @@ -143,7 +145,9 @@ impl<E: ExtensionField> Expression<E> {
&& Self::is_monomial_form_inner(MonomialState::ProductTerm, b)
}
(Expression::ScaledSum(_, _, _), MonomialState::SumTerm) => true,
(Expression::ScaledSum(_, _, b), MonomialState::ProductTerm) => Self::is_zero_expr(b),
(Expression::ScaledSum(x, a, b), MonomialState::ProductTerm) => {
Self::is_zero_expr(x) || Self::is_zero_expr(a) || Self::is_zero_expr(b)
}
}
}
}
Expand Down
12 changes: 2 additions & 10 deletions ceno_zkvm/src/expression/monomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,18 +81,10 @@ impl<E: ExtensionField> Expression<E> {
fn sum_terms(terms: Vec<Term<E>>) -> Self {
terms
.into_iter()
.map(|term| term.vars.into_iter().fold(term.coeff, Self::product))
.reduce(Self::sum)
.map(|term| term.vars.into_iter().fold(term.coeff, |a, b| a * b))
.reduce(|a, b| a + b)
.unwrap_or(Expression::ZERO)
}

fn product(a: Self, b: Self) -> Self {
Product(Box::new(a), Box::new(b))
}

fn sum(a: Self, b: Self) -> Self {
Sum(Box::new(a), Box::new(b))
}
}

#[derive(Clone, Debug)]
Expand Down
80 changes: 80 additions & 0 deletions ceno_zkvm/src/gadgets/is_zero.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
use std::mem::MaybeUninit;

use ff_ext::ExtensionField;
use goldilocks::SmallField;

use crate::{
circuit_builder::CircuitBuilder,
error::ZKVMError,
expression::{Expression, ToExpr, WitIn},
set_val,
};

pub struct IsZeroConfig {
is_zero: WitIn,
inverse: WitIn,
}

impl IsZeroConfig {
pub fn expr<E: ExtensionField>(&self) -> Expression<E> {
self.is_zero.expr()
}

pub fn construct_circuit<E: ExtensionField>(
cb: &mut CircuitBuilder<E>,
x: Expression<E>,
) -> Result<Self, ZKVMError> {
let is_zero = cb.create_witin(|| "is_zero")?;
let inverse = cb.create_witin(|| "inv")?;

// x==0 => is_zero=1
cb.require_one(|| "is_zero_1", is_zero.expr() + x.clone() * inverse.expr())?;

// x!=0 => is_zero=0
cb.require_zero(|| "is_zero_0", is_zero.expr() * x.clone())?;

Ok(IsZeroConfig { is_zero, inverse })
}

pub fn assign_instance<F: SmallField>(
&self,
instance: &mut [MaybeUninit<F>],
x: F,
) -> Result<(), ZKVMError> {
let (is_zero, inverse) = if x.is_zero_vartime() {
(F::ONE, F::ZERO)
} else {
(F::ZERO, x.invert().expect("not zero"))
};

set_val!(instance, self.is_zero, is_zero);
set_val!(instance, self.inverse, inverse);

Ok(())
}
}

pub struct IsEqualConfig(IsZeroConfig);

impl IsEqualConfig {
pub fn expr<E: ExtensionField>(&self) -> Expression<E> {
self.0.expr()
}

pub fn construct_circuit<E: ExtensionField>(
cb: &mut CircuitBuilder<E>,
a: Expression<E>,
b: Expression<E>,
) -> Result<Self, ZKVMError> {
Ok(IsEqualConfig(IsZeroConfig::construct_circuit(cb, a - b)?))
}

pub fn assign_instance<F: SmallField>(
&self,
instance: &mut [MaybeUninit<F>],
a: F,
b: F,
) -> Result<(), ZKVMError> {
self.0.assign_instance(instance, a - b)
}
}
2 changes: 2 additions & 0 deletions ceno_zkvm/src/gadgets/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
mod is_zero;
pub use is_zero::{IsEqualConfig, IsZeroConfig};
1 change: 1 addition & 0 deletions ceno_zkvm/src/instructions/riscv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use ceno_emul::InsnKind;

pub mod arith;
pub mod blt;
pub mod branch;
pub mod config;
pub mod constants;
pub mod logic;
Expand Down
17 changes: 13 additions & 4 deletions ceno_zkvm/src/instructions/riscv/b_insn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ use core::mem::MaybeUninit;
#[derive(Debug)]
pub struct BInstructionConfig {
pc: WitIn,
next_pc: WitIn,
ts: WitIn,
rs1_id: WitIn,
rs2_id: WitIn,
Expand Down Expand Up @@ -98,13 +99,20 @@ impl BInstructionConfig {
)?;

// State out.
let pc_offset = branch_taken_bit * (imm.expr() - PC_STEP_SIZE.into()) + PC_STEP_SIZE.into();
let next_pc = pc.expr() + pc_offset;
let next_pc = {
let pc_offset = branch_taken_bit.clone() * imm.expr()
- branch_taken_bit * PC_STEP_SIZE.into()
+ PC_STEP_SIZE.into();
let next_pc = circuit_builder.create_witin(|| "next_pc")?;
circuit_builder.require_equal(|| "pc_branch", next_pc.expr(), pc.expr() + pc_offset)?;
next_pc
};
let next_ts = cur_ts.expr() + 4.into();
circuit_builder.state_out(next_pc, next_ts)?;
circuit_builder.state_out(next_pc.expr(), next_ts)?;

Ok(BInstructionConfig {
pc,
next_pc,
ts: cur_ts,
rs1_id,
rs2_id,
Expand All @@ -122,8 +130,9 @@ impl BInstructionConfig {
lk_multiplicity: &mut LkMultiplicity,
step: &StepRecord,
) -> Result<(), ZKVMError> {
// State in.
// State.
set_val!(instance, self.pc, step.pc().before.0 as u64);
set_val!(instance, self.next_pc, step.pc().after.0 as u64);
set_val!(instance, self.ts, step.cycle());

// Register indexes and immediate.
Expand Down
19 changes: 19 additions & 0 deletions ceno_zkvm/src/instructions/riscv/branch.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
mod beq_circuit;
use super::RIVInstruction;
use beq_circuit::BeqCircuit;
use ceno_emul::InsnKind;

#[cfg(test)]
mod test;

pub struct BeqOp;
impl RIVInstruction for BeqOp {
const INST_KIND: InsnKind = InsnKind::BEQ;
}
pub type BeqInstruction<E> = BeqCircuit<E, BeqOp>;

pub struct BneOp;
impl RIVInstruction for BneOp {
const INST_KIND: InsnKind = InsnKind::BNE;
}
pub type BneInstruction<E> = BeqCircuit<E, BneOp>;
97 changes: 97 additions & 0 deletions ceno_zkvm/src/instructions/riscv/branch/beq_circuit.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
use std::{marker::PhantomData, mem::MaybeUninit};

use ceno_emul::{InsnKind, StepRecord};
use ff_ext::ExtensionField;

use crate::{
circuit_builder::CircuitBuilder,
error::ZKVMError,
expression::Expression,
gadgets::IsEqualConfig,
instructions::{
riscv::{b_insn::BInstructionConfig, constants::UInt, RIVInstruction},
Instruction,
},
witness::LkMultiplicity,
Value,
};

pub struct BeqConfig<E: ExtensionField> {
b_insn: BInstructionConfig,

// TODO: Limb decomposition is not necessary. Replace with a single witness.
rs1_read: UInt<E>,
rs2_read: UInt<E>,

equal: IsEqualConfig,
}

pub struct BeqCircuit<E, I>(PhantomData<(E, I)>);

impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for BeqCircuit<E, I> {
type InstructionConfig = BeqConfig<E>;

fn name() -> String {
format!("{:?}", I::INST_KIND)
}

fn construct_circuit(
circuit_builder: &mut CircuitBuilder<E>,
) -> Result<Self::InstructionConfig, ZKVMError> {
let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?;
let rs2_read = UInt::new_unchecked(|| "rs2_read", circuit_builder)?;

let equal =
IsEqualConfig::construct_circuit(circuit_builder, rs2_read.value(), rs1_read.value())?;

let branch_taken_bit = match I::INST_KIND {
InsnKind::BEQ => equal.expr(),
InsnKind::BNE => Expression::ONE - equal.expr(),
_ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND),
};

let b_insn = BInstructionConfig::construct_circuit(
circuit_builder,
I::INST_KIND,
rs1_read.register_expr(),
rs2_read.register_expr(),
branch_taken_bit,
)?;

Ok(BeqConfig {
b_insn,
rs1_read,
rs2_read,
equal,
})
}

fn assign_instance(
config: &Self::InstructionConfig,
instance: &mut [MaybeUninit<<E as ExtensionField>::BaseField>],
lk_multiplicity: &mut LkMultiplicity,
step: &StepRecord,
) -> Result<(), ZKVMError> {
config
.b_insn
.assign_instance::<E>(instance, lk_multiplicity, step)?;

let rs1_read = step.rs1().unwrap().value;
config
.rs1_read
.assign_limbs(instance, Value::new_unchecked(rs1_read).u16_fields());

let rs2_read = step.rs2().unwrap().value;
config
.rs2_read
.assign_limbs(instance, Value::new_unchecked(rs2_read).u16_fields());

config.equal.assign_instance(
instance,
E::BaseField::from(rs2_read as u64),
E::BaseField::from(rs1_read as u64),
)?;

Ok(())
}
}
Loading

0 comments on commit 3f54ed4

Please sign in to comment.