From 243ae2b95d0467b12cf13eacb85344a15436fdca Mon Sep 17 00:00:00 2001 From: naure Date: Thu, 19 Sep 2024 16:18:09 +0200 Subject: [PATCH] Logic Ops (#246) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Aurélien Nicolas --- ceno_zkvm/src/chip_handler/general.rs | 63 +++++--- ceno_zkvm/src/instructions/riscv.rs | 1 + ceno_zkvm/src/instructions/riscv/logic.rs | 29 ++++ .../instructions/riscv/logic/logic_circuit.rs | 131 +++++++++++++++ .../src/instructions/riscv/logic/test.rs | 153 ++++++++++++++++++ ceno_zkvm/src/scheme/mock_prover.rs | 9 ++ ceno_zkvm/src/uint.rs | 1 + ceno_zkvm/src/uint/arithmetic.rs | 2 +- ceno_zkvm/src/uint/logic.rs | 34 ++++ ceno_zkvm/src/witness.rs | 23 ++- 10 files changed, 418 insertions(+), 28 deletions(-) create mode 100644 ceno_zkvm/src/instructions/riscv/logic.rs create mode 100644 ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs create mode 100644 ceno_zkvm/src/instructions/riscv/logic/test.rs create mode 100644 ceno_zkvm/src/uint/logic.rs diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index 4105d604b..e2f7610bd 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -235,40 +235,57 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { self.assert_u16(name_fn, expr * Expression::from(1 << 15)) } - /// Assert `a & b = res` and that `a, b, res` are all bytes. - pub(crate) fn lookup_and_byte( + /// Assert `rom_type(a, b) = c` and that `a, b, c` are all bytes. + pub fn logic_u8( &mut self, + rom_type: ROMType, a: Expression, b: Expression, - res: Expression, + c: Expression, ) -> Result<(), ZKVMError> { - let items: Vec> = vec![ - Expression::Constant(E::BaseField::from(ROMType::And as u64)), - a, - b, - res, - ]; + let items: Vec> = vec![(rom_type as usize).into(), a, b, c]; let rlc_record = self.rlc_chip_record(items); - self.lk_record(|| "and lookup record", rlc_record)?; - Ok(()) + self.lk_record(|| format!("lookup_{:?}", rom_type), rlc_record) } - /// Assert that `(a < b) == res as bool`, that `a, b` are unsigned bytes, and that `res` is 0 or 1. - pub(crate) fn lookup_ltu_limb8( + /// Assert `a & b = c` and that `a, b, c` are all bytes. + pub fn lookup_and_byte( &mut self, a: Expression, b: Expression, - res: Expression, + c: Expression, ) -> Result<(), ZKVMError> { - let items: Vec> = vec![ - Expression::Constant(E::BaseField::from(ROMType::Ltu as u64)), - a, - b, - res, - ]; - let rlc_record = self.rlc_chip_record(items); - self.lk_record(|| "ltu lookup record", rlc_record)?; - Ok(()) + self.logic_u8(ROMType::And, a, b, c) + } + + /// Assert `a | b = c` and that `a, b, c` are all bytes. + pub fn lookup_or_byte( + &mut self, + a: Expression, + b: Expression, + c: Expression, + ) -> Result<(), ZKVMError> { + self.logic_u8(ROMType::Or, a, b, c) + } + + /// Assert `a ^ b = c` and that `a, b, c` are all bytes. + pub fn lookup_xor_byte( + &mut self, + a: Expression, + b: Expression, + c: Expression, + ) -> Result<(), ZKVMError> { + self.logic_u8(ROMType::Xor, a, b, c) + } + + /// Assert that `(a < b) == c as bool`, that `a, b` are unsigned bytes, and that `c` is 0 or 1. + pub fn lookup_ltu_byte( + &mut self, + a: Expression, + b: Expression, + c: Expression, + ) -> Result<(), ZKVMError> { + self.logic_u8(ROMType::Ltu, a, b, c) } /// less_than diff --git a/ceno_zkvm/src/instructions/riscv.rs b/ceno_zkvm/src/instructions/riscv.rs index 2a7144c63..7af831d01 100644 --- a/ceno_zkvm/src/instructions/riscv.rs +++ b/ceno_zkvm/src/instructions/riscv.rs @@ -4,6 +4,7 @@ pub mod arith; pub mod blt; pub mod config; pub mod constants; +pub mod logic; mod r_insn; diff --git a/ceno_zkvm/src/instructions/riscv/logic.rs b/ceno_zkvm/src/instructions/riscv/logic.rs new file mode 100644 index 000000000..0f2aa485b --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/logic.rs @@ -0,0 +1,29 @@ +mod logic_circuit; +use logic_circuit::{LogicInstruction, LogicOp}; + +#[cfg(test)] +mod test; + +use crate::tables::{AndTable, OrTable, XorTable}; +use ceno_emul::InsnKind; + +pub struct AndOp; +impl LogicOp for AndOp { + const INST_KIND: InsnKind = InsnKind::AND; + type OpsTable = AndTable; +} +pub type AndInstruction = LogicInstruction; + +pub struct OrOp; +impl LogicOp for OrOp { + const INST_KIND: InsnKind = InsnKind::OR; + type OpsTable = OrTable; +} +pub type OrInstruction = LogicInstruction; + +pub struct XorOp; +impl LogicOp for XorOp { + const INST_KIND: InsnKind = InsnKind::XOR; + type OpsTable = XorTable; +} +pub type XorInstruction = LogicInstruction; diff --git a/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs b/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs new file mode 100644 index 000000000..9f990d1f5 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs @@ -0,0 +1,131 @@ +//! The circuit implementation of logic instructions. + +use core::mem::MaybeUninit; +use ff_ext::ExtensionField; +use std::marker::PhantomData; + +use crate::{ + circuit_builder::CircuitBuilder, + error::ZKVMError, + instructions::{ + riscv::{constants::UInt8, r_insn::RInstructionConfig}, + Instruction, + }, + tables::OpsTable, + witness::LkMultiplicity, +}; +use ceno_emul::{InsnKind, StepRecord, Word, WORD_SIZE}; + +/// This trait defines a logic instruction, connecting an instruction type to a lookup table. +pub trait LogicOp { + const INST_KIND: InsnKind; + type OpsTable: OpsTable; +} + +/// The Instruction circuit for a given LogicOp. +pub struct LogicInstruction(PhantomData<(E, I)>); + +impl Instruction for LogicInstruction { + type InstructionConfig = LogicConfig; + + fn name() -> String { + format!("{:?}", I::INST_KIND) + } + + fn construct_circuit(cb: &mut CircuitBuilder) -> Result { + let config = LogicConfig::construct_circuit(cb, I::INST_KIND)?; + + // Constrain the registers based on the given lookup table. + UInt8::logic( + cb, + I::OpsTable::ROM_TYPE, + &config.rs1_read, + &config.rs2_read, + &config.rd_written, + )?; + + Ok(config) + } + + fn assign_instance( + config: &Self::InstructionConfig, + instance: &mut [MaybeUninit<::BaseField>], + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + UInt8::::logic_assign::( + lk_multiplicity, + step.rs1().unwrap().value as u64, + step.rs2().unwrap().value as u64, + ); + + config.assign_instance(instance, lk_multiplicity, step) + } +} + +/// This config implements R-Instructions that represent registers values as 4 * u8. +/// Non-generic code shared by several circuits. +#[derive(Debug)] +pub struct LogicConfig { + r_insn: RInstructionConfig, + + rs1_read: UInt8, + rs2_read: UInt8, + rd_written: UInt8, +} + +impl LogicConfig { + fn construct_circuit( + cb: &mut CircuitBuilder, + insn_kind: InsnKind, + ) -> Result { + let rs1_read = UInt8::new_unchecked(|| "rs1_read", cb)?; + let rs2_read = UInt8::new_unchecked(|| "rs2_read", cb)?; + let rd_written = UInt8::new_unchecked(|| "rd_written", cb)?; + + let r_insn = RInstructionConfig::::construct_circuit( + cb, + insn_kind, + &rs1_read, + &rs2_read, + &rd_written, + )?; + + Ok(Self { + r_insn, + rs1_read, + rs2_read, + rd_written, + }) + } + + fn assign_instance( + &self, + instance: &mut [MaybeUninit<::BaseField>], + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + self.r_insn + .assign_instance(instance, lk_multiplicity, step)?; + + let rs1_read = Self::u8_limbs(step.rs1().unwrap().value); + self.rs1_read.assign_limbs(instance, rs1_read); + + let rs2_read = Self::u8_limbs(step.rs2().unwrap().value); + self.rs2_read.assign_limbs(instance, rs2_read); + + let rd_written = Self::u8_limbs(step.rd().unwrap().value.after); + self.rd_written.assign_limbs(instance, rd_written); + + Ok(()) + } + + /// Decompose a word into byte field elements in little-endian order. + fn u8_limbs(v: Word) -> Vec { + let mut limbs = Vec::with_capacity(WORD_SIZE); + for i in 0..WORD_SIZE { + limbs.push(E::BaseField::from(((v >> (i * 8)) & 0xff) as u64)); + } + limbs + } +} diff --git a/ceno_zkvm/src/instructions/riscv/logic/test.rs b/ceno_zkvm/src/instructions/riscv/logic/test.rs new file mode 100644 index 000000000..3e3db1ef8 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/logic/test.rs @@ -0,0 +1,153 @@ +use ceno_emul::{Change, StepRecord, Word}; +use goldilocks::GoldilocksExt2; +use itertools::Itertools; +use multilinear_extensions::mle::IntoMLEs; + +use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::Instruction, + scheme::mock_prover::{MockProver, MOCK_PC_AND, MOCK_PC_OR, MOCK_PC_XOR, MOCK_PROGRAM}, + ROMType, +}; + +use super::*; + +const A: Word = 0xbead1010; +const B: Word = 0xef552020; +// The pair of bytes from A and B. +const LOOKUPS: &[(u64, usize)] = &[(0x2010, 2), (0x55ad, 1), (0xefbe, 1)]; + +#[test] +fn test_opcode_and() { + let mut cs = ConstraintSystem::::new(|| "riscv"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = cb + .namespace( + || "and", + |cb| { + let config = AndInstruction::construct_circuit(cb); + Ok(config) + }, + ) + .unwrap() + .unwrap(); + + let (raw_witin, lkm) = AndInstruction::assign_instances( + &config, + cb.cs.num_witin as usize, + vec![StepRecord::new_r_instruction( + 3, + MOCK_PC_AND, + MOCK_PROGRAM[3], + A, + B, + Change::new(0, A & B), + 0, + )], + ) + .unwrap(); + + let lkm = lkm.into_finalize_result()[ROMType::And as usize].clone(); + assert_eq!(&lkm.into_iter().sorted().collect_vec(), LOOKUPS); + + MockProver::assert_satisfied( + &mut cb, + &raw_witin + .de_interleaving() + .into_mles() + .into_iter() + .map(|v| v.into()) + .collect_vec(), + None, + ); +} + +#[test] +fn test_opcode_or() { + let mut cs = ConstraintSystem::::new(|| "riscv"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = cb + .namespace( + || "or", + |cb| { + let config = OrInstruction::construct_circuit(cb); + Ok(config) + }, + ) + .unwrap() + .unwrap(); + + let (raw_witin, lkm) = OrInstruction::assign_instances( + &config, + cb.cs.num_witin as usize, + vec![StepRecord::new_r_instruction( + 3, + MOCK_PC_OR, + MOCK_PROGRAM[4], + A, + B, + Change::new(0, A | B), + 0, + )], + ) + .unwrap(); + + let lkm = lkm.into_finalize_result()[ROMType::Or as usize].clone(); + assert_eq!(&lkm.into_iter().sorted().collect_vec(), LOOKUPS); + + MockProver::assert_satisfied( + &mut cb, + &raw_witin + .de_interleaving() + .into_mles() + .into_iter() + .map(|v| v.into()) + .collect_vec(), + None, + ); +} + +#[test] +fn test_opcode_xor() { + let mut cs = ConstraintSystem::::new(|| "riscv"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = cb + .namespace( + || "xor", + |cb| { + let config = XorInstruction::construct_circuit(cb); + Ok(config) + }, + ) + .unwrap() + .unwrap(); + + let (raw_witin, lkm) = XorInstruction::assign_instances( + &config, + cb.cs.num_witin as usize, + vec![StepRecord::new_r_instruction( + 3, + MOCK_PC_XOR, + MOCK_PROGRAM[5], + A, + B, + Change::new(0, A ^ B), + 0, + )], + ) + .unwrap(); + + let lkm = lkm.into_finalize_result()[ROMType::Xor as usize].clone(); + assert_eq!(&lkm.into_iter().sorted().collect_vec(), LOOKUPS); + + MockProver::assert_satisfied( + &mut cb, + &raw_witin + .de_interleaving() + .into_mles() + .into_iter() + .map(|v| v.into()) + .collect_vec(), + None, + ); +} diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index 63bf593a7..07a1384c4 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -43,11 +43,20 @@ pub const MOCK_PROGRAM: &[u32] = &[ 0x20 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0x00 << 12 | MOCK_RD << 7 | 0x33, // mul (0x01, 0x00, 0x33) 0x01 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0x00 << 12 | MOCK_RD << 7 | 0x33, + // and x4, x2, x3 + 0x00 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0b111 << 12 | MOCK_RD << 7 | 0x33, + // or x4, x2, x3 + 0x00 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0b110 << 12 | MOCK_RD << 7 | 0x33, + // xor x4, x2, x3 + 0x00 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0b100 << 12 | MOCK_RD << 7 | 0x33, ]; // Addresses of particular instructions in the mock program. pub const MOCK_PC_ADD: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start()); pub const MOCK_PC_SUB: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 4); pub const MOCK_PC_MUL: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 8); +pub const MOCK_PC_AND: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 12); +pub const MOCK_PC_OR: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 16); +pub const MOCK_PC_XOR: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 20); #[allow(clippy::enum_variant_names)] #[derive(Debug, PartialEq, Clone)] diff --git a/ceno_zkvm/src/uint.rs b/ceno_zkvm/src/uint.rs index 5a657a8bb..bb70a7103 100644 --- a/ceno_zkvm/src/uint.rs +++ b/ceno_zkvm/src/uint.rs @@ -1,5 +1,6 @@ mod arithmetic; pub mod constants; +mod logic; pub mod util; use crate::{ diff --git a/ceno_zkvm/src/uint/arithmetic.rs b/ceno_zkvm/src/uint/arithmetic.rs index c8ae9bb0b..cf242ccc4 100644 --- a/ceno_zkvm/src/uint/arithmetic.rs +++ b/ceno_zkvm/src/uint/arithmetic.rs @@ -348,7 +348,7 @@ impl UIntLimbs { let is_ltu = circuit_builder.create_witin(|| "is_ltu")?; // circuit_builder.assert_bit(is_ltu.expr())?; // lookup ensure it is bit // now we know the first non-equal byte pairs is (lhs_ne_byte, rhs_ne_byte) - circuit_builder.lookup_ltu_limb8(lhs_ne_byte.expr(), rhs_ne_byte.expr(), is_ltu.expr())?; + circuit_builder.lookup_ltu_byte(lhs_ne_byte.expr(), rhs_ne_byte.expr(), is_ltu.expr())?; Ok(UIntLtuConfig { byte_diff_inv, indexes, diff --git a/ceno_zkvm/src/uint/logic.rs b/ceno_zkvm/src/uint/logic.rs new file mode 100644 index 000000000..78c3148ad --- /dev/null +++ b/ceno_zkvm/src/uint/logic.rs @@ -0,0 +1,34 @@ +use ff_ext::ExtensionField; +use itertools::izip; + +use super::UIntLimbs; +use crate::{ + circuit_builder::CircuitBuilder, error::ZKVMError, expression::ToExpr, tables::OpsTable, + witness::LkMultiplicity, ROMType, +}; + +// Only implemented for u8 limbs. +impl UIntLimbs { + /// Assert `rom_type(a, b) = c` and range-check `a, b, c`. + /// This works with a lookup for each u8 limb. + pub fn logic( + cb: &mut CircuitBuilder, + rom_type: ROMType, + a: &Self, + b: &Self, + c: &Self, + ) -> Result<(), ZKVMError> { + for (a_byte, b_byte, c_byte) in izip!(a.limbs.iter(), b.limbs.iter(), c.limbs.iter()) { + cb.logic_u8(rom_type, a_byte.expr(), b_byte.expr(), c_byte.expr())?; + } + Ok(()) + } + + pub fn logic_assign(lk_multiplicity: &mut LkMultiplicity, a: u64, b: u64) { + for i in 0..M.div_ceil(8) { + let a_byte = (a >> (i * 8)) & 0xff; + let b_byte = (b >> (i * 8)) & 0xff; + lk_multiplicity.logic_u8::(a_byte, b_byte); + } + } +} diff --git a/ceno_zkvm/src/witness.rs b/ceno_zkvm/src/witness.rs index db437181c..93b01cc04 100644 --- a/ceno_zkvm/src/witness.rs +++ b/ceno_zkvm/src/witness.rs @@ -16,7 +16,7 @@ use thread_local::ThreadLocal; use crate::{ structs::ROMType, - tables::{AndTable, LtuTable, OpsTable}, + tables::{AndTable, LtuTable, OpsTable, OrTable, XorTable}, }; #[macro_export] @@ -110,14 +110,29 @@ impl LkMultiplicity { } } + /// Track a lookup into a logic table (AndTable, etc). + pub fn logic_u8(&mut self, a: u64, b: u64) { + self.increment(OP::ROM_TYPE, OP::pack(a, b)); + } + /// lookup a AND b pub fn lookup_and_byte(&mut self, a: u64, b: u64) { - self.increment(ROMType::And, AndTable::pack(a, b)); + self.logic_u8::(a, b) + } + + /// lookup a OR b + pub fn lookup_or_byte(&mut self, a: u64, b: u64) { + self.logic_u8::(a, b) + } + + /// lookup a XOR b + pub fn lookup_xor_byte(&mut self, a: u64, b: u64) { + self.logic_u8::(a, b) } /// lookup a < b as unsigned byte - pub fn lookup_ltu_limb8(&mut self, a: u64, b: u64) { - self.increment(ROMType::Ltu, LtuTable::pack(a, b)); + pub fn lookup_ltu_byte(&mut self, a: u64, b: u64) { + self.logic_u8::(a, b) } /// Fetch instruction at pc