Skip to content

Commit

Permalink
refine shift with updated div_config and adding some test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
KimiWu123 committed Oct 8, 2024
1 parent 03813b5 commit e9eaf3d
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 53 deletions.
5 changes: 0 additions & 5 deletions ceno_zkvm/src/instructions/riscv/divu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,6 @@ impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for ArithInstruction<E
fn construct_circuit(
circuit_builder: &mut CircuitBuilder<E>,
) -> Result<Self::InstructionConfig, ZKVMError> {
// outcome = dividend / divisor + remainder => dividend = divisor * outcome + r
// let mut divisor = UInt::new_unchecked(|| "divisor", circuit_builder)?;
// let mut outcome = UInt::new(|| "outcome", circuit_builder)?;
// let r = UInt::new(|| "remainder", circuit_builder)?;

let div_config = DivConfig::construct_circuit(circuit_builder, || "divu")?;

// div by zero check
Expand Down
74 changes: 26 additions & 48 deletions ceno_zkvm/src/instructions/riscv/shift.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ pub struct ShiftConfig<E: ExtensionField> {
pow2_rs2_low5: UInt<E>,

// for SRL division arithmetics
remainder: Option<UInt<E>>,
div_config: Option<DivConfig<E>>,
}

Expand Down Expand Up @@ -57,7 +56,7 @@ impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for ShiftLogicalInstru
// rs2 = rs2_high | rs2_low5
let rs2_high = UInt::new(|| "rs2_high", circuit_builder)?;

let (rs1_read, rd_written, remainder, div_config) = match I::INST_KIND {
let (rs1_read, rd_written, div_config) = match I::INST_KIND {
InsnKind::SLL => {
let mut rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?;
let rd_written = rs1_read.mul(
Expand All @@ -66,22 +65,13 @@ impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for ShiftLogicalInstru
&mut pow2_rs2_low5,
true,
)?;
(rs1_read, rd_written, None, None)
(rs1_read, rd_written, None)
}
InsnKind::SRL => {
let mut rd_written = UInt::new(|| "rd_written", circuit_builder)?;
let remainder = UInt::new(|| "remainder", circuit_builder)?;
let div_config = DivConfig::construct_circuit(
circuit_builder,
|| "srl_div",
&mut pow2_rs2_low5,
&mut rd_written,
&remainder,
)?;
let div_config = DivConfig::construct_circuit(circuit_builder, || "srl_div")?;
(
div_config.dividend.clone(),
rd_written,
Some(remainder),
div_config.quotient.clone(),
Some(div_config),
)
}
Expand Down Expand Up @@ -112,7 +102,6 @@ impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for ShiftLogicalInstru
rs2_high,
rs2_low5,
pow2_rs2_low5,
remainder,
div_config,
})
}
Expand Down Expand Up @@ -160,11 +149,6 @@ impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for ShiftLogicalInstru
)?;

config.rd_written.assign_value(instance, rd_written);
config
.remainder
.as_ref()
.unwrap()
.assign_value(instance, remainder);
}
_ => unreachable!(),
}
Expand Down Expand Up @@ -201,51 +185,45 @@ mod tests {
use super::{ShiftLogicalInstruction, SllOp, SrlOp};

#[test]
fn test_opcode_sll_1() {
verify::<SllOp>(0b_1, 3, 0b_1000);
}

#[test]
fn test_opcode_sll_2_rs2_overflow() {
fn test_opcode_sll() {
verify::<SllOp>("basic", 0b_0001, 3, 0b_1000);
// 33 << 33 === 33 << 1
verify::<SllOp>(0b_1, 33, 0b_10);
}

#[test]
fn test_opcode_sll_3_bit_loss() {
verify::<SllOp>(1 << 31 | 1, 1, 0b_10);
}

#[test]
fn test_opcode_srl_1() {
verify::<SrlOp>(0b_1000, 3, 0b_1);
}

#[test]
fn test_opcode_srl_2_rs2_overflow() {
// 33 >> 33 === 33 >> 1
verify::<SrlOp>(0b_1010, 33, 0b_101);
verify::<SllOp>("rs2 over 5-bits", 0b_0001, 33, 0b_0010);
verify::<SllOp>("bit loss", 1 << 31 | 1, 1, 0b_0010);
verify::<SllOp>("zero shift", 0b_0001, 0, 0b_0001);
verify::<SllOp>("all zeros", 0b_0000, 0, 0b_0000);
verify::<SllOp>("base is zero", 0b_0000, 1, 0b_0000);
}

#[test]
fn test_opcode_srl_3_bit_loss() {
fn test_opcode_srl() {
verify::<SrlOp>("basic", 0b_1000, 3, 0b_0001);
// 33 >> 33 === 33 >> 1
verify::<SrlOp>(0b_1001, 1, 0b_100);
verify::<SrlOp>("rs2 over 5-bits", 0b_1010, 33, 0b_0101);
verify::<SrlOp>("bit loss", 0b_1001, 1, 0b_0100);
verify::<SrlOp>("zero shift", 0b_1000, 0, 0b_1000);
verify::<SrlOp>("all zeros", 0b_0000, 0, 0b_0000);
verify::<SrlOp>("base is zero", 0b_0000, 1, 0b_0000);
}

fn verify<I: RIVInstruction>(rs1_read: u32, rs2_read: u32, expected_rd_written: u32) {
fn verify<I: RIVInstruction>(
name: &'static str,
rs1_read: u32,
rs2_read: u32,
expected_rd_written: u32,
) {
let mut cs = ConstraintSystem::<GoldilocksExt2>::new(|| "riscv");
let mut cb = CircuitBuilder::new(&mut cs);

let (name, mock_pc, mock_program_op) = match I::INST_KIND {
let (prefix, mock_pc, mock_program_op) = match I::INST_KIND {
InsnKind::SLL => ("SLL", MOCK_PC_SLL, MOCK_PROGRAM[19]),
InsnKind::SRL => ("SRL", MOCK_PC_SRL, MOCK_PROGRAM[20]),
_ => unreachable!(),
};

let config = cb
.namespace(
|| name,
|| format!("{prefix}_({name})"),
|cb| {
let config =
ShiftLogicalInstruction::<GoldilocksExt2, I>::construct_circuit(cb);
Expand Down

0 comments on commit e9eaf3d

Please sign in to comment.