diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 598940d22..a4f58fffc 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -58,5 +58,5 @@ jobs: RAYON_NUM_THREADS: 2 with: command: run - args: --package ceno_zkvm --example riscv_add --target ${{ matrix.target }} -- --start 10 --end 11 + args: --package ceno_zkvm --example riscv_opcodes --target ${{ matrix.target }} -- --start 9 --end 10 diff --git a/Cargo.lock b/Cargo.lock index 2db07c4e4..f87fc142f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1180,7 +1180,7 @@ dependencies = [ "num-bigint", "num-integer", "plonky2", - "poseidon", + "poseidon 0.2.0", "rand", "rand_chacha", "rayon", @@ -1477,6 +1477,20 @@ dependencies = [ "plotters-backend", ] +[[package]] +name = "poseidon" +version = "0.1.0" +dependencies = [ + "ark-std", + "criterion", + "ff", + "goldilocks", + "plonky2", + "rand", + "serde", + "unroll", +] + [[package]] name = "poseidon" version = "0.2.0" @@ -2205,7 +2219,7 @@ dependencies = [ "ff_ext", "goldilocks", "halo2curves 0.1.0", - "poseidon", + "poseidon 0.2.0", "rayon", "serde", ] diff --git a/Cargo.toml b/Cargo.toml index 75c8f859a..a5bd699c1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,8 @@ members = [ "singer-utils", "sumcheck", "transcript", - "ceno_zkvm" + "ceno_zkvm", + "poseidon" ] [workspace.package] diff --git a/ceno_emul/src/rv32im.rs b/ceno_emul/src/rv32im.rs index db2a8699d..86c7c71cb 100644 --- a/ceno_emul/src/rv32im.rs +++ b/ceno_emul/src/rv32im.rs @@ -121,6 +121,18 @@ pub enum InsnCategory { System, Invalid, } +use InsnCategory::*; + +#[derive(Clone, Copy, Debug)] +pub enum InsnFormat { + R, + I, + S, + B, + U, + J, +} +use InsnFormat::*; #[derive(Clone, Copy, Debug, PartialEq, EnumIter)] #[allow(clippy::upper_case_acronyms)] @@ -174,6 +186,7 @@ pub enum InsnKind { EANY, MRET, } +use InsnKind::*; impl InsnKind { pub const fn codes(self) -> InsnCodes { @@ -183,6 +196,7 @@ impl InsnKind { #[derive(Clone, Copy, Debug)] pub struct InsnCodes { + pub format: InsnFormat, pub kind: InsnKind, category: InsnCategory, pub opcode: u32, @@ -234,30 +248,99 @@ impl DecodedInstruction { self.opcode } + /// Get the rd field, regardless of the instruction format. pub fn rd(&self) -> u32 { self.rd } + /// Get the register destination, or zero if the instruction does not write to a register. + pub fn rd_or_zero(&self) -> u32 { + match self.codes().format { + R | I | U | J => self.rd, + _ => 0, + } + } + + /// Get the funct3 field, regardless of the instruction format. pub fn funct3(&self) -> u32 { self.func3 } + /// Get the funct3 field, or zero if the instruction does not use funct3. + pub fn funct3_or_zero(&self) -> u32 { + match self.codes().format { + R | I | S | B => self.func3, + _ => 0, + } + } + + /// Get the rs1 field, regardless of the instruction format. pub fn rs1(&self) -> u32 { self.rs1 } + /// Get the register source 1, or zero if the instruction does not use rs1. + pub fn rs1_or_zero(&self) -> u32 { + match self.codes().format { + R | I | S | B => self.rs1, + _ => 0, + } + } + + /// Get the rs2 field, regardless of the instruction format. pub fn rs2(&self) -> u32 { self.rs2 } + /// Get the register source 2, or zero if the instruction does not use rs2. + pub fn rs2_or_zero(&self) -> u32 { + match self.codes().format { + R | S | B => self.rs2, + _ => 0, + } + } + + /// Get the funct7 field, regardless of the instruction format. pub fn funct7(&self) -> u32 { self.func7 } + /// Get the decoded immediate, or 2^shift, or the funct7 field, depending on the instruction format. + pub fn imm_or_funct7(&self) -> u32 { + match self.codes() { + InsnCodes { format: R, .. } => self.func7, + InsnCodes { + kind: SLLI | SRLI | SRAI, + .. + } => 1 << self.rs2(), // decode the shift as a multiplication by 2.pow(rs2) + InsnCodes { format: I, .. } => self.imm_i(), + InsnCodes { format: S, .. } => self.imm_s(), + InsnCodes { format: B, .. } => self.imm_b(), + InsnCodes { format: U, .. } => self.imm_u(), + InsnCodes { format: J, .. } => self.imm_j(), + } + } + + /// Indicate whether the immediate is interpreted as a signed integer, and it is negative. + pub fn imm_is_negative(&self) -> bool { + match self.codes() { + InsnCodes { format: R, .. } => false, + InsnCodes { + kind: SLLI | SRLI | SRAI, + .. + } => false, + _ => self.top_bit != 0, + } + } + pub fn sign_bit(&self) -> u32 { self.top_bit } + pub fn codes(&self) -> InsnCodes { + FastDecodeTable::get().lookup(self) + } + pub fn kind(&self) -> (InsnCategory, InsnKind) { let i = FastDecodeTable::get().lookup(self); (i.category, i.kind) @@ -292,7 +375,34 @@ impl DecodedInstruction { } } +#[cfg(test)] +#[test] +fn test_decode_imm() { + for (i, expected) in [ + // Example of I-type: ADDI. + // imm | rs1 | funct3 | rd | opcode + (89 << 20 | 1 << 15 | 0b000 << 12 | 1 << 7 | 0x13, 89), + // Shifts get a precomputed power of 2: SLLI, SRLI, SRAI. + (31 << 20 | 1 << 15 | 0b001 << 12 | 1 << 7 | 0x13, 1 << 31), + (31 << 20 | 1 << 15 | 0b101 << 12 | 1 << 7 | 0x13, 1 << 31), + ( + 1 << 30 | 31 << 20 | 1 << 15 | 0b101 << 12 | 1 << 7 | 0x13, + 1 << 31, + ), + // Example of R-type with funct7: SUB. + // funct7 | rs2 | rs1 | funct3 | rd | opcode + ( + 0x20 << 25 | 1 << 20 | 1 << 15 | 0 << 12 | 1 << 7 | 0x33, + 0x20, + ), + ] { + let imm = DecodedInstruction::new(i).imm_or_funct7(); + assert_eq!(imm, expected); + } +} + const fn insn( + format: InsnFormat, kind: InsnKind, category: InsnCategory, opcode: u32, @@ -300,6 +410,7 @@ const fn insn( func7: i32, ) -> InsnCodes { InsnCodes { + format, kind, category, opcode, @@ -312,54 +423,54 @@ type InstructionTable = [InsnCodes; 48]; type FastInstructionTable = [u8; 1 << 10]; const RV32IM_ISA: InstructionTable = [ - insn(InsnKind::INVALID, InsnCategory::Invalid, 0x00, 0x0, 0x00), - insn(InsnKind::ADD, InsnCategory::Compute, 0x33, 0x0, 0x00), - insn(InsnKind::SUB, InsnCategory::Compute, 0x33, 0x0, 0x20), - insn(InsnKind::XOR, InsnCategory::Compute, 0x33, 0x4, 0x00), - insn(InsnKind::OR, InsnCategory::Compute, 0x33, 0x6, 0x00), - insn(InsnKind::AND, InsnCategory::Compute, 0x33, 0x7, 0x00), - insn(InsnKind::SLL, InsnCategory::Compute, 0x33, 0x1, 0x00), - insn(InsnKind::SRL, InsnCategory::Compute, 0x33, 0x5, 0x00), - insn(InsnKind::SRA, InsnCategory::Compute, 0x33, 0x5, 0x20), - insn(InsnKind::SLT, InsnCategory::Compute, 0x33, 0x2, 0x00), - insn(InsnKind::SLTU, InsnCategory::Compute, 0x33, 0x3, 0x00), - insn(InsnKind::ADDI, InsnCategory::Compute, 0x13, 0x0, -1), - insn(InsnKind::XORI, InsnCategory::Compute, 0x13, 0x4, -1), - insn(InsnKind::ORI, InsnCategory::Compute, 0x13, 0x6, -1), - insn(InsnKind::ANDI, InsnCategory::Compute, 0x13, 0x7, -1), - insn(InsnKind::SLLI, InsnCategory::Compute, 0x13, 0x1, 0x00), - insn(InsnKind::SRLI, InsnCategory::Compute, 0x13, 0x5, 0x00), - insn(InsnKind::SRAI, InsnCategory::Compute, 0x13, 0x5, 0x20), - insn(InsnKind::SLTI, InsnCategory::Compute, 0x13, 0x2, -1), - insn(InsnKind::SLTIU, InsnCategory::Compute, 0x13, 0x3, -1), - insn(InsnKind::BEQ, InsnCategory::Compute, 0x63, 0x0, -1), - insn(InsnKind::BNE, InsnCategory::Compute, 0x63, 0x1, -1), - insn(InsnKind::BLT, InsnCategory::Compute, 0x63, 0x4, -1), - insn(InsnKind::BGE, InsnCategory::Compute, 0x63, 0x5, -1), - insn(InsnKind::BLTU, InsnCategory::Compute, 0x63, 0x6, -1), - insn(InsnKind::BGEU, InsnCategory::Compute, 0x63, 0x7, -1), - insn(InsnKind::JAL, InsnCategory::Compute, 0x6f, -1, -1), - insn(InsnKind::JALR, InsnCategory::Compute, 0x67, 0x0, -1), - insn(InsnKind::LUI, InsnCategory::Compute, 0x37, -1, -1), - insn(InsnKind::AUIPC, InsnCategory::Compute, 0x17, -1, -1), - insn(InsnKind::MUL, InsnCategory::Compute, 0x33, 0x0, 0x01), - insn(InsnKind::MULH, InsnCategory::Compute, 0x33, 0x1, 0x01), - insn(InsnKind::MULHSU, InsnCategory::Compute, 0x33, 0x2, 0x01), - insn(InsnKind::MULHU, InsnCategory::Compute, 0x33, 0x3, 0x01), - insn(InsnKind::DIV, InsnCategory::Compute, 0x33, 0x4, 0x01), - insn(InsnKind::DIVU, InsnCategory::Compute, 0x33, 0x5, 0x01), - insn(InsnKind::REM, InsnCategory::Compute, 0x33, 0x6, 0x01), - insn(InsnKind::REMU, InsnCategory::Compute, 0x33, 0x7, 0x01), - insn(InsnKind::LB, InsnCategory::Load, 0x03, 0x0, -1), - insn(InsnKind::LH, InsnCategory::Load, 0x03, 0x1, -1), - insn(InsnKind::LW, InsnCategory::Load, 0x03, 0x2, -1), - insn(InsnKind::LBU, InsnCategory::Load, 0x03, 0x4, -1), - insn(InsnKind::LHU, InsnCategory::Load, 0x03, 0x5, -1), - insn(InsnKind::SB, InsnCategory::Store, 0x23, 0x0, -1), - insn(InsnKind::SH, InsnCategory::Store, 0x23, 0x1, -1), - insn(InsnKind::SW, InsnCategory::Store, 0x23, 0x2, -1), - insn(InsnKind::EANY, InsnCategory::System, 0x73, 0x0, 0x00), - insn(InsnKind::MRET, InsnCategory::System, 0x73, 0x0, 0x18), + insn(R, INVALID, Invalid, 0x00, 0x0, 0x00), + insn(R, ADD, Compute, 0x33, 0x0, 0x00), + insn(R, SUB, Compute, 0x33, 0x0, 0x20), + insn(R, XOR, Compute, 0x33, 0x4, 0x00), + insn(R, OR, Compute, 0x33, 0x6, 0x00), + insn(R, AND, Compute, 0x33, 0x7, 0x00), + insn(R, SLL, Compute, 0x33, 0x1, 0x00), + insn(R, SRL, Compute, 0x33, 0x5, 0x00), + insn(R, SRA, Compute, 0x33, 0x5, 0x20), + insn(R, SLT, Compute, 0x33, 0x2, 0x00), + insn(R, SLTU, Compute, 0x33, 0x3, 0x00), + insn(I, ADDI, Compute, 0x13, 0x0, -1), + insn(I, XORI, Compute, 0x13, 0x4, -1), + insn(I, ORI, Compute, 0x13, 0x6, -1), + insn(I, ANDI, Compute, 0x13, 0x7, -1), + insn(I, SLLI, Compute, 0x13, 0x1, 0x00), + insn(I, SRLI, Compute, 0x13, 0x5, 0x00), + insn(I, SRAI, Compute, 0x13, 0x5, 0x20), + insn(I, SLTI, Compute, 0x13, 0x2, -1), + insn(I, SLTIU, Compute, 0x13, 0x3, -1), + insn(B, BEQ, Compute, 0x63, 0x0, -1), + insn(B, BNE, Compute, 0x63, 0x1, -1), + insn(B, BLT, Compute, 0x63, 0x4, -1), + insn(B, BGE, Compute, 0x63, 0x5, -1), + insn(B, BLTU, Compute, 0x63, 0x6, -1), + insn(B, BGEU, Compute, 0x63, 0x7, -1), + insn(J, JAL, Compute, 0x6f, -1, -1), + insn(I, JALR, Compute, 0x67, 0x0, -1), + insn(U, LUI, Compute, 0x37, -1, -1), + insn(U, AUIPC, Compute, 0x17, -1, -1), + insn(R, MUL, Compute, 0x33, 0x0, 0x01), + insn(R, MULH, Compute, 0x33, 0x1, 0x01), + insn(R, MULHSU, Compute, 0x33, 0x2, 0x01), + insn(R, MULHU, Compute, 0x33, 0x3, 0x01), + insn(R, DIV, Compute, 0x33, 0x4, 0x01), + insn(R, DIVU, Compute, 0x33, 0x5, 0x01), + insn(R, REM, Compute, 0x33, 0x6, 0x01), + insn(R, REMU, Compute, 0x33, 0x7, 0x01), + insn(I, LB, Load, 0x03, 0x0, -1), + insn(I, LH, Load, 0x03, 0x1, -1), + insn(I, LW, Load, 0x03, 0x2, -1), + insn(I, LBU, Load, 0x03, 0x4, -1), + insn(I, LHU, Load, 0x03, 0x5, -1), + insn(S, SB, Store, 0x23, 0x0, -1), + insn(S, SH, Store, 0x23, 0x1, -1), + insn(S, SW, Store, 0x23, 0x2, -1), + insn(I, EANY, System, 0x73, 0x0, 0x00), + insn(I, MRET, System, 0x73, 0x0, 0x18), ]; #[cfg(test)] diff --git a/ceno_emul/src/tracer.rs b/ceno_emul/src/tracer.rs index 2a73ca4ac..ea636ee6e 100644 --- a/ceno_emul/src/tracer.rs +++ b/ceno_emul/src/tracer.rs @@ -86,6 +86,62 @@ impl StepRecord { } } + pub fn new_b_instruction( + cycle: Cycle, + pc: Change, + insn_code: Word, + rs1_read: Word, + rs2_read: Word, + previous_cycle: Cycle, + ) -> StepRecord { + let insn = DecodedInstruction::new(insn_code); + StepRecord { + cycle, + pc, + insn_code, + rs1: Some(ReadOp { + addr: CENO_PLATFORM.register_vma(insn.rs1() as RegIdx).into(), + value: rs1_read, + previous_cycle, + }), + rs2: Some(ReadOp { + addr: CENO_PLATFORM.register_vma(insn.rs2() as RegIdx).into(), + value: rs2_read, + previous_cycle, + }), + rd: None, + memory_op: None, + } + } + + pub fn new_i_instruction( + cycle: Cycle, + pc: ByteAddr, + insn_code: Word, + rs1_read: Word, + rd: Change, + previous_cycle: Cycle, + ) -> StepRecord { + let insn = DecodedInstruction::new(insn_code); + StepRecord { + cycle, + pc: Change::new(pc, pc + PC_STEP_SIZE), + insn_code, + rs1: Some(ReadOp { + addr: CENO_PLATFORM.register_vma(insn.rs1() as RegIdx).into(), + value: rs1_read, + previous_cycle, + }), + rs2: None, + rd: Some(WriteOp { + addr: CENO_PLATFORM.register_vma(insn.rd() as RegIdx).into(), + value: rd, + previous_cycle, + }), + memory_op: None, + } + } + pub fn cycle(&self) -> Cycle { self.cycle } diff --git a/ceno_rt/src/params.rs b/ceno_rt/src/params.rs index b8a30b24c..708dde472 100644 --- a/ceno_rt/src/params.rs +++ b/ceno_rt/src/params.rs @@ -1,3 +1,3 @@ pub const WORD_SIZE: usize = 4; -pub const INFO_OUT_ADDR: u32 = 0xC000_0000; \ No newline at end of file +pub const INFO_OUT_ADDR: u32 = 0xC000_0000; diff --git a/ceno_zkvm/Makefile.toml b/ceno_zkvm/Makefile.toml index 77d33a50a..86ac6b908 100644 --- a/ceno_zkvm/Makefile.toml +++ b/ceno_zkvm/Makefile.toml @@ -3,7 +3,7 @@ CARGO_MAKE_EXTEND_WORKSPACE_MAKEFILE = true CORE = { script = ["grep ^cpu\\scores /proc/cpuinfo | uniq | awk '{print $4}'"] } RAYON_NUM_THREADS = "${CORE}" -[tasks.riscv_add_flamegraph] +[tasks.riscv_opcodes_flamegraph] env = { "RUST_LOG" = "debug", "RAYON_NUM_THREADS" = "8"} command = "cargo" -args = ["run", "--package", "ceno_zkvm", "--release", "--example", "riscv_add"] +args = ["run", "--package", "ceno_zkvm", "--release", "--example", "riscv_opcodes"] diff --git a/ceno_zkvm/examples/riscv_add.rs b/ceno_zkvm/examples/riscv_opcodes.rs similarity index 67% rename from ceno_zkvm/examples/riscv_add.rs rename to ceno_zkvm/examples/riscv_opcodes.rs index c5ab38d26..9aa7631d6 100644 --- a/ceno_zkvm/examples/riscv_add.rs +++ b/ceno_zkvm/examples/riscv_opcodes.rs @@ -1,17 +1,22 @@ use std::{iter, time::Instant}; use ceno_zkvm::{ - instructions::riscv::arith::AddInstruction, scheme::prover::ZKVMProver, + instructions::riscv::{arith::AddInstruction, blt::BltInstruction}, + scheme::prover::ZKVMProver, tables::ProgramTableCircuit, }; use clap::Parser; use const_env::from_env; -use ceno_emul::{ByteAddr, InsnKind::ADD, StepRecord, VMState, CENO_PLATFORM}; +use ceno_emul::{ + ByteAddr, + InsnKind::{ADD, BLT}, + StepRecord, VMState, CENO_PLATFORM, +}; use ceno_zkvm::{ scheme::{constants::MAX_NUM_VARIABLES, verifier::ZKVMVerifier}, structs::{ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses}, - tables::U16TableCircuit, + tables::{AndTableCircuit, LtuTableCircuit, U16TableCircuit}, }; use goldilocks::GoldilocksExt2; use mpcs::{Basefold, BasefoldRSParams, PolynomialCommitmentScheme}; @@ -33,11 +38,11 @@ const RAYON_NUM_THREADS: usize = 8; #[allow(clippy::unusual_byte_groupings)] const ECALL_HALT: u32 = 0b_000000000000_00000_000_00000_1110011; #[allow(clippy::unusual_byte_groupings)] -const PROGRAM_ADD_LOOP: [u32; 4] = [ +const PROGRAM_CODE: [u32; 4] = [ // func7 rs2 rs1 f3 rd opcode 0b_0000000_00100_00001_000_00100_0110011, // add x4, x4, x1 <=> addi x4, x4, 1 0b_0000000_00011_00010_000_00011_0110011, // add x3, x3, x2 <=> addi x3, x3, -1 - 0b_1_111111_00000_00011_001_1100_1_1100011, // bne x3, x0, -8 + 0b_1_111111_00011_00000_100_1100_1_1100011, // blt x0, x3, -8 ECALL_HALT, // ecall halt ]; @@ -50,7 +55,7 @@ struct Args { start: u8, /// end round - #[arg(short, long, default_value_t = 22)] + #[arg(short, long, default_value_t = 9)] end: u8, } @@ -96,11 +101,17 @@ fn main() { let pcs_param = Pcs::setup(1 << MAX_NUM_VARIABLES).expect("Basefold PCS setup"); let (pp, vp) = Pcs::trim(&pcs_param, 1 << MAX_NUM_VARIABLES).expect("Basefold trim"); let mut zkvm_cs = ZKVMConstraintSystem::default(); + // opcode circuits let add_config = zkvm_cs.register_opcode_circuit::>(); - let range_config = zkvm_cs.register_table_circuit::>(); + let blt_config = zkvm_cs.register_opcode_circuit::(); + // tables + let u16_range_config = zkvm_cs.register_table_circuit::>(); + // let u1_range_config = zkvm_cs.register_table_circuit::>(); + let and_config = zkvm_cs.register_table_circuit::>(); + let ltu_config = zkvm_cs.register_table_circuit::>(); let prog_config = zkvm_cs.register_table_circuit::>(); - let program_add_loop: Vec = PROGRAM_ADD_LOOP + let program_code: Vec = PROGRAM_CODE .iter() .cloned() .chain(iter::repeat(ECALL_HALT)) @@ -108,15 +119,32 @@ fn main() { .collect(); let mut zkvm_fixed_traces = ZKVMFixedTraces::default(); zkvm_fixed_traces.register_opcode_circuit::>(&zkvm_cs); + zkvm_fixed_traces.register_opcode_circuit::(&zkvm_cs); + zkvm_fixed_traces.register_table_circuit::>( &zkvm_cs, - range_config.clone(), + u16_range_config.clone(), + &(), + ); + // zkvm_fixed_traces.register_table_circuit::>( + // &zkvm_cs, + // u1_range_config.clone(), + // &(), + // ); + zkvm_fixed_traces.register_table_circuit::>( + &zkvm_cs, + and_config.clone(), + &(), + ); + zkvm_fixed_traces.register_table_circuit::>( + &zkvm_cs, + ltu_config.clone(), &(), ); zkvm_fixed_traces.register_table_circuit::>( &zkvm_cs, prog_config.clone(), - &program_add_loop, + &program_code, ); let pk = zkvm_cs @@ -139,33 +167,60 @@ fn main() { vm.init_register_unsafe(1usize, 1); vm.init_register_unsafe(2usize, u32::MAX); // -1 in two's complement vm.init_register_unsafe(3usize, step_loop as u32); - for (i, inst) in program_add_loop.iter().enumerate() { + for (i, inst) in program_code.iter().enumerate() { vm.init_memory(pc_start + i, *inst); } - let records = vm + + let all_records = vm .iter_until_success() .collect::, _>>() .expect("vm exec failed") .into_iter() - .filter(|record| record.insn().kind().1 == ADD) .collect::>(); - tracing::info!("tracer generated {} ADD records", records.len()); + let mut add_records = Vec::new(); + let mut blt_records = Vec::new(); + all_records.iter().for_each(|record| { + let kind = record.insn().kind().1; + if kind == ADD { + add_records.push(record.clone()); + } else if kind == BLT { + blt_records.push(record.clone()); + } + }); + + tracing::info!( + "tracer generated {} ADD records, {} BLT records", + add_records.len(), + blt_records.len() + ); let mut zkvm_witness = ZKVMWitnesses::default(); // assign opcode circuits zkvm_witness - .assign_opcode_circuit::>(&zkvm_cs, &add_config, records) + .assign_opcode_circuit::>(&zkvm_cs, &add_config, add_records) + .unwrap(); + zkvm_witness + .assign_opcode_circuit::(&zkvm_cs, &blt_config, blt_records) .unwrap(); zkvm_witness.finalize_lk_multiplicities(); // assign table circuits zkvm_witness - .assign_table_circuit::>(&zkvm_cs, &range_config, &()) + .assign_table_circuit::>(&zkvm_cs, &u16_range_config, &()) + .unwrap(); + // zkvm_witness + // .assign_table_circuit::>(&zkvm_cs, &u1_range_config, &()) + // .unwrap(); + zkvm_witness + .assign_table_circuit::>(&zkvm_cs, &and_config, &()) + .unwrap(); + zkvm_witness + .assign_table_circuit::>(&zkvm_cs, <u_config, &()) .unwrap(); zkvm_witness .assign_table_circuit::>( &zkvm_cs, &prog_config, - &program_add_loop.len(), + &program_code.len(), ) .unwrap(); @@ -177,7 +232,7 @@ fn main() { .expect("create_proof failed"); println!( - "AddInstruction::create_proof, instance_num_vars = {}, time = {}", + "riscv_opcodes::create_proof, instance_num_vars = {}, time = {}", instance_num_vars, timer.elapsed().as_secs_f64() ); diff --git a/ceno_zkvm/src/chip_handler.rs b/ceno_zkvm/src/chip_handler.rs index d78bd67d7..755278c0c 100644 --- a/ceno_zkvm/src/chip_handler.rs +++ b/ceno_zkvm/src/chip_handler.rs @@ -8,6 +8,7 @@ use crate::{ pub mod general; pub mod global_state; +pub mod memory; pub mod register; pub mod utils; @@ -19,8 +20,7 @@ pub trait GlobalStateRegisterMachineChipOperations { /// The common representation of a register value. /// Format: `[u16; 2]`, least-significant-first. -#[derive(Debug, Clone)] -pub struct RegisterExpr(pub [Expression; 2]); +pub type RegisterExpr = [Expression; 2]; pub trait RegisterChipOperations, N: FnOnce() -> NR> { fn register_read( @@ -43,3 +43,31 @@ pub trait RegisterChipOperations, N: FnOnce( value: RegisterExpr, ) -> Result<(Expression, ExprLtConfig), ZKVMError>; } + +/// The common representation of a memory value. +/// Format: `[u16; 2]`, least-significant-first. +pub type MemoryExpr = [Expression; 2]; + +pub trait MemoryChipOperations, N: FnOnce() -> NR> { + #[allow(dead_code)] + fn memory_read( + &mut self, + name_fn: N, + memory_addr: &WitIn, + prev_ts: Expression, + ts: Expression, + value: crate::chip_handler::MemoryExpr, + ) -> Result<(Expression, ExprLtConfig), ZKVMError>; + + #[allow(clippy::too_many_arguments)] + #[allow(dead_code)] + fn memory_write( + &mut self, + name_fn: N, + memory_addr: &WitIn, + prev_ts: Expression, + ts: Expression, + prev_values: crate::chip_handler::MemoryExpr, + value: crate::chip_handler::MemoryExpr, + ) -> Result<(Expression, ExprLtConfig), ZKVMError>; +} diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index e2f7610bd..53c8ea847 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -146,6 +146,29 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { ) } + pub fn condition_require_equal( + &mut self, + name_fn: N, + cond: Expression, + target: Expression, + true_expr: Expression, + false_expr: Expression, + ) -> Result<(), ZKVMError> + where + NR: Into, + N: FnOnce() -> NR, + { + // cond * (true_expr) + (1 - cond) * false_expr + // => false_expr + cond * true_expr - cond * false_expr + self.namespace( + || "cond_require_equal", + |cb| { + let cond_target = false_expr.clone() + cond.clone() * true_expr - cond * false_expr; + cb.cs.require_zero(name_fn, target - cond_target) + }, + ) + } + pub(crate) fn assert_ux( &mut self, name_fn: N, @@ -222,6 +245,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { Ok(()) } + #[allow(dead_code)] pub(crate) fn assert_bit( &mut self, name_fn: N, @@ -232,7 +256,10 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { N: FnOnce() -> NR, { // TODO: Replace with `x * (1 - x)` or a multi-bit lookup similar to assert_u8_pair. - self.assert_u16(name_fn, expr * Expression::from(1 << 15)) + let items: Vec> = vec![(ROMType::U1 as usize).into(), expr]; + let rlc_record = self.rlc_chip_record(items); + self.lk_record(name_fn, rlc_record)?; + Ok(()) } /// Assert `rom_type(a, b) = c` and that `a, b, c` are all bytes. diff --git a/ceno_zkvm/src/chip_handler/memory.rs b/ceno_zkvm/src/chip_handler/memory.rs new file mode 100644 index 000000000..3aa922ca4 --- /dev/null +++ b/ceno_zkvm/src/chip_handler/memory.rs @@ -0,0 +1,105 @@ +use crate::{ + chip_handler::{MemoryChipOperations, MemoryExpr}, + circuit_builder::CircuitBuilder, + error::ZKVMError, + expression::{Expression, ToExpr, WitIn}, + instructions::riscv::config::ExprLtConfig, + structs::RAMType, +}; +use ff_ext::ExtensionField; + +impl<'a, E: ExtensionField, NR: Into, N: FnOnce() -> NR> MemoryChipOperations + for CircuitBuilder<'a, E> +{ + #[allow(dead_code)] + fn memory_read( + &mut self, + name_fn: N, + memory_addr: &WitIn, + prev_ts: Expression, + ts: Expression, + value: MemoryExpr, + ) -> Result<(Expression, ExprLtConfig), ZKVMError> { + self.namespace(name_fn, |cb| { + // READ (a, v, t) + let read_record = cb.rlc_chip_record( + [ + vec![Expression::::Constant(E::BaseField::from( + RAMType::Memory as u64, + ))], + vec![memory_addr.expr()], + value.to_vec(), + vec![prev_ts.clone()], + ] + .concat(), + ); + // Write (a, v, t) + let write_record = cb.rlc_chip_record( + [ + vec![Expression::::Constant(E::BaseField::from( + RAMType::Memory as u64, + ))], + vec![memory_addr.expr()], + value.to_vec(), + vec![ts.clone()], + ] + .concat(), + ); + cb.read_record(|| "read_record", read_record)?; + cb.write_record(|| "write_record", write_record)?; + + // assert prev_ts < current_ts + let lt_cfg = cb.less_than(|| "prev_ts < ts", prev_ts, ts.clone(), Some(true))?; + + let next_ts = ts + 1.into(); + + Ok((next_ts, lt_cfg)) + }) + } + + #[allow(dead_code)] + fn memory_write( + &mut self, + name_fn: N, + memory_addr: &WitIn, + prev_ts: Expression, + ts: Expression, + prev_values: MemoryExpr, + value: MemoryExpr, + ) -> Result<(Expression, ExprLtConfig), ZKVMError> { + self.namespace(name_fn, |cb| { + // READ (a, v, t) + let read_record = cb.rlc_chip_record( + [ + vec![Expression::::Constant(E::BaseField::from( + RAMType::Memory as u64, + ))], + vec![memory_addr.expr()], + prev_values.to_vec(), + vec![prev_ts.clone()], + ] + .concat(), + ); + // Write (a, v, t) + let write_record = cb.rlc_chip_record( + [ + vec![Expression::::Constant(E::BaseField::from( + RAMType::Memory as u64, + ))], + vec![memory_addr.expr()], + value.to_vec(), + vec![ts.clone()], + ] + .concat(), + ); + cb.read_record(|| "read_record", read_record)?; + cb.write_record(|| "write_record", write_record)?; + + let lt_cfg = cb.less_than(|| "prev_ts < ts", prev_ts, ts.clone(), Some(true))?; + + let next_ts = ts + 1.into(); + + Ok((next_ts, lt_cfg)) + }) + } +} diff --git a/ceno_zkvm/src/chip_handler/register.rs b/ceno_zkvm/src/chip_handler/register.rs index c283cc052..dd1709247 100644 --- a/ceno_zkvm/src/chip_handler/register.rs +++ b/ceno_zkvm/src/chip_handler/register.rs @@ -29,7 +29,7 @@ impl<'a, E: ExtensionField, NR: Into, N: FnOnce() -> NR> RegisterChipOpe RAMType::Register as u64, ))], vec![register_id.expr()], - value.0.to_vec(), + value.to_vec(), vec![prev_ts.clone()], ] .concat(), @@ -41,7 +41,7 @@ impl<'a, E: ExtensionField, NR: Into, N: FnOnce() -> NR> RegisterChipOpe RAMType::Register as u64, ))], vec![register_id.expr()], - value.0.to_vec(), + value.to_vec(), vec![ts.clone()], ] .concat(), @@ -75,7 +75,7 @@ impl<'a, E: ExtensionField, NR: Into, N: FnOnce() -> NR> RegisterChipOpe RAMType::Register as u64, ))], vec![register_id.expr()], - prev_values.0.to_vec(), + prev_values.to_vec(), vec![prev_ts.clone()], ] .concat(), @@ -87,7 +87,7 @@ impl<'a, E: ExtensionField, NR: Into, N: FnOnce() -> NR> RegisterChipOpe RAMType::Register as u64, ))], vec![register_id.expr()], - value.0.to_vec(), + value.to_vec(), vec![ts.clone()], ] .concat(), diff --git a/ceno_zkvm/src/circuit_builder.rs b/ceno_zkvm/src/circuit_builder.rs index 6b616b1fa..773136e66 100644 --- a/ceno_zkvm/src/circuit_builder.rs +++ b/ceno_zkvm/src/circuit_builder.rs @@ -201,8 +201,8 @@ impl ConstraintSystem { assert_eq!( rlc_record.degree(), 1, - "rlc record degree {} != 1", - rlc_record.degree() + "rlc lk_record degree ({})", + name_fn().into() ); self.lk_expressions.push(rlc_record); let path = self.ns.compute_path(name_fn().into()); @@ -223,8 +223,8 @@ impl ConstraintSystem { assert_eq!( rlc_record.degree(), 1, - "rlc record degree {} != 1", - rlc_record.degree() + "rlc lk_table_record degree ({})", + name_fn().into() ); self.lk_table_expressions.push(LogupTableExpression { values: rlc_record, @@ -244,8 +244,8 @@ impl ConstraintSystem { assert_eq!( rlc_record.degree(), 1, - "rlc record degree {} != 1", - rlc_record.degree() + "rlc read_record degree ({})", + name_fn().into() ); self.r_expressions.push(rlc_record); let path = self.ns.compute_path(name_fn().into()); @@ -261,8 +261,8 @@ impl ConstraintSystem { assert_eq!( rlc_record.degree(), 1, - "rlc record degree {} != 1", - rlc_record.degree() + "rlc write_record degree ({})", + name_fn().into() ); self.w_expressions.push(rlc_record); let path = self.ns.compute_path(name_fn().into()); @@ -284,10 +284,13 @@ impl ConstraintSystem { let path = self.ns.compute_path(name_fn().into()); self.assert_zero_expressions_namespace_map.push(path); } else { - assert!( - assert_zero_expr.is_monomial_form(), - "only support sumcheck in monomial form" - ); + let assert_zero_expr = if assert_zero_expr.is_monomial_form() { + assert_zero_expr + } else { + let e = assert_zero_expr.to_monomial_form(); + assert!(e.is_monomial_form(), "failed to put into monomial form"); + e + }; self.max_non_lc_degree = self.max_non_lc_degree.max(assert_zero_expr.degree()); self.assert_zero_sumcheck_expressions.push(assert_zero_expr); let path = self.ns.compute_path(name_fn().into()); diff --git a/ceno_zkvm/src/error.rs b/ceno_zkvm/src/error.rs index d1fc023a6..7b948336f 100644 --- a/ceno_zkvm/src/error.rs +++ b/ceno_zkvm/src/error.rs @@ -10,6 +10,7 @@ pub enum ZKVMError { CircuitError, UtilError(UtilError), WitnessNotFound(String), + InvalidWitness(String), VKNotFound(String), FixedTraceNotFound(String), VerifyError(String), diff --git a/ceno_zkvm/src/expression.rs b/ceno_zkvm/src/expression.rs index 9b8f29679..4f3e82612 100644 --- a/ceno_zkvm/src/expression.rs +++ b/ceno_zkvm/src/expression.rs @@ -1,3 +1,5 @@ +mod monomial; + use std::{ cmp::max, mem::MaybeUninit, @@ -14,7 +16,7 @@ use crate::{ structs::{ChallengeId, WitnessId}, }; -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Eq)] pub enum Expression { /// WitIn(Id) WitIn(WitnessId), @@ -95,6 +97,10 @@ impl Expression { Self::is_monomial_form_inner(MonomialState::SumTerm, self) } + pub fn to_monomial_form(&self) -> Self { + self.to_monomial_form_inner() + } + pub fn unpack_sum(&self) -> Option<(Expression, Expression)> { match self { Expression::Sum(a, b) => Some((a.deref().clone(), b.deref().clone())), @@ -109,8 +115,10 @@ impl Expression { Expression::Constant(c) => *c == E::BaseField::ZERO, Expression::Sum(a, b) => Self::is_zero_expr(a) && Self::is_zero_expr(b), Expression::Product(a, b) => Self::is_zero_expr(a) || Self::is_zero_expr(b), - Expression::ScaledSum(_, _, _) => false, - Expression::Challenge(_, _, _, _) => false, + Expression::ScaledSum(x, a, b) => { + (Self::is_zero_expr(x) || Self::is_zero_expr(a)) && Self::is_zero_expr(b) + } + Expression::Challenge(_, _, scalar, offset) => *scalar == E::ZERO && *offset == E::ZERO, } } @@ -137,7 +145,9 @@ impl Expression { && Self::is_monomial_form_inner(MonomialState::ProductTerm, b) } (Expression::ScaledSum(_, _, _), MonomialState::SumTerm) => true, - (Expression::ScaledSum(_, _, b), MonomialState::ProductTerm) => Self::is_zero_expr(b), + (Expression::ScaledSum(x, a, b), MonomialState::ProductTerm) => { + Self::is_zero_expr(x) || Self::is_zero_expr(a) || Self::is_zero_expr(b) + } } } } @@ -341,31 +351,42 @@ impl Mul for Expression { if challenge_id1 == challenge_id2 { // (s1 * s2 * c1^(pow1 + pow2) + offset2 * s1 * c1^(pow1) + offset1 * s2 * c2^(pow2)) // + offset1 * offset2 - Expression::Sum( - Box::new(Expression::Sum( - // (s1 * s2 * c1^(pow1 + pow2) + offset1 * offset2 + + // (s1 * s2 * c1^(pow1 + pow2) + offset1 * offset2 + let mut result = Expression::Challenge( + *challenge_id1, + pow1 + pow2, + *s1 * s2, + *offset1 * offset2, + ); + + // offset2 * s1 * c1^(pow1) + if *s1 != E::ZERO && *offset2 != E::ZERO { + result = Expression::Sum( + Box::new(result), Box::new(Expression::Challenge( *challenge_id1, - pow1 + pow2, - *s1 * s2, - *offset1 * offset2, + *pow1, + *offset2 * *s1, + E::ZERO, )), - // offset2 * s1 * c1^(pow1) + ); + } + + // offset1 * s2 * c2^(pow2)) + if *s2 != E::ZERO && *offset1 != E::ZERO { + result = Expression::Sum( + Box::new(result), Box::new(Expression::Challenge( *challenge_id1, - *pow1, - *offset2, + *pow2, + *offset1 * *s2, E::ZERO, )), - )), - // offset1 * s2 * c2^(pow2)) - Box::new(Expression::Challenge( - *challenge_id1, - *pow2, - *offset1, - E::ZERO, - )), - ) + ); + } + + result } else { Expression::Product(Box::new(self), Box::new(rhs)) } @@ -545,10 +566,10 @@ mod tests { E::ONE * E::ONE, )), // offset2 * s1 * c1^(pow1) - Box::new(Expression::Challenge(0, 3, E::ONE, E::ZERO,)), + Box::new(Expression::Challenge(0, 3, 2.into(), E::ZERO)), )), // offset1 * s2 * c2^(pow2)) - Box::new(Expression::Challenge(0, 2, E::ONE, E::ZERO,)), + Box::new(Expression::Challenge(0, 2, 2.into(), E::ZERO)), ) ); } diff --git a/ceno_zkvm/src/expression/monomial.rs b/ceno_zkvm/src/expression/monomial.rs new file mode 100644 index 000000000..c030e0620 --- /dev/null +++ b/ceno_zkvm/src/expression/monomial.rs @@ -0,0 +1,235 @@ +use ff_ext::ExtensionField; +use goldilocks::SmallField; +use std::cmp::Ordering; + +use super::Expression; +use Expression::*; + +impl Expression { + pub(super) fn to_monomial_form_inner(&self) -> Self { + Self::sum_terms(Self::combine(self.distribute())) + } + + fn distribute(&self) -> Vec> { + match self { + Constant(_) => { + vec![Term { + coeff: self.clone(), + vars: vec![], + }] + } + + Fixed(_) | WitIn(_) | Challenge(..) => { + vec![Term { + coeff: Expression::ONE, + vars: vec![self.clone()], + }] + } + + Sum(a, b) => { + let mut res = a.distribute(); + res.extend(b.distribute()); + res + } + + Product(a, b) => { + let a = a.distribute(); + let b = b.distribute(); + let mut res = vec![]; + for a in a { + for b in &b { + res.push(Term { + coeff: a.coeff.clone() * b.coeff.clone(), + vars: a.vars.iter().chain(b.vars.iter()).cloned().collect(), + }); + } + } + res + } + + ScaledSum(x, a, b) => { + let x = x.distribute(); + let a = a.distribute(); + let mut res = b.distribute(); + for x in x { + for a in &a { + res.push(Term { + coeff: x.coeff.clone() * a.coeff.clone(), + vars: x.vars.iter().chain(a.vars.iter()).cloned().collect(), + }); + } + } + res + } + } + } + + fn combine(terms: Vec>) -> Vec> { + let mut res: Vec> = vec![]; + for mut term in terms { + term.vars.sort(); + + if let Some(res_term) = res.iter_mut().find(|res_term| res_term.vars == term.vars) { + res_term.coeff = res_term.coeff.clone() + term.coeff.clone(); + } else { + res.push(term); + } + } + res + } + + fn sum_terms(terms: Vec>) -> Self { + terms + .into_iter() + .map(|term| term.vars.into_iter().fold(term.coeff, |a, b| a * b)) + .reduce(|a, b| a + b) + .unwrap_or(Expression::ZERO) + } +} + +#[derive(Clone, Debug)] +struct Term { + coeff: Expression, + vars: Vec>, +} + +// Define a lexicographic order for expressions. It compares the types first, then the arguments left-to-right. +impl Ord for Expression { + fn cmp(&self, other: &Self) -> Ordering { + use Ordering::*; + + match (self, other) { + (Fixed(a), Fixed(b)) => a.cmp(b), + (WitIn(a), WitIn(b)) => a.cmp(b), + (Constant(a), Constant(b)) => cmp_field(a, b), + (Challenge(a, b, c, d), Challenge(e, f, g, h)) => { + let cmp = a.cmp(e); + if cmp == Equal { + let cmp = b.cmp(f); + if cmp == Equal { + let cmp = cmp_ext(c, g); + if cmp == Equal { cmp_ext(d, h) } else { cmp } + } else { + cmp + } + } else { + cmp + } + } + (Sum(a, b), Sum(c, d)) => { + let cmp = a.cmp(c); + if cmp == Equal { b.cmp(d) } else { cmp } + } + (Product(a, b), Product(c, d)) => { + let cmp = a.cmp(c); + if cmp == Equal { b.cmp(d) } else { cmp } + } + (ScaledSum(x, a, b), ScaledSum(y, c, d)) => { + let cmp = x.cmp(y); + if cmp == Equal { + let cmp = a.cmp(c); + if cmp == Equal { b.cmp(d) } else { cmp } + } else { + cmp + } + } + (Fixed(_), _) => Less, + (WitIn(_), _) => Less, + (Constant(_), _) => Less, + (Challenge(..), _) => Less, + (Sum(..), _) => Less, + (Product(..), _) => Less, + (ScaledSum(..), _) => Less, + } + } +} + +impl PartialOrd for Expression { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +fn cmp_field(a: &F, b: &F) -> Ordering { + a.to_canonical_u64().cmp(&b.to_canonical_u64()) +} + +fn cmp_ext(a: &E, b: &E) -> Ordering { + let a = a.as_bases().iter().map(|f| f.to_canonical_u64()); + let b = b.as_bases().iter().map(|f| f.to_canonical_u64()); + a.cmp(b) +} + +#[cfg(test)] +mod tests { + use crate::{expression::Fixed as FixedS, scheme::utils::eval_by_expr_with_fixed}; + + use super::*; + use ff::Field; + use goldilocks::{Goldilocks as F, GoldilocksExt2 as E}; + use rand_chacha::{rand_core::SeedableRng, ChaChaRng}; + + #[test] + fn test_to_monomial_form() { + use Expression::*; + + let eval = make_eval(); + + let a = || Fixed(FixedS(0)); + let b = || Fixed(FixedS(1)); + let c = || Fixed(FixedS(2)); + let x = || WitIn(0); + let y = || WitIn(1); + let z = || WitIn(2); + let n = || Constant(104.into()); + let m = || Constant(-F::from(599)); + let r = || Challenge(0, 1, E::from(1), E::from(0)); + + let test_exprs: &[Expression] = &[ + a() * x() * x(), + a(), + x(), + n(), + r(), + a() + b() + x() + y() + n() + m() + r(), + a() * x() * n() * r(), + x() * y() * z(), + (x() + y() + a()) * b() * (y() + z()) + c(), + (r() * x() + n() + z()) * m() * y(), + (b() + y() + m() * z()) * (x() + y() + c()), + a() * r() * x(), + ]; + + for factored in test_exprs { + let monomials = factored.to_monomial_form_inner(); + assert!(monomials.is_monomial_form()); + + // Check that the two forms are equivalent (Schwartz-Zippel test). + let factored = eval(&factored); + let monomials = eval(&monomials); + assert_eq!(monomials, factored); + } + } + + /// Create an evaluator of expressions. Fixed, witness, and challenge values are pseudo-random. + fn make_eval() -> impl Fn(&Expression) -> E { + // Create a deterministic RNG from a seed. + let mut rng = ChaChaRng::from_seed([12u8; 32]); + let fixed = vec![ + E::random(&mut rng), + E::random(&mut rng), + E::random(&mut rng), + ]; + let witnesses = vec![ + E::random(&mut rng), + E::random(&mut rng), + E::random(&mut rng), + ]; + let challenges = vec![ + E::random(&mut rng), + E::random(&mut rng), + E::random(&mut rng), + ]; + move |expr: &Expression| eval_by_expr_with_fixed(&fixed, &witnesses, &challenges, expr) + } +} diff --git a/ceno_zkvm/src/gadgets/is_zero.rs b/ceno_zkvm/src/gadgets/is_zero.rs new file mode 100644 index 000000000..b4edaa162 --- /dev/null +++ b/ceno_zkvm/src/gadgets/is_zero.rs @@ -0,0 +1,80 @@ +use std::mem::MaybeUninit; + +use ff_ext::ExtensionField; +use goldilocks::SmallField; + +use crate::{ + circuit_builder::CircuitBuilder, + error::ZKVMError, + expression::{Expression, ToExpr, WitIn}, + set_val, +}; + +pub struct IsZeroConfig { + is_zero: WitIn, + inverse: WitIn, +} + +impl IsZeroConfig { + pub fn expr(&self) -> Expression { + self.is_zero.expr() + } + + pub fn construct_circuit( + cb: &mut CircuitBuilder, + x: Expression, + ) -> Result { + let is_zero = cb.create_witin(|| "is_zero")?; + let inverse = cb.create_witin(|| "inv")?; + + // x==0 => is_zero=1 + cb.require_one(|| "is_zero_1", is_zero.expr() + x.clone() * inverse.expr())?; + + // x!=0 => is_zero=0 + cb.require_zero(|| "is_zero_0", is_zero.expr() * x.clone())?; + + Ok(IsZeroConfig { is_zero, inverse }) + } + + pub fn assign_instance( + &self, + instance: &mut [MaybeUninit], + x: F, + ) -> Result<(), ZKVMError> { + let (is_zero, inverse) = if x.is_zero_vartime() { + (F::ONE, F::ZERO) + } else { + (F::ZERO, x.invert().expect("not zero")) + }; + + set_val!(instance, self.is_zero, is_zero); + set_val!(instance, self.inverse, inverse); + + Ok(()) + } +} + +pub struct IsEqualConfig(IsZeroConfig); + +impl IsEqualConfig { + pub fn expr(&self) -> Expression { + self.0.expr() + } + + pub fn construct_circuit( + cb: &mut CircuitBuilder, + a: Expression, + b: Expression, + ) -> Result { + Ok(IsEqualConfig(IsZeroConfig::construct_circuit(cb, a - b)?)) + } + + pub fn assign_instance( + &self, + instance: &mut [MaybeUninit], + a: F, + b: F, + ) -> Result<(), ZKVMError> { + self.0.assign_instance(instance, a - b) + } +} diff --git a/ceno_zkvm/src/gadgets/mod.rs b/ceno_zkvm/src/gadgets/mod.rs new file mode 100644 index 000000000..6851cb9a6 --- /dev/null +++ b/ceno_zkvm/src/gadgets/mod.rs @@ -0,0 +1,2 @@ +mod is_zero; +pub use is_zero::{IsEqualConfig, IsZeroConfig}; diff --git a/ceno_zkvm/src/instructions/riscv.rs b/ceno_zkvm/src/instructions/riscv.rs index 1e3630988..f3901146d 100644 --- a/ceno_zkvm/src/instructions/riscv.rs +++ b/ceno_zkvm/src/instructions/riscv.rs @@ -1,13 +1,16 @@ use ceno_emul::InsnKind; pub mod arith; +mod b_insn; pub mod blt; +pub mod branch; pub mod config; pub mod constants; +pub mod divu; +mod i_insn; pub mod logic; - -mod b_insn; mod r_insn; +pub mod shift_imm; #[cfg(test)] mod test; diff --git a/ceno_zkvm/src/instructions/riscv/b_insn.rs b/ceno_zkvm/src/instructions/riscv/b_insn.rs index 90e2bbcee..8194d0e80 100644 --- a/ceno_zkvm/src/instructions/riscv/b_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/b_insn.rs @@ -38,6 +38,7 @@ use core::mem::MaybeUninit; #[derive(Debug)] pub struct BInstructionConfig { pc: WitIn, + next_pc: WitIn, ts: WitIn, rs1_id: WitIn, rs2_id: WitIn, @@ -70,11 +71,11 @@ impl BInstructionConfig { circuit_builder.lk_fetch(&InsnRecord::new( pc.expr(), (insn_kind.codes().opcode as usize).into(), - 0.into(), // TODO: Make sure the program table sets rd=0. + 0.into(), (insn_kind.codes().func3 as usize).into(), rs1_id.expr(), rs2_id.expr(), - imm.expr(), // TODO: Make sure the program table sets the full immediate. + imm.expr(), ))?; // Register state. @@ -98,13 +99,20 @@ impl BInstructionConfig { )?; // State out. - let pc_offset = branch_taken_bit * (imm.expr() - PC_STEP_SIZE.into()) + PC_STEP_SIZE.into(); - let next_pc = pc.expr() + pc_offset; + let next_pc = { + let pc_offset = branch_taken_bit.clone() * imm.expr() + - branch_taken_bit * PC_STEP_SIZE.into() + + PC_STEP_SIZE.into(); + let next_pc = circuit_builder.create_witin(|| "next_pc")?; + circuit_builder.require_equal(|| "pc_branch", next_pc.expr(), pc.expr() + pc_offset)?; + next_pc + }; let next_ts = cur_ts.expr() + 4.into(); - circuit_builder.state_out(next_pc, next_ts)?; + circuit_builder.state_out(next_pc.expr(), next_ts)?; Ok(BInstructionConfig { pc, + next_pc, ts: cur_ts, rs1_id, rs2_id, @@ -122,14 +130,19 @@ impl BInstructionConfig { lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - // State in. + // State. set_val!(instance, self.pc, step.pc().before.0 as u64); + set_val!(instance, self.next_pc, step.pc().after.0 as u64); set_val!(instance, self.ts, step.cycle()); // Register indexes and immediate. set_val!(instance, self.rs1_id, step.insn().rs1() as u64); set_val!(instance, self.rs2_id, step.insn().rs2() as u64); - set_val!(instance, self.imm, step.insn().imm_b() as u64); + set_val!( + instance, + self.imm, + InsnRecord::imm_or_funct7_field::(&step.insn()) + ); // Fetch the instruction. lk_multiplicity.fetch(step.pc().before.0); diff --git a/ceno_zkvm/src/instructions/riscv/blt.rs b/ceno_zkvm/src/instructions/riscv/blt.rs index d21b50ef5..dd9eb475b 100644 --- a/ceno_zkvm/src/instructions/riscv/blt.rs +++ b/ceno_zkvm/src/instructions/riscv/blt.rs @@ -1,215 +1,34 @@ use ceno_emul::InsnKind; -use goldilocks::SmallField; -use std::mem::MaybeUninit; use ff_ext::ExtensionField; use crate::{ - chip_handler::{GlobalStateRegisterMachineChipOperations, RegisterChipOperations}, circuit_builder::CircuitBuilder, - create_witin_from_expr, error::ZKVMError, - expression::{ToExpr, WitIn}, + expression::ToExpr, instructions::{ - riscv::config::{ExprLtInput, UIntLtConfig, UIntLtInput}, + riscv::config::{UIntLtConfig, UIntLtInput}, Instruction, }, - set_val, - utils::i64_to_base, + utils::{i64_to_base, split_to_u8}, witness::LkMultiplicity, }; -use super::{ - config::ExprLtConfig, - constants::{UInt, UInt8, PC_STEP_SIZE}, - RIVInstruction, -}; +use super::{b_insn::BInstructionConfig, constants::UInt8, RIVInstruction}; pub struct BltInstruction; pub struct InstructionConfig { - pub pc: WitIn, - pub next_pc: WitIn, - pub ts: WitIn, - pub imm: WitIn, - pub lhs_limb8: UInt8, - pub rhs_limb8: UInt8, - pub rs1_id: WitIn, - pub rs2_id: WitIn, - pub prev_rs1_ts: WitIn, - pub prev_rs2_ts: WitIn, + pub b_insn: BInstructionConfig, + pub read_rs1: UInt8, + pub read_rs2: UInt8, pub is_lt: UIntLtConfig, - pub lt_rs1_cfg: ExprLtConfig, - pub lt_rs2_cfg: ExprLtConfig, -} - -pub struct BltInput { - pub pc: u16, - pub ts: u16, - pub imm: i16, // rust don't have i12 - pub lhs_limb8: Vec, - pub rhs_limb8: Vec, - pub rs1_id: u8, - pub rs2_id: u8, - pub prev_rs1_ts: u16, - pub prev_rs2_ts: u16, -} - -impl BltInput { - /// TODO: refactor after formalize the interface of opcode inputs - pub fn assign>( - &self, - config: &InstructionConfig, - instance: &mut [MaybeUninit], - lk_multiplicity: &mut LkMultiplicity, - ) { - assert!(!self.lhs_limb8.is_empty() && (self.lhs_limb8.len() == self.rhs_limb8.len())); - // TODO: add boundary check for witin - let lt_input = UIntLtInput { - lhs_limbs: &self.lhs_limb8, - rhs_limbs: &self.rhs_limb8, - }; - let is_lt = lt_input.assign(instance, &config.is_lt); - - set_val!(instance, config.pc, { i64_to_base::(self.pc as i64) }); - set_val!(instance, config.next_pc, { - if is_lt { - i64_to_base::(self.pc as i64 + self.imm as i64) - } else { - i64_to_base::(self.pc as i64 + PC_STEP_SIZE as i64) - } - }); - set_val!(instance, config.ts, { i64_to_base::(self.ts as i64) }); - set_val!(instance, config.imm, { i64_to_base::(self.imm as i64) }); - set_val!(instance, config.rs1_id, { - i64_to_base::(self.rs1_id as i64) - }); - set_val!(instance, config.rs2_id, { - i64_to_base::(self.rs2_id as i64) - }); - set_val!(instance, config.prev_rs1_ts, { - i64_to_base::(self.prev_rs1_ts as i64) - }); - set_val!(instance, config.prev_rs2_ts, { - i64_to_base::(self.prev_rs2_ts as i64) - }); - - config.lhs_limb8.assign_limbs(instance, { - self.lhs_limb8 - .iter() - .map(|&limb| i64_to_base::(limb as i64)) - .collect() - }); - config.rhs_limb8.assign_limbs(instance, { - self.rhs_limb8 - .iter() - .map(|&limb| i64_to_base::(limb as i64)) - .collect() - }); - ExprLtInput { - lhs: self.prev_rs1_ts as u64, - rhs: self.ts as u64, - } - .assign(instance, &config.lt_rs1_cfg, lk_multiplicity); - ExprLtInput { - lhs: self.prev_rs2_ts as u64, - rhs: (self.ts + 1) as u64, - } - .assign(instance, &config.lt_rs2_cfg, lk_multiplicity); - } - - pub fn random() -> Self { - use ark_std::{rand::Rng, test_rng}; - let mut rng = test_rng(); - - // hack to generate valid inputs - let ts_bound: u16 = rng.gen_range(100..1000); - let pc_bound: u16 = rng.gen_range(100..1000); - - Self { - pc: rng.gen_range(pc_bound..(1 << 15)), - ts: rng.gen_range(ts_bound..(1 << 15)), - imm: rng.gen_range(-(pc_bound as i16)..2047), - // this is for riscv32 inputs - lhs_limb8: (0..4).map(|_| rng.gen()).collect(), - rhs_limb8: (0..4).map(|_| rng.gen()).collect(), - rs1_id: rng.gen(), - rs2_id: rng.gen(), - prev_rs1_ts: rng.gen_range(0..ts_bound), - prev_rs2_ts: rng.gen_range(0..ts_bound), - } - } } impl RIVInstruction for BltInstruction { const INST_KIND: InsnKind = InsnKind::BLT; } -/// if (rs1 < rs2) PC += sext(imm) -fn blt_gadget( - circuit_builder: &mut CircuitBuilder, -) -> Result, ZKVMError> { - let pc = circuit_builder.create_witin(|| "pc")?; - // imm is already sext(imm) from instruction - let imm = circuit_builder.create_witin(|| "imm")?; - let cur_ts = circuit_builder.create_witin(|| "ts")?; - circuit_builder.state_in(pc.expr(), cur_ts.expr())?; - - // TODO: constraint rs1_id, rs2_id by bytecode lookup - let rs1_id = circuit_builder.create_witin(|| "rs1_id")?; - let rs2_id = circuit_builder.create_witin(|| "rs2_id")?; - - let lhs_limb8 = UInt8::new(|| "lhs_limb8", circuit_builder)?; - let rhs_limb8 = UInt8::new(|| "rhs_limb8", circuit_builder)?; - - let is_lt = lhs_limb8.lt_limb8(circuit_builder, &rhs_limb8)?; - - // update pc - let next_pc = pc.expr() + is_lt.is_lt.expr() * imm.expr() + PC_STEP_SIZE.into() - - is_lt.is_lt.expr() * PC_STEP_SIZE.into(); - - // update ts - let prev_rs1_ts = circuit_builder.create_witin(|| "prev_rs1_ts")?; - let prev_rs2_ts = circuit_builder.create_witin(|| "prev_rs2_ts")?; - let lhs = UInt::from_u8_limbs(&lhs_limb8)?; - let rhs = UInt::from_u8_limbs(&rhs_limb8)?; - - let (ts, lt_rs1_cfg) = circuit_builder.register_read( - || "read ts for rs1", - &rs1_id, - prev_rs1_ts.expr(), - cur_ts.expr(), - lhs.register_expr(), - )?; - let (ts, lt_rs2_cfg) = circuit_builder.register_read( - || "read ts for rs2", - &rs2_id, - prev_rs2_ts.expr(), - ts, - rhs.register_expr(), - )?; - - let next_pc = create_witin_from_expr!(|| "next_pc", circuit_builder, false, next_pc)?; - let next_ts = ts + 1.into(); - circuit_builder.state_out(next_pc.expr(), next_ts)?; - - Ok(InstructionConfig { - pc, - next_pc, - ts: cur_ts, - lhs_limb8, - rhs_limb8, - imm, - rs1_id, - rs2_id, - prev_rs1_ts, - prev_rs2_ts, - is_lt, - lt_rs1_cfg, - lt_rs2_cfg, - }) -} - impl Instruction for BltInstruction { // const NAME: &'static str = "BLT"; fn name() -> String { @@ -219,18 +38,56 @@ impl Instruction for BltInstruction { fn construct_circuit( circuit_builder: &mut CircuitBuilder, ) -> Result, ZKVMError> { - blt_gadget::(circuit_builder) + let read_rs1 = UInt8::new_unchecked(|| "rs1_limbs", circuit_builder)?; + let read_rs2 = UInt8::new_unchecked(|| "rs2_limbs", circuit_builder)?; + let is_lt = read_rs1.lt_limb8(circuit_builder, &read_rs2)?; + + let b_insn = BInstructionConfig::construct_circuit( + circuit_builder, + Self::INST_KIND, + read_rs1.register_expr(), + read_rs2.register_expr(), + is_lt.is_lt.expr(), + )?; + + Ok(InstructionConfig { + b_insn, + read_rs1, + read_rs2, + is_lt, + }) } fn assign_instance( config: &Self::InstructionConfig, instance: &mut [std::mem::MaybeUninit], lk_multiplicity: &mut LkMultiplicity, - _step: &ceno_emul::StepRecord, + step: &ceno_emul::StepRecord, ) -> Result<(), ZKVMError> { - // take input from _step - let input = BltInput::random(); - input.assign(config, instance, lk_multiplicity); + let rs1_limbs = split_to_u8(step.rs1().unwrap().value); + let rs2_limbs = split_to_u8(step.rs2().unwrap().value); + config.read_rs1.assign_limbs(instance, { + rs1_limbs + .iter() + .map(|&limb| i64_to_base::(limb as i64)) + .collect() + }); + config.read_rs2.assign_limbs(instance, { + rs2_limbs + .iter() + .map(|&limb| i64_to_base::(limb as i64)) + .collect() + }); + let lt_input = UIntLtInput { + lhs_limbs: &rs1_limbs, + rhs_limbs: &rs2_limbs, + }; + lt_input.assign(instance, &config.is_lt, lk_multiplicity); + + config + .b_insn + .assign_instance::(instance, lk_multiplicity, step)?; + Ok(()) } } @@ -243,7 +100,11 @@ mod test { use itertools::Itertools; use multilinear_extensions::mle::IntoMLEs; - use crate::{circuit_builder::ConstraintSystem, scheme::mock_prover::MockProver}; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::Instruction, + scheme::mock_prover::{MockProver, MOCK_PC_BLT, MOCK_PROGRAM}, + }; #[test] fn test_blt_circuit() -> Result<(), ZKVMError> { @@ -253,11 +114,17 @@ mod test { let num_wits = circuit_builder.cs.num_witin as usize; // generate mock witness - let num_instances = 1 << 4; let (raw_witin, _) = BltInstruction::assign_instances( &config, num_wits, - vec![StepRecord::default(); num_instances], + vec![StepRecord::new_b_instruction( + 3, + MOCK_PC_BLT, + MOCK_PROGRAM[8], + 0x20, + 0x21, + 0, + )], ) .unwrap(); @@ -273,11 +140,4 @@ mod test { ); Ok(()) } - - fn bench_blt_instruction_helper(_instance_num_vars: usize) {} - - #[test] - fn bench_blt_instruction() { - bench_blt_instruction_helper::(10); - } } diff --git a/ceno_zkvm/src/instructions/riscv/branch.rs b/ceno_zkvm/src/instructions/riscv/branch.rs new file mode 100644 index 000000000..4a5f159b1 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/branch.rs @@ -0,0 +1,19 @@ +mod beq_circuit; +use super::RIVInstruction; +use beq_circuit::BeqCircuit; +use ceno_emul::InsnKind; + +#[cfg(test)] +mod test; + +pub struct BeqOp; +impl RIVInstruction for BeqOp { + const INST_KIND: InsnKind = InsnKind::BEQ; +} +pub type BeqInstruction = BeqCircuit; + +pub struct BneOp; +impl RIVInstruction for BneOp { + const INST_KIND: InsnKind = InsnKind::BNE; +} +pub type BneInstruction = BeqCircuit; diff --git a/ceno_zkvm/src/instructions/riscv/branch/beq_circuit.rs b/ceno_zkvm/src/instructions/riscv/branch/beq_circuit.rs new file mode 100644 index 000000000..2e4371d26 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/branch/beq_circuit.rs @@ -0,0 +1,97 @@ +use std::{marker::PhantomData, mem::MaybeUninit}; + +use ceno_emul::{InsnKind, StepRecord}; +use ff_ext::ExtensionField; + +use crate::{ + circuit_builder::CircuitBuilder, + error::ZKVMError, + expression::Expression, + gadgets::IsEqualConfig, + instructions::{ + riscv::{b_insn::BInstructionConfig, constants::UInt, RIVInstruction}, + Instruction, + }, + witness::LkMultiplicity, + Value, +}; + +pub struct BeqConfig { + b_insn: BInstructionConfig, + + // TODO: Limb decomposition is not necessary. Replace with a single witness. + rs1_read: UInt, + rs2_read: UInt, + + equal: IsEqualConfig, +} + +pub struct BeqCircuit(PhantomData<(E, I)>); + +impl Instruction for BeqCircuit { + type InstructionConfig = BeqConfig; + + fn name() -> String { + format!("{:?}", I::INST_KIND) + } + + fn construct_circuit( + circuit_builder: &mut CircuitBuilder, + ) -> Result { + let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; + let rs2_read = UInt::new_unchecked(|| "rs2_read", circuit_builder)?; + + let equal = + IsEqualConfig::construct_circuit(circuit_builder, rs2_read.value(), rs1_read.value())?; + + let branch_taken_bit = match I::INST_KIND { + InsnKind::BEQ => equal.expr(), + InsnKind::BNE => Expression::ONE - equal.expr(), + _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), + }; + + let b_insn = BInstructionConfig::construct_circuit( + circuit_builder, + I::INST_KIND, + rs1_read.register_expr(), + rs2_read.register_expr(), + branch_taken_bit, + )?; + + Ok(BeqConfig { + b_insn, + rs1_read, + rs2_read, + equal, + }) + } + + fn assign_instance( + config: &Self::InstructionConfig, + instance: &mut [MaybeUninit<::BaseField>], + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + config + .b_insn + .assign_instance::(instance, lk_multiplicity, step)?; + + let rs1_read = step.rs1().unwrap().value; + config + .rs1_read + .assign_limbs(instance, Value::new_unchecked(rs1_read).u16_fields()); + + let rs2_read = step.rs2().unwrap().value; + config + .rs2_read + .assign_limbs(instance, Value::new_unchecked(rs2_read).u16_fields()); + + config.equal.assign_instance( + instance, + E::BaseField::from(rs2_read as u64), + E::BaseField::from(rs1_read as u64), + )?; + + Ok(()) + } +} diff --git a/ceno_zkvm/src/instructions/riscv/branch/test.rs b/ceno_zkvm/src/instructions/riscv/branch/test.rs new file mode 100644 index 000000000..83fe36d04 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/branch/test.rs @@ -0,0 +1,108 @@ +use ceno_emul::{Change, StepRecord, Word, PC_STEP_SIZE}; +use goldilocks::GoldilocksExt2; +use itertools::Itertools; +use multilinear_extensions::mle::IntoMLEs; + +use super::*; +use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::Instruction, + scheme::mock_prover::{MockProver, MOCK_PC_BEQ, MOCK_PC_BNE, MOCK_PROGRAM}, +}; + +const A: Word = 0xbead1010; +const B: Word = 0xef552020; + +#[test] +fn test_opcode_beq() { + impl_opcode_beq(false); + impl_opcode_beq(true); +} + +fn impl_opcode_beq(equal: bool) { + let mut cs = ConstraintSystem::::new(|| "riscv"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = cb + .namespace( + || "beq", + |cb| { + let config = BeqInstruction::construct_circuit(cb); + Ok(config) + }, + ) + .unwrap() + .unwrap(); + + let pc_offset = if equal { 8 } else { PC_STEP_SIZE }; + let (raw_witin, _lkm) = BeqInstruction::assign_instances( + &config, + cb.cs.num_witin as usize, + vec![StepRecord::new_b_instruction( + 3, + Change::new(MOCK_PC_BEQ, MOCK_PC_BEQ + pc_offset), + MOCK_PROGRAM[6], + A, + if equal { A } else { B }, + 0, + )], + ) + .unwrap(); + + MockProver::assert_satisfied( + &mut cb, + &raw_witin + .de_interleaving() + .into_mles() + .into_iter() + .map(|v| v.into()) + .collect_vec(), + None, + ); +} + +#[test] +fn test_opcode_bne() { + impl_opcode_bne(false); + impl_opcode_bne(true); +} + +fn impl_opcode_bne(equal: bool) { + let mut cs = ConstraintSystem::::new(|| "riscv"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = cb + .namespace( + || "bne", + |cb| { + let config = BneInstruction::construct_circuit(cb); + Ok(config) + }, + ) + .unwrap() + .unwrap(); + + let pc_offset = if equal { PC_STEP_SIZE } else { 8 }; + let (raw_witin, _lkm) = BneInstruction::assign_instances( + &config, + cb.cs.num_witin as usize, + vec![StepRecord::new_b_instruction( + 3, + Change::new(MOCK_PC_BNE, MOCK_PC_BNE + pc_offset), + MOCK_PROGRAM[7], + A, + if equal { A } else { B }, + 0, + )], + ) + .unwrap(); + + MockProver::assert_satisfied( + &mut cb, + &raw_witin + .de_interleaving() + .into_mles() + .into_iter() + .map(|v| v.into()) + .collect_vec(), + None, + ); +} diff --git a/ceno_zkvm/src/instructions/riscv/config.rs b/ceno_zkvm/src/instructions/riscv/config.rs index 2396475cd..5c637f5f4 100644 --- a/ceno_zkvm/src/instructions/riscv/config.rs +++ b/ceno_zkvm/src/instructions/riscv/config.rs @@ -27,6 +27,7 @@ impl MsbInput<'_> { &self, instance: &mut [MaybeUninit], config: &MsbConfig, + lk_multiplicity: &mut LkMultiplicity, ) -> (u8, u8) { let n_limbs = self.limbs.len(); assert!(n_limbs > 0); @@ -37,6 +38,7 @@ impl MsbInput<'_> { set_val!(instance, config.high_limb_no_msb, { i64_to_base::(high_limb as i64) }); + lk_multiplicity.lookup_and_byte(high_limb as u64, 0b0111_1111); (msb, high_limb) } } @@ -61,6 +63,7 @@ impl UIntLtuInput<'_> { &self, instance: &mut [MaybeUninit], config: &UIntLtuConfig, + lk_multiplicity: &mut LkMultiplicity, ) -> bool { let mut idx = 0; let mut flag: bool = false; @@ -83,6 +86,15 @@ impl UIntLtuInput<'_> { set_val!(instance, config.indexes[idx], { i64_to_base::(flag as i64) }); + // (0..config.indexes.len()).for_each(|i| { + // if i == idx { + // lk_multiplicity.assert_ux::<1>(0); + // } else { + // lk_multiplicity.assert_ux::<1>(flag as u64); + // } + // }); + // this corresponds to assert_bit of index_sum + // lk_multiplicity.assert_ux::<1>(flag as u64); config.acc_indexes.iter().enumerate().for_each(|(id, wit)| { if id <= idx { set_val!(instance, wit, { i64_to_base::(flag as i64) }); @@ -102,6 +114,7 @@ impl UIntLtuInput<'_> { } }); let is_ltu = self.lhs_limbs[idx] < self.rhs_limbs[idx]; + lk_multiplicity.lookup_ltu_byte(self.lhs_limbs[idx] as u64, self.rhs_limbs[idx] as u64); set_val!(instance, config.is_ltu, { i64_to_base::(is_ltu as i64) }); is_ltu } @@ -127,16 +140,19 @@ impl UIntLtInput<'_> { &self, instance: &mut [MaybeUninit], config: &UIntLtConfig, + lk_multiplicity: &mut LkMultiplicity, ) -> bool { let n_limbs = self.lhs_limbs.len(); let lhs_msb_input = MsbInput { limbs: self.lhs_limbs, }; - let (lhs_msb, lhs_high_limb_no_msb) = lhs_msb_input.assign(instance, &config.lhs_msb); + let (lhs_msb, lhs_high_limb_no_msb) = + lhs_msb_input.assign(instance, &config.lhs_msb, lk_multiplicity); let rhs_msb_input = MsbInput { limbs: self.rhs_limbs, }; - let (rhs_msb, rhs_high_limb_no_msb) = rhs_msb_input.assign(instance, &config.rhs_msb); + let (rhs_msb, rhs_high_limb_no_msb) = + rhs_msb_input.assign(instance, &config.rhs_msb, lk_multiplicity); let mut lhs_limbs_no_msb = self.lhs_limbs.iter().copied().collect_vec(); lhs_limbs_no_msb[n_limbs - 1] = lhs_high_limb_no_msb; @@ -148,7 +164,7 @@ impl UIntLtInput<'_> { lhs_limbs: &lhs_limbs_no_msb, rhs_limbs: &rhs_limbs_no_msb, }; - let is_ltu = ltu_input.assign::(instance, &config.is_ltu); + let is_ltu = ltu_input.assign::(instance, &config.is_ltu, lk_multiplicity); let msb_is_equal = lhs_msb == rhs_msb; let msb_diff_inv = if msb_is_equal { @@ -166,12 +182,14 @@ impl UIntLtInput<'_> { // is_lt = a_s\cdot (1-b_s)+eq(a_s,b_s)\cdot ltu(a_{(is_lt as i64) }); + // lk_multiplicity.assert_ux::<1>(is_lt as u64); assert!(is_lt == 0 || is_lt == 1); is_lt > 0 } } +// TODO move ExprLtConfig to gadgets #[derive(Debug)] pub struct ExprLtConfig { pub is_lt: Option, diff --git a/ceno_zkvm/src/instructions/riscv/divu.rs b/ceno_zkvm/src/instructions/riscv/divu.rs new file mode 100644 index 000000000..ce081eb07 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/divu.rs @@ -0,0 +1,220 @@ +use ceno_emul::{InsnKind, StepRecord}; +use ff_ext::ExtensionField; +use itertools::Itertools; + +use super::{constants::UInt, r_insn::RInstructionConfig, RIVInstruction}; +use crate::{ + circuit_builder::CircuitBuilder, error::ZKVMError, gadgets::IsZeroConfig, + instructions::Instruction, uint::Value, witness::LkMultiplicity, +}; +use core::mem::MaybeUninit; +use std::marker::PhantomData; + +pub struct ArithConfig { + r_insn: RInstructionConfig, + + dividend: UInt, + divisor: UInt, + outcome: UInt, + + remainder: UInt, + inter_mul_value: UInt, + is_zero: IsZeroConfig, +} + +pub struct ArithInstruction(PhantomData<(E, I)>); + +pub struct DivUOp; +impl RIVInstruction for DivUOp { + const INST_KIND: InsnKind = InsnKind::DIVU; +} +pub type DivUInstruction = ArithInstruction; + +impl Instruction for ArithInstruction { + type InstructionConfig = ArithConfig; + + fn name() -> String { + format!("{:?}", I::INST_KIND) + } + + fn construct_circuit( + circuit_builder: &mut CircuitBuilder, + ) -> Result { + // outcome = dividend / divisor + remainder => dividend = divisor * outcome + r + let mut divisor = UInt::new_unchecked(|| "divisor", circuit_builder)?; + let mut outcome = UInt::new(|| "outcome", circuit_builder)?; + let r = UInt::new(|| "remainder", circuit_builder)?; + + let (inter_mul_value, dividend) = + divisor.mul_add(|| "dividend", circuit_builder, &mut outcome, &r, true)?; + + // div by zero check + let is_zero = IsZeroConfig::construct_circuit(circuit_builder, divisor.value())?; + + let outcome_value = outcome.value(); + circuit_builder + .condition_require_equal( + || "outcome_is_zero", + is_zero.expr(), + outcome_value.clone(), + ((1 << UInt::::M) - 1).into(), + outcome_value, + ) + .unwrap(); + + let r_insn = RInstructionConfig::::construct_circuit( + circuit_builder, + I::INST_KIND, + dividend.register_expr(), + divisor.register_expr(), + outcome.register_expr(), + )?; + + Ok(ArithConfig { + r_insn, + dividend, + divisor, + outcome, + remainder: r, + inter_mul_value, + is_zero, + }) + } + + fn assign_instance( + config: &Self::InstructionConfig, + instance: &mut [MaybeUninit], + lkm: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + let rs1 = step.rs1().unwrap().value; + let rs2 = step.rs2().unwrap().value; + let rd = step.rd().unwrap().value.after; + + // dividend = divisor * outcome + r + let dividend = Value::new_unchecked(rs1); + let divisor = Value::new_unchecked(rs2); + let outcome = Value::new(rd, lkm); + + // divisor * outcome + let inter_mul_value = Value::new(rs2 * rd, lkm); + let r = if rs2 == 0 { + Value::new_unchecked(0) + } else { + Value::new(rs1 % rs2, lkm) + }; + + // assignment + config.r_insn.assign_instance(instance, lkm, step)?; + config.divisor.assign_limbs(instance, divisor.u16_fields()); + config.outcome.assign_limbs(instance, outcome.u16_fields()); + + let (_, mul_carries, add_carries) = divisor.mul_add(&outcome, &r, lkm, true); + config + .inter_mul_value + .assign_limbs(instance, inter_mul_value.u16_fields()); + config.inter_mul_value.assign_carries( + instance, + mul_carries + .into_iter() + .map(|carry| E::BaseField::from(carry as u64)) + .collect_vec(), + ); + config.remainder.assign_limbs(instance, r.u16_fields()); + + config + .dividend + .assign_limbs(instance, dividend.u16_fields()); + config.dividend.assign_carries( + instance, + add_carries + .into_iter() + .map(|carry| E::BaseField::from(carry as u64)) + .collect_vec(), + ); + + config + .is_zero + .assign_instance(instance, (rs2 as u64).into())?; + + Ok(()) + } +} + +#[cfg(test)] +mod test { + + mod divu { + use std::u32; + + use ceno_emul::{Change, StepRecord, Word}; + use goldilocks::GoldilocksExt2; + use itertools::Itertools; + use multilinear_extensions::mle::IntoMLEs; + use rand::Rng; + + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{riscv::divu::DivUInstruction, Instruction}, + scheme::mock_prover::{MockProver, MOCK_PC_DIVU, MOCK_PROGRAM}, + }; + + fn verify(name: &'static str, dividend: Word, divisor: Word, outcome: Word) { + let mut cs = ConstraintSystem::::new(|| "riscv"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = cb + .namespace( + || format!("divu_{name}"), + |cb| Ok(DivUInstruction::construct_circuit(cb)), + ) + .unwrap() + .unwrap(); + + // values assignment + let (raw_witin, _) = DivUInstruction::assign_instances( + &config, + cb.cs.num_witin as usize, + vec![StepRecord::new_r_instruction( + 3, + MOCK_PC_DIVU, + MOCK_PROGRAM[9], + dividend, + divisor, + Change::new(0, outcome), + 0, + )], + ) + .unwrap(); + + MockProver::assert_satisfied( + &mut cb, + &raw_witin + .de_interleaving() + .into_mles() + .into_iter() + .map(|v| v.into()) + .collect_vec(), + None, + ); + } + #[test] + fn test_opcode_divu() { + verify("basic", 10, 2, 5); + verify("dividend > divisor", 10, 11, 0); + verify("remainder", 11, 2, 5); + verify("u32::MAX", u32::MAX, u32::MAX, 1); + verify("u32::MAX div by 2", u32::MAX, 2, u32::MAX / 2); + verify("div by zero", 10, 0, u32::MAX); + verify("mul carry", 1202729773, 171818539, 7); + } + + #[test] + fn test_opcode_divu_random() { + let mut rng = rand::thread_rng(); + let a: u32 = rng.gen(); + let b: u32 = rng.gen_range(1..u32::MAX); + println!("random: {} / {} = {}", a, b, a / b); + verify("random", a, b, a / b); + } + } +} diff --git a/ceno_zkvm/src/instructions/riscv/i_insn.rs b/ceno_zkvm/src/instructions/riscv/i_insn.rs new file mode 100644 index 000000000..203d32720 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/i_insn.rs @@ -0,0 +1,151 @@ +use ceno_emul::{InsnKind, StepRecord}; +use ff_ext::ExtensionField; + +use super::{ + config::ExprLtConfig, + constants::{UInt, PC_STEP_SIZE}, +}; +use crate::{ + chip_handler::{ + GlobalStateRegisterMachineChipOperations, RegisterChipOperations, RegisterExpr, + }, + circuit_builder::CircuitBuilder, + error::ZKVMError, + expression::{Expression, ToExpr, WitIn}, + instructions::riscv::config::ExprLtInput, + set_val, + tables::InsnRecord, + uint::Value, + witness::LkMultiplicity, +}; +use core::mem::MaybeUninit; + +/// This config handles the common part of I-type instructions: +/// - PC, cycle, fetch. +/// - Registers read and write. +/// +/// It does not witness of the register values, nor the actual function (e.g. srli, addi, etc). +#[derive(Debug)] +pub struct IInstructionConfig { + pub pc: WitIn, + pub ts: WitIn, + pub rs1_id: WitIn, + pub rd_id: WitIn, + pub prev_rd_value: UInt, + pub prev_rs1_ts: WitIn, + pub prev_rd_ts: WitIn, + pub lt_rs1_cfg: ExprLtConfig, + pub lt_rd_cfg: ExprLtConfig, +} + +impl IInstructionConfig { + pub fn construct_circuit( + circuit_builder: &mut CircuitBuilder, + insn_kind: InsnKind, + imm: &Expression, + rs1_read: RegisterExpr, + rd_written: RegisterExpr, + ) -> Result { + // State in. + let pc = circuit_builder.create_witin(|| "pc")?; + let cur_ts = circuit_builder.create_witin(|| "cur_ts")?; + circuit_builder.state_in(pc.expr(), cur_ts.expr())?; + + // Register indexes. + let rs1_id = circuit_builder.create_witin(|| "rs1_id")?; + let rd_id = circuit_builder.create_witin(|| "rd_id")?; + + // Fetch the instruction. + circuit_builder.lk_fetch(&InsnRecord::new( + pc.expr(), + (insn_kind.codes().opcode as usize).into(), + rd_id.expr(), + (insn_kind.codes().func3 as usize).into(), + rs1_id.expr(), + 0.into(), + imm.clone(), + ))?; + + // Register state. + let prev_rs1_ts = circuit_builder.create_witin(|| "prev_rs1_ts")?; + let prev_rd_ts = circuit_builder.create_witin(|| "prev_rd_ts")?; + let prev_rd_value = UInt::new_unchecked(|| "prev_rd_value", circuit_builder)?; + + // Register read and write. + let (next_ts, lt_rs1_cfg) = circuit_builder.register_read( + || "read_rs1", + &rs1_id, + prev_rs1_ts.expr(), + cur_ts.expr(), + rs1_read, + )?; + let (next_ts, lt_rd_cfg) = circuit_builder.register_write( + || "write_rd", + &rd_id, + prev_rd_ts.expr(), + next_ts, + prev_rd_value.register_expr(), + rd_written, + )?; + + // State out. + let next_pc = pc.expr() + PC_STEP_SIZE.into(); + circuit_builder.state_out(next_pc, next_ts)?; + + Ok(IInstructionConfig { + pc, + ts: cur_ts, + rs1_id, + rd_id, + prev_rd_value, + prev_rs1_ts, + prev_rd_ts, + lt_rs1_cfg, + lt_rd_cfg, + }) + } + + pub fn assign_instance( + &self, + instance: &mut [MaybeUninit<::BaseField>], + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + // State in. + set_val!(instance, self.pc, step.pc().before.0 as u64); + set_val!(instance, self.ts, step.cycle()); + + // Register indexes. + set_val!(instance, self.rs1_id, step.insn().rs1() as u64); + set_val!(instance, self.rd_id, step.insn().rd() as u64); + + // Fetch the instruction. + lk_multiplicity.fetch(step.pc().before.0); + + // Register state. + set_val!( + instance, + self.prev_rs1_ts, + step.rs1().unwrap().previous_cycle + ); + set_val!(instance, self.prev_rd_ts, step.rd().unwrap().previous_cycle); + self.prev_rd_value.assign_limbs( + instance, + Value::new_unchecked(step.rd().unwrap().value.before).u16_fields(), + ); + + // Register read and write. + ExprLtInput { + lhs: step.rs1().unwrap().previous_cycle, + rhs: step.cycle(), + } + .assign(instance, &self.lt_rs1_cfg, lk_multiplicity); + ExprLtInput { + lhs: step.rd().unwrap().previous_cycle, + rhs: step.cycle() + 1, + } + .assign(instance, &self.lt_rd_cfg, lk_multiplicity); + + Ok(()) + } +} diff --git a/ceno_zkvm/src/instructions/riscv/r_insn.rs b/ceno_zkvm/src/instructions/riscv/r_insn.rs index 7782060a2..b3297650b 100644 --- a/ceno_zkvm/src/instructions/riscv/r_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/r_insn.rs @@ -77,32 +77,31 @@ impl RInstructionConfig { let prev_rd_value = UInt::new_unchecked(|| "prev_rd_value", circuit_builder)?; // Register read and write. - let (ts, lt_rs1_cfg) = circuit_builder.register_read( + let (next_ts, lt_rs1_cfg) = circuit_builder.register_read( || "read_rs1", &rs1_id, prev_rs1_ts.expr(), cur_ts.expr(), rs1_read, )?; - let (ts, lt_rs2_cfg) = circuit_builder.register_read( + let (next_ts, lt_rs2_cfg) = circuit_builder.register_read( || "read_rs2", &rs2_id, prev_rs2_ts.expr(), - ts, + next_ts, rs2_read, )?; - let (ts, lt_prev_ts_cfg) = circuit_builder.register_write( + let (next_ts, lt_prev_ts_cfg) = circuit_builder.register_write( || "write_rd", &rd_id, prev_rd_ts.expr(), - ts, + next_ts, prev_rd_value.register_expr(), rd_written, )?; // State out. let next_pc = pc.expr() + PC_STEP_SIZE.into(); - let next_ts = ts + 1.into(); circuit_builder.state_out(next_pc, next_ts)?; Ok(RInstructionConfig { diff --git a/ceno_zkvm/src/instructions/riscv/shift_imm.rs b/ceno_zkvm/src/instructions/riscv/shift_imm.rs new file mode 100644 index 000000000..a7da58377 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/shift_imm.rs @@ -0,0 +1,12 @@ +use super::RIVInstruction; + +mod shift_imm_circuit; + +#[cfg(test)] +mod test; + +pub struct SrliOp; + +impl RIVInstruction for SrliOp { + const INST_KIND: ceno_emul::InsnKind = ceno_emul::InsnKind::SRLI; +} diff --git a/ceno_zkvm/src/instructions/riscv/shift_imm/shift_imm_circuit.rs b/ceno_zkvm/src/instructions/riscv/shift_imm/shift_imm_circuit.rs new file mode 100644 index 000000000..28db629d2 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/shift_imm/shift_imm_circuit.rs @@ -0,0 +1,116 @@ +use std::{marker::PhantomData, mem::MaybeUninit}; + +use ceno_emul::StepRecord; +use ff_ext::ExtensionField; + +use crate::{ + circuit_builder::CircuitBuilder, + error::ZKVMError, + instructions::{ + riscv::{constants::UInt, i_insn::IInstructionConfig, RIVInstruction}, + Instruction, + }, + witness::LkMultiplicity, + Value, +}; + +pub struct ShiftImmInstruction(PhantomData<(E, I)>); + +pub struct InstructionConfig { + i_insn: IInstructionConfig, + + imm: UInt, + rd_written: UInt, + remainder: UInt, + rd_imm_mul: UInt, + rd_imm_rem_add: UInt, +} + +impl Instruction for ShiftImmInstruction { + type InstructionConfig = InstructionConfig; + + fn name() -> String { + format!("{:?}", I::INST_KIND) + } + + fn construct_circuit( + circuit_builder: &mut CircuitBuilder, + ) -> Result { + let mut imm = UInt::new(|| "imm", circuit_builder)?; + let mut rd_written = UInt::new_unchecked(|| "rd_written", circuit_builder)?; + + // Note: `imm` is set to 2**imm (upto 32 bit) just for SRLI for efficient verification + // Goal is to constrain: + // rs1_read == rd_written * imm + remainder + let remainder = UInt::new(|| "remainder", circuit_builder)?; + let (rd_imm_mul, rd_imm_rem_add) = rd_written.mul_add( + || "rd_written * imm +remainder ", + circuit_builder, + &mut imm, + &remainder, + true, + )?; + + let i_insn = IInstructionConfig::::construct_circuit( + circuit_builder, + I::INST_KIND, + &imm.value(), + rd_imm_rem_add.register_expr(), + rd_written.register_expr(), + )?; + + Ok(InstructionConfig { + i_insn, + imm, + rd_written, + remainder, + rd_imm_mul, + rd_imm_rem_add, + }) + } + + fn assign_instance( + config: &Self::InstructionConfig, + instance: &mut [MaybeUninit<::BaseField>], + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + // We need to calculate result and remainder. + let rs1_read = step.rs1().unwrap().value; + let rd_written = step.rd().unwrap().value.after; + let imm = step.insn().imm_or_funct7(); + let result = rs1_read.wrapping_div(imm); + let remainder = rs1_read.wrapping_sub(result * imm); + assert_eq!(result, rd_written, "SRLI: result mismatch"); + + // Assignment. + let rd_written = Value::new_unchecked(rd_written); + let imm = Value::new(imm, lk_multiplicity); + let remainder = Value::new(remainder, lk_multiplicity); + + let rd_imm_mul = rd_written.mul(&imm, lk_multiplicity, true); + let rd_imm = Value::from_limb_slice_unchecked(&rd_imm_mul.0); + config + .rd_imm_mul + .assign_limb_with_carry(instance, &rd_imm_mul); + + let rd_imm_rem_add = rd_imm.add(&remainder, lk_multiplicity, true); + debug_assert_eq!( + Value::::from_limb_slice_unchecked(&rd_imm_rem_add.0).as_u64(), + rs1_read as u64, + "SRLI: rd_imm_rem_add mismatch" + ); + config + .rd_imm_rem_add + .assign_limb_with_carry(instance, &rd_imm_rem_add); + + config + .i_insn + .assign_instance(instance, lk_multiplicity, step)?; + config.imm.assign_value(instance, imm); + config.rd_written.assign_value(instance, rd_written); + config.remainder.assign_value(instance, remainder); + + Ok(()) + } +} diff --git a/ceno_zkvm/src/instructions/riscv/shift_imm/test.rs b/ceno_zkvm/src/instructions/riscv/shift_imm/test.rs new file mode 100644 index 000000000..c516d0ef6 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/shift_imm/test.rs @@ -0,0 +1,94 @@ +use ceno_emul::{Change, StepRecord}; +use goldilocks::GoldilocksExt2; +use itertools::Itertools; +use multilinear_extensions::mle::IntoMLEs; + +use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::Instruction, + scheme::mock_prover::{MockProver, MOCK_PC_SRLI, MOCK_PC_SRLI_31, MOCK_PROGRAM}, +}; + +use super::{shift_imm_circuit::ShiftImmInstruction, SrliOp}; + +#[test] +fn test_opcode_srli_1() { + let mut cs = ConstraintSystem::::new(|| "riscv"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = cb + .namespace( + || "srli", + |cb| { + let config = ShiftImmInstruction::::construct_circuit(cb); + Ok(config) + }, + ) + .unwrap() + .unwrap(); + + let (raw_witin, _) = ShiftImmInstruction::::assign_instances( + &config, + cb.cs.num_witin as usize, + vec![StepRecord::new_i_instruction( + 3, + MOCK_PC_SRLI, + MOCK_PROGRAM[10], + 32, + Change::new(0, 32 >> 3), + 0, + )], + ) + .unwrap(); + + MockProver::assert_satisfied( + &cb, + &raw_witin + .de_interleaving() + .into_mles() + .into_iter() + .map(|v| v.into()) + .collect_vec(), + None, + ); +} + +#[test] +fn test_opcode_srli_2() { + let mut cs = ConstraintSystem::::new(|| "riscv"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = cb + .namespace( + || "srli", + |cb| { + let config = ShiftImmInstruction::::construct_circuit(cb); + Ok(config) + }, + ) + .unwrap() + .unwrap(); + + let (raw_witin, _) = ShiftImmInstruction::::assign_instances( + &config, + cb.cs.num_witin as usize, + vec![StepRecord::new_i_instruction( + 3, + MOCK_PC_SRLI_31, + MOCK_PROGRAM[11], + 32, + Change::new(0, 32 >> 31), + 0, + )], + ) + .unwrap(); + + MockProver::assert_satisfied( + &cb, + &raw_witin + .de_interleaving() + .into_mles() + .into_iter() + .map(|v| v.into()) + .collect_vec(), + None, + ); +} diff --git a/ceno_zkvm/src/lib.rs b/ceno_zkvm/src/lib.rs index efac2dbc9..5af066611 100644 --- a/ceno_zkvm/src/lib.rs +++ b/ceno_zkvm/src/lib.rs @@ -10,6 +10,7 @@ pub use utils::u64vec; mod chip_handler; pub mod circuit_builder; pub mod expression; +pub mod gadgets; mod keygen; pub mod structs; mod uint; diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index 31f991ecf..31024a422 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -11,7 +11,7 @@ use crate::{ }; use ark_std::test_rng; use base64::{engine::general_purpose::STANDARD_NO_PAD, Engine}; -use ceno_emul::{ByteAddr, CENO_PLATFORM}; +use ceno_emul::{ByteAddr, Change, CENO_PLATFORM}; use ff_ext::ExtensionField; use generic_static::StaticTypeMap; use goldilocks::SmallField; @@ -19,6 +19,7 @@ use itertools::Itertools; use multilinear_extensions::virtual_poly_v2::ArcMultilinearExtension; use std::{ collections::HashSet, + fmt::Write, fs::{self, File}, hash::Hash, io::{BufReader, ErrorKind}, @@ -30,9 +31,12 @@ use std::{ pub const MOCK_RS1: u32 = 2; pub const MOCK_RS2: u32 = 3; pub const MOCK_RD: u32 = 4; +pub const MOCK_IMM_3: u32 = 3; +pub const MOCK_IMM_31: u32 = 31; /// The program baked in the MockProver. /// TODO: Make this a parameter? #[allow(clippy::identity_op)] +#[allow(clippy::unusual_byte_groupings)] pub const MOCK_PROGRAM: &[u32] = &[ // R-Type // funct7 | rs2 | rs1 | funct3 | rd | opcode @@ -49,6 +53,19 @@ pub const MOCK_PROGRAM: &[u32] = &[ 0x00 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0b110 << 12 | MOCK_RD << 7 | 0x33, // xor x4, x2, x3 0x00 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0b100 << 12 | MOCK_RD << 7 | 0x33, + // B-Type + // beq x2, x3, 8 + 0x00 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0b000 << 12 | 0x08 << 7 | 0x63, + // bne x2, x3, 8 + 0x00 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0b001 << 12 | 0x08 << 7 | 0x63, + // blt x2, x3, -8 + 0b_1_111111 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0b_100 << 12 | 0b_1100_1 << 7 | 0x63, + // divu (0x01, 0x05, 0x33) + 0x01 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0x05 << 12 | MOCK_RD << 7 | 0x33, + // srli x4, x2, 3 + 0x00 << 25 | MOCK_IMM_3 << 20 | MOCK_RS1 << 15 | 0x05 << 12 | MOCK_RD << 7 | 0x13, + // srli x4, x2, 31 + 0x00 << 25 | MOCK_IMM_31 << 20 | MOCK_RS1 << 15 | 0x05 << 12 | MOCK_RD << 7 | 0x13, ]; // Addresses of particular instructions in the mock program. pub const MOCK_PC_ADD: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start()); @@ -57,6 +74,15 @@ pub const MOCK_PC_MUL: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 8); pub const MOCK_PC_AND: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 12); pub const MOCK_PC_OR: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 16); pub const MOCK_PC_XOR: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 20); +pub const MOCK_PC_BEQ: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 24); +pub const MOCK_PC_BNE: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 28); +pub const MOCK_PC_BLT: Change = Change { + before: ByteAddr(CENO_PLATFORM.pc_start() + 32), + after: ByteAddr(CENO_PLATFORM.pc_start() + 24), +}; +pub const MOCK_PC_DIVU: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 36); +pub const MOCK_PC_SRLI: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 40); +pub const MOCK_PC_SRLI_31: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 44); #[allow(clippy::enum_variant_names)] #[derive(Debug, PartialEq, Clone)] @@ -104,7 +130,7 @@ impl MockProverError { "\nAssertZeroError {name:?}: Evaluated expression is not zero\n\ Expression: {expression_fmt}\n\ Evaluation: {eval_fmt} != 0\n\ - Inst[{inst_id}]: {wtns_fmt}\n", + Inst[{inst_id}]:\n{wtns_fmt}\n", ); } Self::AssertEqualError { @@ -125,7 +151,7 @@ impl MockProverError { Left: {left_eval_fmt} != Right: {right_eval_fmt}\n\ Left Expression: {left_expression_fmt}\n\ Right Expression: {right_expression_fmt}\n\ - Inst[{inst_id}]: {wtns_fmt}\n", + Inst[{inst_id}]:\n{wtns_fmt}\n", ); } Self::LookupError { @@ -141,115 +167,136 @@ impl MockProverError { "\nLookupError {name:#?}: Evaluated expression does not exist in T vector\n\ Expression: {expression_fmt}\n\ Evaluation: {eval_fmt}\n\ - Inst[{inst_id}]: {wtns_fmt}\n", + Inst[{inst_id}]:\n{wtns_fmt}\n", ); } } + } +} - fn fmt_expr( - expression: &Expression, - wtns: &mut Vec, - add_prn_sum: bool, - ) -> String { - match expression { - Expression::WitIn(wit_in) => { - wtns.push(*wit_in); - format!("WitIn({})", wit_in) - } - Expression::Challenge(id, _, _, _) => format!("Challenge({})", id), - Expression::Constant(constant) => fmt_base_field::(constant, true).to_string(), - Expression::Fixed(fixed) => format!("{:?}", fixed), - Expression::Sum(left, right) => { - let s = format!( - "{} + {}", - fmt_expr(left, wtns, false), - fmt_expr(right, wtns, false) - ); - if add_prn_sum { format!("({})", s) } else { s } +fn fmt_expr( + expression: &Expression, + wtns: &mut Vec, + add_prn_sum: bool, +) -> String { + match expression { + Expression::WitIn(wit_in) => { + wtns.push(*wit_in); + format!("WitIn({})", wit_in) + } + Expression::Challenge(id, pow, scaler, offset) => { + if *pow == 1 && *scaler == 1.into() && *offset == 0.into() { + format!("Challenge({})", id) + } else { + let mut s = String::new(); + if *scaler != 1.into() { + write!(s, "{}*", fmt_field(scaler)).unwrap(); } - Expression::Product(left, right) => { - format!( - "{} * {}", - fmt_expr(left, wtns, true), - fmt_expr(right, wtns, true) - ) + write!(s, "Challenge({})", id,).unwrap(); + if *pow > 1 { + write!(s, "^{}", pow).unwrap(); } - Expression::ScaledSum(x, a, b) => { - let s = format!( - "{} * {} + {}", - fmt_expr(a, wtns, true), - fmt_expr(x, wtns, true), - fmt_expr(b, wtns, false) - ); - if add_prn_sum { format!("({})", s) } else { s } + if *offset != 0.into() { + write!(s, "+{}", fmt_field(offset)).unwrap(); } + s } } - - fn fmt_field(field: &E) -> String { - let name = format!("{:?}", field); - let name = name.split('(').next().unwrap_or("ExtensionField"); + Expression::Constant(constant) => fmt_base_field::(constant, true).to_string(), + Expression::Fixed(fixed) => format!("{:?}", fixed), + Expression::Sum(left, right) => { + let s = format!( + "{} + {}", + fmt_expr(left, wtns, false), + fmt_expr(right, wtns, false) + ); + if add_prn_sum { format!("({})", s) } else { s } + } + Expression::Product(left, right) => { format!( - "{name}[{}]", - field - .as_bases() - .iter() - .map(|b| fmt_base_field::(b, false)) - .collect::>() - .join(",") + "{} * {}", + fmt_expr(left, wtns, true), + fmt_expr(right, wtns, true) ) } - - fn fmt_base_field(base_field: &E::BaseField, add_prn: bool) -> String { - let value = base_field.to_canonical_u64(); - - if value > E::BaseField::MODULUS_U64 - u16::MAX as u64 { - // beautiful format for negative number > -65536 - fmt_prn(format!("-{}", E::BaseField::MODULUS_U64 - value), add_prn) - } else if value < u16::MAX as u64 { - format!("{value}") - } else { - // hex - if value > E::BaseField::MODULUS_U64 - (u32::MAX as u64 + u16::MAX as u64) { - fmt_prn( - format!("-{:#x}", E::BaseField::MODULUS_U64 - value), - add_prn, - ) - } else { - format!("{value:#x}") - } - } + Expression::ScaledSum(x, a, b) => { + let s = format!( + "{} * {} + {}", + fmt_expr(a, wtns, true), + fmt_expr(x, wtns, true), + fmt_expr(b, wtns, false) + ); + if add_prn_sum { format!("({})", s) } else { s } } + } +} - fn fmt_prn(s: String, add_prn: bool) -> String { - if add_prn { format!("({})", s) } else { s } - } +fn fmt_field(field: &E) -> String { + let name = format!("{:?}", field); + let name = name.split('(').next().unwrap_or("ExtensionField"); + + let data = field + .as_bases() + .iter() + .map(|b| fmt_base_field::(b, false)) + .collect::>(); + let only_one_limb = field.as_bases()[1..].iter().all(|&x| x == 0.into()); + + if only_one_limb { + data[0].to_string() + } else { + format!("{name}[{}]", data.join(",")) + } +} - fn fmt_wtns( - wtns: &[WitnessId], - wits_in: &[ArcMultilinearExtension], - inst_id: usize, - wits_in_name: &[String], - ) -> String { - wtns.iter() - .sorted() - .map(|wt_id| { - let wit = &wits_in[*wt_id as usize]; - let name = &wits_in_name[*wt_id as usize]; - let value_fmt = if let Some(e) = wit.get_ext_field_vec_optn() { - fmt_field(&e[inst_id]) - } else if let Some(bf) = wit.get_base_field_vec_optn() { - fmt_base_field::(&bf[inst_id], true) - } else { - "Unknown".to_string() - }; - format!("\nWitIn({wt_id})\npath={name}\nvalue={value_fmt}\n") - }) - .join("----\n") +fn fmt_base_field(base_field: &E::BaseField, add_prn: bool) -> String { + let value = base_field.to_canonical_u64(); + + if value > E::BaseField::MODULUS_U64 - u16::MAX as u64 { + // beautiful format for negative number > -65536 + fmt_prn(format!("-{}", E::BaseField::MODULUS_U64 - value), add_prn) + } else if value < u16::MAX as u64 { + format!("{value}") + } else { + // hex + if value > E::BaseField::MODULUS_U64 - (u32::MAX as u64 + u16::MAX as u64) { + fmt_prn( + format!("-{:#x}", E::BaseField::MODULUS_U64 - value), + add_prn, + ) + } else { + format!("{value:#x}") } } } +fn fmt_prn(s: String, add_prn: bool) -> String { + if add_prn { format!("({})", s) } else { s } +} + +fn fmt_wtns( + wtns: &[WitnessId], + wits_in: &[ArcMultilinearExtension], + inst_id: usize, + wits_in_name: &[String], +) -> String { + wtns.iter() + .sorted() + .map(|wt_id| { + let wit = &wits_in[*wt_id as usize]; + let name = &wits_in_name[*wt_id as usize]; + let value_fmt = if let Some(e) = wit.get_ext_field_vec_optn() { + fmt_field(&e[inst_id]) + } else if let Some(bf) = wit.get_base_field_vec_optn() { + fmt_base_field::(&bf[inst_id], true) + } else { + "Unknown".to_string() + }; + format!(" WitIn({wt_id})={value_fmt} {name:?}") + }) + .join("\n") +} + pub(crate) struct MockProver { _phantom: PhantomData, } @@ -524,6 +571,40 @@ mod tests { use goldilocks::{Goldilocks, GoldilocksExt2}; use multilinear_extensions::mle::{IntoMLE, IntoMLEs}; + #[test] + fn test_fmt_expr_challenge_1() { + let a = Expression::::Challenge(0, 2, 3.into(), 4.into()); + let b = Expression::::Challenge(0, 5, 6.into(), 7.into()); + + let mut wtns_acc = vec![]; + let s = fmt_expr(&(a * b), &mut wtns_acc, false); + + assert_eq!( + s, + "18*Challenge(0)^7+28 + 21*Challenge(0)^2 + 24*Challenge(0)^5" + ); + } + + #[test] + fn test_fmt_expr_challenge_2() { + let a = Expression::::Challenge(0, 1, 1.into(), 0.into()); + let b = Expression::::Challenge(0, 1, 1.into(), 0.into()); + + let mut wtns_acc = vec![]; + let s = fmt_expr(&(a * b), &mut wtns_acc, false); + + assert_eq!(s, "Challenge(0)^2"); + } + + #[test] + fn test_fmt_expr_wtns_acc_1() { + let expr = Expression::::WitIn(0); + let mut wtns_acc = vec![]; + let s = fmt_expr(&expr, &mut wtns_acc, false); + assert_eq!(s, "WitIn(0)"); + assert_eq!(wtns_acc, vec![0]); + } + #[derive(Debug)] #[allow(dead_code)] struct AssertZeroCircuit { diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 2f14cd0f3..d36bee6de 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -473,11 +473,30 @@ impl> ZKVMProver { assert!(sel_non_lc_zero_sumcheck.is_some()); // \sum_t (sel(rt, t) * (\sum_j alpha_{j} * all_monomial_terms(t) )) - for (expr, alpha) in cs + for ((expr, name), alpha) in cs .assert_zero_sumcheck_expressions .iter() + .zip_eq(cs.assert_zero_sumcheck_expressions_namespace_map.iter()) .zip_eq(alpha_pow_iter) { + // sanity check in debug build and output != instance index for zero check sumcheck poly + if cfg!(debug_assertions) { + let expected_zero_poly = wit_infer_by_expr(&[], &witnesses, challenges, expr); + let top_100_errors = expected_zero_poly + .get_ext_field_vec() + .iter() + .enumerate() + .filter(|(_, v)| **v != E::ZERO) + .take(100) + .collect_vec(); + if !top_100_errors.is_empty() { + return Err(ZKVMError::InvalidWitness(format!( + "degree > 1 zero check virtual poly: expr {name} != 0 on instance indexes: {}...", + top_100_errors.into_iter().map(|(i, _)| i).join(",") + ))); + } + } + distrinct_zerocheck_terms_set.extend(virtual_polys.add_mle_list_by_expr( sel_non_lc_zero_sumcheck.as_ref(), witnesses.iter().collect_vec(), diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 67e577473..4a0e94096 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -43,6 +43,7 @@ pub type ChallengeId = u16; #[derive(Copy, Clone, Debug)] pub enum ROMType { U5 = 0, // 2^5 = 32 + U1, // TODO: optimize it U8, // 2^8 = 256 U16, // 2^16 = 65,536 And, // a & b where a, b are bytes @@ -56,6 +57,7 @@ pub enum ROMType { pub enum RAMType { GlobalState, Register, + Memory, } /// A point is a vector of num_var length diff --git a/ceno_zkvm/src/tables/program.rs b/ceno_zkvm/src/tables/program.rs index 6acb8991c..8a3b50070 100644 --- a/ceno_zkvm/src/tables/program.rs +++ b/ceno_zkvm/src/tables/program.rs @@ -10,8 +10,9 @@ use crate::{ tables::TableCircuit, witness::RowMajorMatrix, }; -use ceno_emul::{DecodedInstruction, Word, CENO_PLATFORM, WORD_SIZE}; +use ceno_emul::{DecodedInstruction, Word, CENO_PLATFORM, PC_STEP_SIZE, WORD_SIZE}; use ff_ext::ExtensionField; +use goldilocks::SmallField; use itertools::Itertools; use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}; @@ -35,22 +36,27 @@ impl InsnRecord { &self.0[1] } - pub fn rd(&self) -> &T { + pub fn rd_or_zero(&self) -> &T { &self.0[2] } - pub fn funct3(&self) -> &T { + pub fn funct3_or_zero(&self) -> &T { &self.0[3] } - pub fn rs1(&self) -> &T { + pub fn rs1_or_zero(&self) -> &T { &self.0[4] } - pub fn rs2(&self) -> &T { + pub fn rs2_or_zero(&self) -> &T { &self.0[5] } + /// Iterate through the fields, except immediate because it is complicated. + fn without_imm(&self) -> &[T] { + &self.0[0..6] + } + /// The complete immediate value, for instruction types I/S/B/U/J. /// Otherwise, the field funct7 of R-Type instructions. pub fn imm_or_funct7(&self) -> &T { @@ -63,13 +69,23 @@ impl InsnRecord { InsnRecord::new( pc, insn.opcode(), - insn.rd(), - insn.funct3(), - insn.rs1(), - insn.rs2(), - insn.funct7(), // TODO: get immediate for all types. + insn.rd_or_zero(), + insn.funct3_or_zero(), + insn.rs1_or_zero(), + insn.rs2_or_zero(), + insn.imm_or_funct7(), ) } + + /// Interpret the immediate or funct7 as unsigned or signed depending on the instruction. + /// Convert negative values from two's complement to field. + pub fn imm_or_funct7_field(insn: &DecodedInstruction) -> F { + if insn.imm_is_negative() { + -F::from(-(insn.imm_or_funct7() as i32) as u64) + } else { + F::from(insn.imm_or_funct7() as u64) + } + } } #[derive(Clone, Debug)] @@ -137,13 +153,25 @@ impl TableCircuit for ProgramTableCircuit { .with_min_len(MIN_PAR_SIZE) .zip((0..num_instructions).into_par_iter()) .for_each(|(row, i)| { - let pc = pc_start + (i * WORD_SIZE) as u32; + let pc = pc_start + (i * PC_STEP_SIZE) as u32; let insn = DecodedInstruction::new(program[i]); let values = InsnRecord::from_decoded(pc, &insn); - for (col, val) in config.record.as_slice().iter().zip_eq(values.as_slice()) { + // Copy all the fields except immediate. + for (col, val) in config + .record + .without_imm() + .iter() + .zip_eq(values.without_imm()) + { set_fixed_val!(row, *col, E::BaseField::from(*val as u64)); } + + set_fixed_val!( + row, + config.record.imm_or_funct7(), + InsnRecord::imm_or_funct7_field(&insn) + ); }); fixed diff --git a/ceno_zkvm/src/tables/range.rs b/ceno_zkvm/src/tables/range.rs index cdbf5aca1..53a5d7c2c 100644 --- a/ceno_zkvm/src/tables/range.rs +++ b/ceno_zkvm/src/tables/range.rs @@ -7,6 +7,15 @@ pub use range_circuit::{RangeTable, RangeTableCircuit}; use crate::structs::ROMType; +pub struct U1Table; +impl RangeTable for U1Table { + const ROM_TYPE: ROMType = ROMType::U1; + fn len() -> usize { + 1 << 1 + } +} +pub type U1TableCircuit = RangeTableCircuit; + pub struct U5Table; impl RangeTable for U5Table { const ROM_TYPE: ROMType = ROMType::U5; diff --git a/ceno_zkvm/src/uint.rs b/ceno_zkvm/src/uint.rs index bc30b5b4b..a2cfa2737 100644 --- a/ceno_zkvm/src/uint.rs +++ b/ceno_zkvm/src/uint.rs @@ -18,6 +18,7 @@ use ff_ext::ExtensionField; use goldilocks::SmallField; use itertools::Itertools; use std::{ + borrow::Cow, mem::{self, MaybeUninit}, ops::Index, }; @@ -147,6 +148,23 @@ impl UIntLimbs { } } + pub fn assign_value + Default + From + Copy>( + &self, + instance: &mut [MaybeUninit], + value: Value, + ) { + self.assign_limbs(instance, value.u16_fields()) + } + + pub fn assign_limb_with_carry( + &self, + instance: &mut [MaybeUninit], + (limbs, carry): &(Vec, Vec), + ) { + self.assign_limbs(instance, limbs.iter().map(|v| (*v as u64).into()).collect()); + self.assign_carries(instance, carry.iter().map(|v| (*v as u64).into()).collect()); + } + pub fn assign_limbs( &self, instance: &mut [MaybeUninit], @@ -419,6 +437,16 @@ impl UIntLimbs { res } + + /// Get an Expression from the limbs, unsafe if Uint value exceeds field limit + pub fn value(&self) -> Expression { + let base = Expression::from(1 << C); + self.expr() + .into_iter() + .rev() + .reduce(|sum, limb| sum * base.clone() + limb) + .unwrap() + } } /// Construct `UIntLimbs` from `Vec` @@ -469,7 +497,7 @@ impl UIntLimbs<32, 16, E> { /// Return a value suitable for register read/write. From [u16; 2] limbs. pub fn register_expr(&self) -> RegisterExpr { let u16_limbs = self.expr(); - RegisterExpr(u16_limbs.try_into().expect("two limbs with M=32 and C=16")) + u16_limbs.try_into().expect("two limbs with M=32 and C=16") } } @@ -484,19 +512,19 @@ impl UIntLimbs<32, 8, E> { a + b * 256.into() }) .collect_vec(); - RegisterExpr(u16_limbs.try_into().expect("four limbs with M=32 and C=8")) + u16_limbs.try_into().expect("four limbs with M=32 and C=8") } } -pub struct Value + Copy> { +pub struct Value<'a, T: Into + From + Copy + Default> { #[allow(dead_code)] val: T, - pub limbs: Vec, + pub limbs: Cow<'a, [u16]>, } // TODO generalize to support non 16 bit limbs // TODO optimize api with fixed size array -impl + Copy> Value { +impl<'a, T: Into + From + Copy + Default> Value<'a, T> { const LIMBS: usize = { let u16_bytes = (u16::BITS / 8) as usize; mem::size_of::() / u16_bytes @@ -505,7 +533,7 @@ impl + Copy> Value { pub fn new(val: T, lkm: &mut LkMultiplicity) -> Self { let uint = Value:: { val, - limbs: Self::split_to_u16(val), + limbs: Cow::Owned(Self::split_to_u16(val)), }; Self::assert_u16(&uint.limbs, lkm); uint @@ -514,7 +542,29 @@ impl + Copy> Value { pub fn new_unchecked(val: T) -> Self { Value:: { val, - limbs: Self::split_to_u16(val), + limbs: Cow::Owned(Self::split_to_u16(val)), + } + } + + pub fn from_limb_unchecked(limbs: Vec) -> Self { + Value:: { + val: limbs + .iter() + .rev() + .fold(0u32, |acc, &v| acc * (1 << 16) + v as u32) + .into(), + limbs: Cow::Owned(limbs), + } + } + + pub fn from_limb_slice_unchecked(limbs: &'a [u16]) -> Self { + Value:: { + val: limbs + .iter() + .rev() + .fold(0u32, |acc, &v| acc * (1 << 16) + v as u32) + .into(), + limbs: Cow::Borrowed(limbs), } } @@ -539,6 +589,11 @@ impl + Copy> Value { &self.limbs } + /// Convert the limbs to a u64 value + pub fn as_u64(&self) -> u64 { + self.val.into() + } + pub fn u16_fields(&self) -> Vec { self.limbs.iter().map(|v| F::from(*v as u64)).collect_vec() } @@ -548,16 +603,16 @@ impl + Copy> Value { rhs: &Self, lkm: &mut LkMultiplicity, with_overflow: bool, - ) -> (Vec, Vec) { + ) -> (Vec, Vec) { let res = self.as_u16_limbs().iter().zip(rhs.as_u16_limbs()).fold( vec![], |mut acc, (a_limb, b_limb)| { let (a, b) = a_limb.overflowing_add(*b_limb); if let Some((_, prev_carry)) = acc.last() { - let (e, d) = a.overflowing_add(*prev_carry as u16); - acc.push((e, b || d)); + let (e, d) = a.overflowing_add(*prev_carry); + acc.push((e, (b || d) as u16)); } else { - acc.push((a, b)); + acc.push((a, b as u16)); } // range check if let Some((limb, _)) = acc.last() { @@ -566,9 +621,9 @@ impl + Copy> Value { acc }, ); - let (limbs, mut carries): (Vec, Vec) = res.into_iter().unzip(); + let (limbs, mut carries): (Vec, Vec) = res.into_iter().unzip(); if !with_overflow { - carries.resize(carries.len() - 1, false); + carries.resize(carries.len() - 1, 0); } carries.iter().for_each(|c| lkm.assert_ux::<16>(*c as u64)); (limbs, carries) @@ -579,22 +634,46 @@ impl + Copy> Value { rhs: &Self, lkm: &mut LkMultiplicity, with_overflow: bool, + ) -> (Vec, Vec) { + self.internal_mul(rhs, lkm, with_overflow) + } + + pub fn mul_add( + &self, + mul: &Self, + addend: &Self, + lkm: &mut LkMultiplicity, + with_overflow: bool, + ) -> (Vec, Vec, Vec) { + let (ret, mul_carries) = self.internal_mul(mul, lkm, with_overflow); + let (ret, add_carries) = addend.add(&Self::from_limb_unchecked(ret), lkm, with_overflow); + (ret, mul_carries, add_carries) + } + + fn internal_mul( + &self, + mul: &Self, + lkm: &mut LkMultiplicity, + with_overflow: bool, ) -> (Vec, Vec) { let a_limbs = self.as_u16_limbs(); - let b_limbs = rhs.as_u16_limbs(); + let b_limbs = mul.as_u16_limbs(); let num_limbs = a_limbs.len(); let mut c_limbs = vec![0u16; num_limbs]; let mut carries = vec![0u16; num_limbs]; - a_limbs.iter().enumerate().for_each(|(i, a_limb)| { - b_limbs.iter().enumerate().for_each(|(j, b_limb)| { + a_limbs.iter().enumerate().for_each(|(i, &a_limb)| { + b_limbs.iter().enumerate().for_each(|(j, &b_limb)| { let idx = i + j; if idx < num_limbs { - let (c, overflow_mul) = a_limb.overflowing_mul(*b_limb); + let (c, overflow_mul) = a_limb.overflowing_mul(b_limb); let (ret, overflow_add) = c_limbs[idx].overflowing_add(c); c_limbs[idx] = ret; - carries[idx] += (overflow_add as u16) + (overflow_mul as u16); + carries[idx] += overflow_add as u16; + if overflow_mul { + carries[idx] += ((a_limb as u32 * b_limb as u32) / (1 << 16)) as u16; + } } }) }); @@ -623,73 +702,73 @@ impl + Copy> Value { #[cfg(test)] mod tests { - use crate::witness::LkMultiplicity; - - use super::Value; - - #[test] - fn test_add() { - let a = Value::new_unchecked(1u32); - let b = Value::new_unchecked(2u32); - let mut lkm = LkMultiplicity::default(); - - let (c, carries) = a.add(&b, &mut lkm, true); - assert_eq!(c[0], 3); - assert_eq!(c[1], 0); - assert_eq!(carries[0], false); - assert_eq!(carries[1], false); - } - - #[test] - fn test_add_carry() { - let a = Value::new_unchecked(u16::MAX as u32); - let b = Value::new_unchecked(2u32); - let mut lkm = LkMultiplicity::default(); - - let (c, carries) = a.add(&b, &mut lkm, true); - assert_eq!(c[0], 1); - assert_eq!(c[1], 1); - assert_eq!(carries[0], true); - assert_eq!(carries[1], false); - } - - #[test] - fn test_mul() { - let a = Value::new_unchecked(1u32); - let b = Value::new_unchecked(2u32); - let mut lkm = LkMultiplicity::default(); - - let (c, carries) = a.mul(&b, &mut lkm, true); - assert_eq!(c[0], 2); - assert_eq!(c[1], 0); - assert_eq!(carries[0], 0); - assert_eq!(carries[1], 0); - } - - #[test] - fn test_mul_carry() { - let a = Value::new_unchecked(u16::MAX as u32); - let b = Value::new_unchecked(2u32); - let mut lkm = LkMultiplicity::default(); - - let (c, carries) = a.mul(&b, &mut lkm, true); - assert_eq!(c[0], u16::MAX - 1); - assert_eq!(c[1], 1); - assert_eq!(carries[0], 1); - assert_eq!(carries[1], 0); - } - - #[test] - fn test_mul_overflow() { - let a = Value::new_unchecked(u32::MAX / 2 + 1); - let b = Value::new_unchecked(2u32); - let mut lkm = LkMultiplicity::default(); - - let (c, carries) = a.mul(&b, &mut lkm, true); - assert_eq!(c[0], 0); - assert_eq!(c[1], 0); - assert_eq!(carries[0], 0); - assert_eq!(carries[1], 1); + + mod value { + use crate::{witness::LkMultiplicity, Value}; + #[test] + fn test_add() { + let a = Value::new_unchecked(1u32); + let b = Value::new_unchecked(2u32); + let mut lkm = LkMultiplicity::default(); + + let (c, carries) = a.add(&b, &mut lkm, true); + assert_eq!(c[0], 3); + assert_eq!(c[1], 0); + assert_eq!(carries[0], 0); + assert_eq!(carries[1], 0); + } + + #[test] + fn test_add_carry() { + let a = Value::new_unchecked(u16::MAX as u32); + let b = Value::new_unchecked(2u32); + let mut lkm = LkMultiplicity::default(); + + let (c, carries) = a.add(&b, &mut lkm, true); + assert_eq!(c[0], 1); + assert_eq!(c[1], 1); + assert_eq!(carries[0], 1); + assert_eq!(carries[1], 0); + } + + #[test] + fn test_mul() { + let a = Value::new_unchecked(1u32); + let b = Value::new_unchecked(2u32); + let mut lkm = LkMultiplicity::default(); + + let (c, carries) = a.mul(&b, &mut lkm, true); + assert_eq!(c[0], 2); + assert_eq!(c[1], 0); + assert_eq!(carries[0], 0); + assert_eq!(carries[1], 0); + } + + #[test] + fn test_mul_carry() { + let a = Value::new_unchecked(u16::MAX as u32); + let b = Value::new_unchecked(2u32); + let mut lkm = LkMultiplicity::default(); + + let (c, carries) = a.mul(&b, &mut lkm, true); + assert_eq!(c[0], u16::MAX - 1); + assert_eq!(c[1], 1); + assert_eq!(carries[0], 1); + assert_eq!(carries[1], 0); + } + + #[test] + fn test_mul_overflow() { + let a = Value::new_unchecked(u32::MAX / 2 + 1); + let b = Value::new_unchecked(2u32); + let mut lkm = LkMultiplicity::default(); + + let (c, carries) = a.mul(&b, &mut lkm, true); + assert_eq!(c[0], 0); + assert_eq!(c[1], 0); + assert_eq!(carries[0], 0); + assert_eq!(carries[1], 1); + } } // #[test] // fn test_uint_from_cell_ids() { diff --git a/ceno_zkvm/src/uint/arithmetic.rs b/ceno_zkvm/src/uint/arithmetic.rs index cf242ccc4..817c83ca0 100644 --- a/ceno_zkvm/src/uint/arithmetic.rs +++ b/ceno_zkvm/src/uint/arithmetic.rs @@ -18,8 +18,7 @@ impl UIntLimbs { fn internal_add( &self, circuit_builder: &mut CircuitBuilder, - addend1: &Vec>, - addend2: &Vec>, + addend: &Vec>, with_overflow: bool, ) -> Result, ZKVMError> { let mut c = UIntLimbs::::new_as_empty(); @@ -30,9 +29,9 @@ impl UIntLimbs { // perform add operation // c[i] = a[i] + b[i] + carry[i-1] - carry[i] * 2 ^ C c.limbs = UintLimb::Expression( - (*addend1) + (self.expr()) .iter() - .zip((*addend2).iter()) + .zip((*addend).iter()) .enumerate() .map(|(i, (a, b))| { let carries = c.carries.as_ref().unwrap(); @@ -78,7 +77,7 @@ impl UIntLimbs { }) .collect_vec(); - self.internal_add(cb, &self.expr(), &b_limbs, with_overflow) + self.internal_add(cb, &b_limbs, with_overflow) }) } @@ -91,7 +90,7 @@ impl UIntLimbs { with_overflow: bool, ) -> Result, ZKVMError> { circuit_builder.namespace(name_fn, |cb| { - self.internal_add(cb, &self.expr(), &addend.expr(), with_overflow) + self.internal_add(cb, &addend.expr(), with_overflow) }) } @@ -181,6 +180,23 @@ impl UIntLimbs { }) } + pub fn mul_add, N: FnOnce() -> NR>( + &mut self, + name_fn: N, + circuit_builder: &mut CircuitBuilder, + multiplier: &mut UIntLimbs, + addend: &UIntLimbs, + with_overflow: bool, + ) -> Result<(UIntLimbs, UIntLimbs), ZKVMError> { + circuit_builder.namespace(name_fn, |cb| { + let c = self.internal_mul(cb, multiplier, with_overflow)?; + Ok(( + c.clone(), + c.internal_add(cb, &addend.expr(), with_overflow).unwrap(), + )) + }) + } + /// Check two UIntLimbs are equal pub fn eq, N: FnOnce() -> NR>( &self, @@ -190,7 +206,10 @@ impl UIntLimbs { ) -> Result<(), ZKVMError> { circuit_builder.namespace(name_fn, |cb| { izip!(self.expr(), rhs.expr()) - .try_for_each(|(lhs, rhs)| cb.require_equal(|| "uint_eq", lhs, rhs)) + .enumerate() + .try_for_each(|(i, (lhs, rhs))| { + cb.require_equal(|| format!("uint_eq_{i}"), lhs, rhs) + }) }) } @@ -274,13 +293,13 @@ impl UIntLimbs { // indicate the first non-zero byte index i_0 of a[i] - b[i] // from high to low - indexes - .iter() - .try_for_each(|idx| circuit_builder.assert_bit(|| "bit assert", idx.expr()))?; - let index_sum = indexes - .iter() - .fold(Expression::from(0), |acc, idx| acc + idx.expr()); - circuit_builder.assert_bit(|| "bit assert", index_sum)?; + // indexes + // .iter() + // .try_for_each(|idx| circuit_builder.assert_bit(|| "bit assert", idx.expr()))?; + // let index_sum = indexes + // .iter() + // .fold(Expression::from(0), |acc, idx| acc + idx.expr()); + // circuit_builder.assert_bit(|| "bit assert", index_sum)?; // equal zero if a==b, otherwise equal (a[i_0]-b[i_0])^{-1} let byte_diff_inv = circuit_builder.create_witin(|| "byte_diff_inverse")?; @@ -307,9 +326,10 @@ impl UIntLimbs { si.iter() .zip(self.limbs.iter()) .zip(rhs.limbs.iter()) - .try_for_each(|((flag, a), b)| { + .enumerate() + .try_for_each(|(i, ((flag, a), b))| { circuit_builder.require_zero( - || "byte diff zero check", + || format!("byte diff {i} zero check"), a.expr() - b.expr() - flag.expr() * a.expr() + flag.expr() * b.expr(), ) })?; @@ -346,7 +366,6 @@ impl UIntLimbs { )?; let is_ltu = circuit_builder.create_witin(|| "is_ltu")?; - // circuit_builder.assert_bit(is_ltu.expr())?; // lookup ensure it is bit // now we know the first non-equal byte pairs is (lhs_ne_byte, rhs_ne_byte) circuit_builder.lookup_ltu_byte(lhs_ne_byte.expr(), rhs_ne_byte.expr(), is_ltu.expr())?; Ok(UIntLtuConfig { @@ -365,7 +384,7 @@ impl UIntLimbs { rhs: &UIntLimbs, ) -> Result { let is_lt = circuit_builder.create_witin(|| "is_lt")?; - circuit_builder.assert_bit(|| "assert_bit", is_lt.expr())?; + // circuit_builder.assert_bit(|| "assert_bit", is_lt.expr())?; let lhs_msb = self.msb_decompose(circuit_builder)?; let rhs_msb = rhs.msb_decompose(circuit_builder)?; @@ -694,13 +713,10 @@ mod tests { #[test] fn test_mul32_16_w_carries() { - // a = 256 - // b = 257 - // c = 256 + 1 * 2^16 = 65,792 - let wit_a = vec![256, 0]; - let wit_b = vec![257, 0]; - let wit_c = vec![256, 1]; - let wit_carries = vec![1, 0]; + let wit_a = vec![48683, 2621]; + let wit_b = vec![7, 0]; + let wit_c = vec![13101, 18352]; + let wit_carries = vec![5, 0]; let witness_values = [wit_a, wit_b, wit_c, wit_carries].concat(); verify::<32, 16, E>(witness_values, false); } @@ -969,5 +985,50 @@ mod tests { ); }); } + + #[test] + fn test_mul_add2() { + // c = a * b + // e = c + d + + // a = 1 + 1 * 2^16 + // b = 2 + 1 * 2^16 + // ==> c = 2 + 3 * 2^16 + 1 * 2^32 + // d = 1 + 1 * 2^16 + // ==> e = 3 + 4 * 2^16 + 1 * 2^32 + let a = vec![1, 1, 0, 0]; + let b = vec![2, 1, 0, 0]; + let c = vec![2, 3, 1, 0]; + let c_carries = vec![0; 3]; + // e = c + d + let d = vec![1, 1, 0, 0]; + let e = vec![3, 4, 1, 0]; + let e_carries = vec![0; 3]; + + let witness_values: Vec = [a, b, d, c, c_carries, e_carries] + .concat() + .iter() + .map(|&a| a.into()) + .collect_vec(); + + let mut cs = ConstraintSystem::new(|| "test_mul_add"); + let mut cb = CircuitBuilder::::new(&mut cs); + let challenges = (0..witness_values.len()).map(|_| 1.into()).collect_vec(); + + let mut uint_a = UIntLimbs::<64, 16, E>::new(|| "uint_a", &mut cb).unwrap(); + let mut uint_b = UIntLimbs::<64, 16, E>::new(|| "uint_b", &mut cb).unwrap(); + let mut uint_d = UIntLimbs::<64, 16, E>::new(|| "uint_d", &mut cb).unwrap(); + let (_, uint_e) = uint_a + .mul_add(|| "uint_c", &mut cb, &mut uint_b, &mut uint_d, false) + .unwrap(); + + uint_e.expr().iter().enumerate().for_each(|(i, ret)| { + // limbs check + assert_eq!( + eval_by_expr(&witness_values, &challenges, ret), + E::from(e.clone()[i]) + ); + }); + } } } diff --git a/ceno_zkvm/src/utils.rs b/ceno_zkvm/src/utils.rs index 02edbb078..d8e1c710d 100644 --- a/ceno_zkvm/src/utils.rs +++ b/ceno_zkvm/src/utils.rs @@ -2,6 +2,7 @@ use ff::Field; use ff_ext::ExtensionField; use goldilocks::SmallField; use itertools::Itertools; +use std::mem; use transcript::Transcript; /// convert ext field element to u64, assume it is inside the range @@ -33,6 +34,21 @@ pub fn limb_u8_to_u16(input: &[u8]) -> Vec { .collect() } +pub fn split_to_u8>(value: T) -> Vec { + let value: u64 = value.into(); // Convert to u64 for generality + let limbs: usize = { + let u8_bytes = (u16::BITS / 8) as usize; + mem::size_of::() / u8_bytes + }; + (0..limbs) + .scan(value, |acc, _| { + let limb = (*acc & 0xFF) as u8; + *acc >>= 8; + Some(limb) + }) + .collect_vec() +} + /// Compile time evaluated minimum function /// returns min(a, b) pub(crate) const fn const_min(a: usize, b: usize) -> usize { diff --git a/ceno_zkvm/src/witness.rs b/ceno_zkvm/src/witness.rs index cd5a1b186..128cb2fff 100644 --- a/ceno_zkvm/src/witness.rs +++ b/ceno_zkvm/src/witness.rs @@ -118,6 +118,7 @@ impl LkMultiplicity { 16 => self.increment(ROMType::U16, v), 8 => self.increment(ROMType::U8, v), 5 => self.increment(ROMType::U5, v), + 1 => self.increment(ROMType::U1, v), _ => panic!("Unsupported bit range"), } } diff --git a/mpcs/benches/basecode.rs b/mpcs/benches/basecode.rs index e9cf65ddd..a870e1506 100644 --- a/mpcs/benches/basecode.rs +++ b/mpcs/benches/basecode.rs @@ -9,7 +9,8 @@ use mpcs::{ util::{ arithmetic::interpolate_field_type_over_boolean_hypercube, plonky2_util::reverse_index_bits_in_place_field_type, - }, Basefold, BasefoldBasecodeParams, BasefoldSpec, EncodingScheme, PolynomialCommitmentScheme + }, + Basefold, BasefoldBasecodeParams, BasefoldSpec, EncodingScheme, PolynomialCommitmentScheme, }; use multilinear_extensions::mle::{DenseMultilinearExtension, FieldType}; @@ -55,7 +56,6 @@ fn bench_encoding(c: &mut Criterion, is_base: bool) { }) .collect_vec(); - group.bench_function( BenchmarkId::new("batch_encode", format!("{}-{}", num_vars, batch_size)), |b| { diff --git a/mpcs/benches/commit_open_verify_rs.rs b/mpcs/benches/commit_open_verify_rs.rs index eab916486..14840a918 100644 --- a/mpcs/benches/commit_open_verify_rs.rs +++ b/mpcs/benches/commit_open_verify_rs.rs @@ -10,11 +10,13 @@ use mpcs::{ PolynomialCommitmentScheme, }; -use multilinear_extensions::mle::{DenseMultilinearExtension, MultilinearExtension}; +use multilinear_extensions::{ + mle::{DenseMultilinearExtension, MultilinearExtension}, + virtual_poly_v2::ArcMultilinearExtension, +}; use rand::{rngs::OsRng, SeedableRng}; use rand_chacha::ChaCha8Rng; use rayon::iter::{IntoParallelIterator, ParallelIterator}; -use multilinear_extensions::virtual_poly_v2::ArcMultilinearExtension; use transcript::Transcript; type Pcs = Basefold; @@ -293,9 +295,8 @@ fn bench_simple_batch_commit_open_verify_goldilocks(c: &mut Criterion, is_base: }, ); - let polys: Vec> = polys.into_iter() - .map(|poly| poly.into()) - .collect_vec(); + let polys: Vec> = + polys.into_iter().map(|poly| poly.into()).collect_vec(); let point = (0..num_vars) .map(|_| transcript.get_and_append_challenge(b"Point").elements) diff --git a/mpcs/benches/hashing.rs b/mpcs/benches/hashing.rs new file mode 100644 index 000000000..d9baa6ab8 --- /dev/null +++ b/mpcs/benches/hashing.rs @@ -0,0 +1,39 @@ +use ark_std::test_rng; +use criterion::{criterion_group, criterion_main, Criterion}; +use ff::Field; +use goldilocks::Goldilocks; +use mpcs::util::hash::{hash_two_digests, new_hasher, Digest, DIGEST_WIDTH}; + +fn random_ceno_goldy() -> Goldilocks { + Goldilocks::random(&mut test_rng()) +} +pub fn criterion_benchmark(c: &mut Criterion) { + let hasher = new_hasher(); + let left = Digest( + vec![Goldilocks::random(&mut test_rng()); 4] + .try_into() + .unwrap(), + ); + let right = Digest( + vec![Goldilocks::random(&mut test_rng()); 4] + .try_into() + .unwrap(), + ); + c.bench_function("ceno hash 2 to 1", |bencher| { + bencher.iter(|| hash_two_digests(&left, &right, &hasher)) + }); + + let mut hasher = new_hasher(); + let values = (0..60) + .map(|_| Goldilocks::random(&mut test_rng())) + .collect::>(); + c.bench_function("ceno hash 60 to 1", |bencher| { + bencher.iter(|| { + hasher.update(values.as_slice()); + let result = &hasher.squeeze_vec()[0..DIGEST_WIDTH]; + }) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/mpcs/benches/rscode.rs b/mpcs/benches/rscode.rs index 30190b99a..9f2cb9813 100644 --- a/mpcs/benches/rscode.rs +++ b/mpcs/benches/rscode.rs @@ -10,8 +10,7 @@ use mpcs::{ arithmetic::interpolate_field_type_over_boolean_hypercube, plonky2_util::reverse_index_bits_in_place_field_type, }, - Basefold, BasefoldRSParams, BasefoldSpec, EncodingScheme, - PolynomialCommitmentScheme, + Basefold, BasefoldRSParams, BasefoldSpec, EncodingScheme, PolynomialCommitmentScheme, }; use multilinear_extensions::mle::{DenseMultilinearExtension, FieldType}; @@ -57,7 +56,6 @@ fn bench_encoding(c: &mut Criterion, is_base: bool) { }) .collect_vec(); - group.bench_function( BenchmarkId::new("batch_encode", format!("{}-{}", num_vars, batch_size)), |b| { diff --git a/poseidon/Cargo.toml b/poseidon/Cargo.toml new file mode 100644 index 000000000..9560dffba --- /dev/null +++ b/poseidon/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "poseidon" +version.workspace = true +edition.workspace = true +license.workspace = true + +[dependencies] +criterion.workspace = true +ff.workspace = true +goldilocks.workspace = true +serde.workspace = true +unroll = "0.1.5" + +[dev-dependencies] +plonky2 = "0.2.2" +rand = "0.8.5" +ark-std.workspace = true + +[[bench]] +name = "hashing" +harness = false diff --git a/poseidon/benches/hashing.rs b/poseidon/benches/hashing.rs new file mode 100644 index 000000000..a9417c023 --- /dev/null +++ b/poseidon/benches/hashing.rs @@ -0,0 +1,132 @@ +use ark_std::test_rng; +use criterion::{black_box, criterion_group, criterion_main, BatchSize, Criterion}; +use ff::Field; +use goldilocks::Goldilocks; +use plonky2::{ + field::{goldilocks_field::GoldilocksField, types::Sample}, + hash::{ + hash_types::HashOut, + hashing::PlonkyPermutation, + poseidon::{PoseidonHash as PlonkyPoseidonHash, PoseidonPermutation}, + }, + plonk::config::Hasher, +}; +use poseidon::{digest::Digest, poseidon_hash::PoseidonHash}; + +fn random_plonky_2_goldy() -> GoldilocksField { + GoldilocksField::rand() +} + +fn random_ceno_goldy() -> Goldilocks { + Goldilocks::random(&mut test_rng()) +} + +fn random_ceno_hash() -> Digest { + Digest( + vec![Goldilocks::random(&mut test_rng()); 4] + .try_into() + .unwrap(), + ) +} + +fn plonky_hash_single(a: GoldilocksField) { + let result = black_box(PlonkyPoseidonHash::hash_or_noop(&[a])); +} + +fn ceno_hash_single(a: Goldilocks) { + let result = black_box(PoseidonHash::hash_or_noop(&[a])); +} + +fn plonky_hash_2_to_1(left: HashOut, right: HashOut) { + let result = black_box(PlonkyPoseidonHash::two_to_one(left, right)); +} + +fn ceno_hash_2_to_1(left: &Digest, right: &Digest) { + let result = black_box(PoseidonHash::two_to_one(left, right)); +} + +fn plonky_hash_many_to_1(values: &[GoldilocksField]) { + let result = black_box(PlonkyPoseidonHash::hash_or_noop(values)); +} + +fn ceno_hash_many_to_1(values: &[Goldilocks]) { + let result = black_box(PoseidonHash::hash_or_noop(values)); +} + +pub fn hashing_benchmark(c: &mut Criterion) { + c.bench_function("plonky hash single", |bencher| { + bencher.iter_batched( + || random_plonky_2_goldy(), + |p_a| plonky_hash_single(p_a), + BatchSize::SmallInput, + ) + }); + + c.bench_function("plonky hash 2 to 1", |bencher| { + bencher.iter_batched( + || { + ( + HashOut::::rand(), + HashOut::::rand(), + ) + }, + |(left, right)| plonky_hash_2_to_1(left, right), + BatchSize::SmallInput, + ) + }); + + c.bench_function("plonky hash 60 to 1", |bencher| { + bencher.iter_batched( + || GoldilocksField::rand_vec(60), + |sixty_elems| plonky_hash_many_to_1(sixty_elems.as_slice()), + BatchSize::SmallInput, + ) + }); + + c.bench_function("ceno hash single", |bencher| { + bencher.iter_batched( + || random_ceno_goldy(), + |c_a| ceno_hash_single(c_a), + BatchSize::SmallInput, + ) + }); + + c.bench_function("ceno hash 2 to 1", |bencher| { + bencher.iter_batched( + || (random_ceno_hash(), random_ceno_hash()), + |(left, right)| ceno_hash_2_to_1(&left, &right), + BatchSize::SmallInput, + ) + }); + + c.bench_function("ceno hash 60 to 1", |bencher| { + bencher.iter_batched( + || { + (0..60) + .map(|_| Goldilocks::random(&mut test_rng())) + .collect::>() + }, + |values| ceno_hash_many_to_1(values.as_slice()), + BatchSize::SmallInput, + ) + }); +} + +// bench permutation +pub fn permutation_benchmark(c: &mut Criterion) { + let mut plonky_permutation = PoseidonPermutation::new(core::iter::repeat(GoldilocksField(0))); + let mut ceno_permutation = poseidon::poseidon_permutation::PoseidonPermutation::new( + core::iter::repeat(Goldilocks::ZERO), + ); + + c.bench_function("plonky permute", |bencher| { + bencher.iter(|| plonky_permutation.permute()) + }); + + c.bench_function("ceno permute", |bencher| { + bencher.iter(|| ceno_permutation.permute()) + }); +} + +criterion_group!(benches, permutation_benchmark, hashing_benchmark); +criterion_main!(benches); diff --git a/poseidon/src/constants.rs b/poseidon/src/constants.rs new file mode 100644 index 000000000..db170d100 --- /dev/null +++ b/poseidon/src/constants.rs @@ -0,0 +1,121 @@ +pub(crate) const DIGEST_WIDTH: usize = 4; + +pub(crate) const SPONGE_RATE: usize = 8; +pub(crate) const SPONGE_CAPACITY: usize = 4; +pub(crate) const SPONGE_WIDTH: usize = SPONGE_RATE + SPONGE_CAPACITY; + +// The number of full rounds and partial rounds is given by the +// calc_round_numbers.py script. They happen to be the same for both +// width 8 and width 12 with s-box x^7. +// +// NB: Changing any of these values will require regenerating all of +// the precomputed constant arrays in this file. +pub const HALF_N_FULL_ROUNDS: usize = 4; +pub(crate) const N_FULL_ROUNDS_TOTAL: usize = 2 * HALF_N_FULL_ROUNDS; +pub const N_PARTIAL_ROUNDS: usize = 22; +pub const N_ROUNDS: usize = N_FULL_ROUNDS_TOTAL + N_PARTIAL_ROUNDS; +const MAX_WIDTH: usize = 12; // we only have width 8 and 12, and 12 is bigger. :) + +/// Note that these work for the Goldilocks field, but not necessarily others. See +/// `generate_constants` about how these were generated. We include enough for a width of 12; +/// smaller widths just use a subset. +#[rustfmt::skip] +pub const ALL_ROUND_CONSTANTS: [u64; MAX_WIDTH * N_ROUNDS] = [ + // WARNING: The AVX2 Goldilocks specialization relies on all round constants being in + // 0..0xfffeeac900011537. If these constants are randomly regenerated, there is a ~.6% chance + // that this condition will no longer hold. + // + // WARNING: If these are changed in any way, then all the + // implementations of Poseidon must be regenerated. See comments + // in `poseidon_goldilocks.rs`. + 0xb585f766f2144405, 0x7746a55f43921ad7, 0xb2fb0d31cee799b4, 0x0f6760a4803427d7, + 0xe10d666650f4e012, 0x8cae14cb07d09bf1, 0xd438539c95f63e9f, 0xef781c7ce35b4c3d, + 0xcdc4a239b0c44426, 0x277fa208bf337bff, 0xe17653a29da578a1, 0xc54302f225db2c76, + 0x86287821f722c881, 0x59cd1a8a41c18e55, 0xc3b919ad495dc574, 0xa484c4c5ef6a0781, + 0x308bbd23dc5416cc, 0x6e4a40c18f30c09c, 0x9a2eedb70d8f8cfa, 0xe360c6e0ae486f38, + 0xd5c7718fbfc647fb, 0xc35eae071903ff0b, 0x849c2656969c4be7, 0xc0572c8c08cbbbad, + 0xe9fa634a21de0082, 0xf56f6d48959a600d, 0xf7d713e806391165, 0x8297132b32825daf, + 0xad6805e0e30b2c8a, 0xac51d9f5fcf8535e, 0x502ad7dc18c2ad87, 0x57a1550c110b3041, + 0x66bbd30e6ce0e583, 0x0da2abef589d644e, 0xf061274fdb150d61, 0x28b8ec3ae9c29633, + 0x92a756e67e2b9413, 0x70e741ebfee96586, 0x019d5ee2af82ec1c, 0x6f6f2ed772466352, + 0x7cf416cfe7e14ca1, 0x61df517b86a46439, 0x85dc499b11d77b75, 0x4b959b48b9c10733, + 0xe8be3e5da8043e57, 0xf5c0bc1de6da8699, 0x40b12cbf09ef74bf, 0xa637093ecb2ad631, + 0x3cc3f892184df408, 0x2e479dc157bf31bb, 0x6f49de07a6234346, 0x213ce7bede378d7b, + 0x5b0431345d4dea83, 0xa2de45780344d6a1, 0x7103aaf94a7bf308, 0x5326fc0d97279301, + 0xa9ceb74fec024747, 0x27f8ec88bb21b1a3, 0xfceb4fda1ded0893, 0xfac6ff1346a41675, + 0x7131aa45268d7d8c, 0x9351036095630f9f, 0xad535b24afc26bfb, 0x4627f5c6993e44be, + 0x645cf794b8f1cc58, 0x241c70ed0af61617, 0xacb8e076647905f1, 0x3737e9db4c4f474d, + 0xe7ea5e33e75fffb6, 0x90dee49fc9bfc23a, 0xd1b1edf76bc09c92, 0x0b65481ba645c602, + 0x99ad1aab0814283b, 0x438a7c91d416ca4d, 0xb60de3bcc5ea751c, 0xc99cab6aef6f58bc, + 0x69a5ed92a72ee4ff, 0x5e7b329c1ed4ad71, 0x5fc0ac0800144885, 0x32db829239774eca, + 0x0ade699c5830f310, 0x7cc5583b10415f21, 0x85df9ed2e166d64f, 0x6604df4fee32bcb1, + 0xeb84f608da56ef48, 0xda608834c40e603d, 0x8f97fe408061f183, 0xa93f485c96f37b89, + 0x6704e8ee8f18d563, 0xcee3e9ac1e072119, 0x510d0e65e2b470c1, 0xf6323f486b9038f0, + 0x0b508cdeffa5ceef, 0xf2417089e4fb3cbd, 0x60e75c2890d15730, 0xa6217d8bf660f29c, + 0x7159cd30c3ac118e, 0x839b4e8fafead540, 0x0d3f3e5e82920adc, 0x8f7d83bddee7bba8, + 0x780f2243ea071d06, 0xeb915845f3de1634, 0xd19e120d26b6f386, 0x016ee53a7e5fecc6, + 0xcb5fd54e7933e477, 0xacb8417879fd449f, 0x9c22190be7f74732, 0x5d693c1ba3ba3621, + 0xdcef0797c2b69ec7, 0x3d639263da827b13, 0xe273fd971bc8d0e7, 0x418f02702d227ed5, + 0x8c25fda3b503038c, 0x2cbaed4daec8c07c, 0x5f58e6afcdd6ddc2, 0x284650ac5e1b0eba, + 0x635b337ee819dab5, 0x9f9a036ed4f2d49f, 0xb93e260cae5c170e, 0xb0a7eae879ddb76d, + 0xd0762cbc8ca6570c, 0x34c6efb812b04bf5, 0x40bf0ab5fa14c112, 0xb6b570fc7c5740d3, + 0x5a27b9002de33454, 0xb1a5b165b6d2b2d2, 0x8722e0ace9d1be22, 0x788ee3b37e5680fb, + 0x14a726661551e284, 0x98b7672f9ef3b419, 0xbb93ae776bb30e3a, 0x28fd3b046380f850, + 0x30a4680593258387, 0x337dc00c61bd9ce1, 0xd5eca244c7a4ff1d, 0x7762638264d279bd, + 0xc1e434bedeefd767, 0x0299351a53b8ec22, 0xb2d456e4ad251b80, 0x3e9ed1fda49cea0b, + 0x2972a92ba450bed8, 0x20216dd77be493de, 0xadffe8cf28449ec6, 0x1c4dbb1c4c27d243, + 0x15a16a8a8322d458, 0x388a128b7fd9a609, 0x2300e5d6baedf0fb, 0x2f63aa8647e15104, + 0xf1c36ce86ecec269, 0x27181125183970c9, 0xe584029370dca96d, 0x4d9bbc3e02f1cfb2, + 0xea35bc29692af6f8, 0x18e21b4beabb4137, 0x1e3b9fc625b554f4, 0x25d64362697828fd, + 0x5a3f1bb1c53a9645, 0xdb7f023869fb8d38, 0xb462065911d4e1fc, 0x49c24ae4437d8030, + 0xd793862c112b0566, 0xaadd1106730d8feb, 0xc43b6e0e97b0d568, 0xe29024c18ee6fca2, + 0x5e50c27535b88c66, 0x10383f20a4ff9a87, 0x38e8ee9d71a45af8, 0xdd5118375bf1a9b9, + 0x775005982d74d7f7, 0x86ab99b4dde6c8b0, 0xb1204f603f51c080, 0xef61ac8470250ecf, + 0x1bbcd90f132c603f, 0x0cd1dabd964db557, 0x11a3ae5beb9d1ec9, 0xf755bfeea585d11d, + 0xa3b83250268ea4d7, 0x516306f4927c93af, 0xddb4ac49c9efa1da, 0x64bb6dec369d4418, + 0xf9cc95c22b4c1fcc, 0x08d37f755f4ae9f6, 0xeec49b613478675b, 0xf143933aed25e0b0, + 0xe4c5dd8255dfc622, 0xe7ad7756f193198e, 0x92c2318b87fff9cb, 0x739c25f8fd73596d, + 0x5636cac9f16dfed0, 0xdd8f909a938e0172, 0xc6401fe115063f5b, 0x8ad97b33f1ac1455, + 0x0c49366bb25e8513, 0x0784d3d2f1698309, 0x530fb67ea1809a81, 0x410492299bb01f49, + 0x139542347424b9ac, 0x9cb0bd5ea1a1115e, 0x02e3f615c38f49a1, 0x985d4f4a9c5291ef, + 0x775b9feafdcd26e7, 0x304265a6384f0f2d, 0x593664c39773012c, 0x4f0a2e5fb028f2ce, + 0xdd611f1000c17442, 0xd8185f9adfea4fd0, 0xef87139ca9a3ab1e, 0x3ba71336c34ee133, + 0x7d3a455d56b70238, 0x660d32e130182684, 0x297a863f48cd1f43, 0x90e0a736a751ebb7, + 0x549f80ce550c4fd3, 0x0f73b2922f38bd64, 0x16bf1f73fb7a9c3f, 0x6d1f5a59005bec17, + 0x02ff876fa5ef97c4, 0xc5cb72a2a51159b0, 0x8470f39d2d5c900e, 0x25abb3f1d39fcb76, + 0x23eb8cc9b372442f, 0xd687ba55c64f6364, 0xda8d9e90fd8ff158, 0xe3cbdc7d2fe45ea7, + 0xb9a8c9b3aee52297, 0xc0d28a5c10960bd3, 0x45d7ac9b68f71a34, 0xeeb76e397069e804, + 0x3d06c8bd1514e2d9, 0x9c9c98207cb10767, 0x65700b51aedfb5ef, 0x911f451539869408, + 0x7ae6849fbc3a0ec6, 0x3bb340eba06afe7e, 0xb46e9d8b682ea65e, 0x8dcf22f9a3b34356, + 0x77bdaeda586257a7, 0xf19e400a5104d20d, 0xc368a348e46d950f, 0x9ef1cd60e679f284, + 0xe89cd854d5d01d33, 0x5cd377dc8bb882a2, 0xa7b0fb7883eee860, 0x7684403ec392950d, + 0x5fa3f06f4fed3b52, 0x8df57ac11bc04831, 0x2db01efa1e1e1897, 0x54846de4aadb9ca2, + 0xba6745385893c784, 0x541d496344d2c75b, 0xe909678474e687fe, 0xdfe89923f6c9c2ff, + 0xece5a71e0cfedc75, 0x5ff98fd5d51fe610, 0x83e8941918964615, 0x5922040b47f150c1, + 0xf97d750e3dd94521, 0x5080d4c2b86f56d7, 0xa7de115b56c78d70, 0x6a9242ac87538194, + 0xf7856ef7f9173e44, 0x2265fc92feb0dc09, 0x17dfc8e4f7ba8a57, 0x9001a64209f21db8, + 0x90004c1371b893c5, 0xb932b7cf752e5545, 0xa0b1df81b6fe59fc, 0x8ef1dd26770af2c2, + 0x0541a4f9cfbeed35, 0x9e61106178bfc530, 0xb3767e80935d8af2, 0x0098d5782065af06, + 0x31d191cd5c1466c7, 0x410fefafa319ac9d, 0xbdf8f242e316c4ab, 0x9e8cd55b57637ed0, + 0xde122bebe9a39368, 0x4d001fd58f002526, 0xca6637000eb4a9f8, 0x2f2339d624f91f78, + 0x6d1a7918c80df518, 0xdf9a4939342308e9, 0xebc2151ee6c8398c, 0x03cc2ba8a1116515, + 0xd341d037e840cf83, 0x387cb5d25af4afcc, 0xbba2515f22909e87, 0x7248fe7705f38e47, + 0x4d61e56a525d225a, 0x262e963c8da05d3d, 0x59e89b094d220ec2, 0x055d5b52b78b9c5e, + 0x82b27eb33514ef99, 0xd30094ca96b7ce7b, 0xcf5cb381cd0a1535, 0xfeed4db6919e5a7c, + 0x41703f53753be59f, 0x5eeea940fcde8b6f, 0x4cd1f1b175100206, 0x4a20358574454ec0, + 0x1478d361dbbf9fac, 0x6f02dc07d141875c, 0x296a202ed8e556a2, 0x2afd67999bf32ee5, + 0x7acfd96efa95491d, 0x6798ba0c0abb2c6d, 0x34c6f57b26c92122, 0x5736e1bad206b5de, + 0x20057d2a0056521b, 0x3dea5bd5d0578bd7, 0x16e50d897d4634ac, 0x29bff3ecb9b7a6e3, + 0x475cd3205a3bdcde, 0x18a42105c31b7e88, 0x023e7414af663068, 0x15147108121967d7, + 0xe4a3dff1d7d6fef9, 0x01a8d1a588085737, 0x11b4c74eda62beef, 0xe587cc0d69a73346, + 0x1ff7327017aa2a6e, 0x594e29c42473d06b, 0xf6f31db1899b12d5, 0xc02ac5e47312d3ca, + 0xe70201e960cb78b8, 0x6f90ff3b6a65f108, 0x42747a7245e7fa84, 0xd1f507e43ab749b2, + 0x1c86d265f15750cd, 0x3996ce73dd832c1c, 0x8e7fba02983224bd, 0xba0dec7103255dd4, + 0x9e9cbd781628fc5b, 0xdae8645996edd6a5, 0xdebe0853b1a1d378, 0xa49229d24d014343, + 0x7be5b9ffda905e1c, 0xa3c95eaec244aa30, 0x0230bca8f4df0544, 0x4135c2bebfe148c6, + 0x166fc0cc438a3c72, 0x3762b59a8ae83efa, 0xe8928a4c89114750, 0x2a440b51a4945ee5, + 0x80cefd2b7d99ff83, 0xbb9879c6e61fd62a, 0x6e7c8f1a84265034, 0x164bb2de1bbeddc8, + 0xf3c12fe54d5c653b, 0x40b9e922ed9771e2, 0x551f5b0fbe7b1840, 0x25032aa7c4cb1811, + 0xaaed34074b164346, 0x8ffd96bbf9c9c81d, 0x70fc91eb5937085c, 0x7f795e2a5f915440, + 0x4543d9df5476d3cb, 0xf172d73e004fc90d, 0xdfd1c4febcc81238, 0xbc8dfb627fe558fc, +]; diff --git a/poseidon/src/digest.rs b/poseidon/src/digest.rs new file mode 100644 index 000000000..400682b1a --- /dev/null +++ b/poseidon/src/digest.rs @@ -0,0 +1,31 @@ +use crate::constants::DIGEST_WIDTH; +use goldilocks::SmallField; + +#[derive(Debug)] +pub struct Digest(pub [F; DIGEST_WIDTH]); + +impl TryFrom> for Digest { + type Error = String; + + fn try_from(values: Vec) -> Result { + if values.len() != DIGEST_WIDTH { + return Err(format!( + "can only create digest from {DIGEST_WIDTH} elements" + )); + } + + Ok(Digest(values.try_into().unwrap())) + } +} + +impl Digest { + pub(crate) fn from_partial(inputs: &[F]) -> Self { + let mut elements = [F::ZERO; DIGEST_WIDTH]; + elements[0..inputs.len()].copy_from_slice(inputs); + Self(elements) + } + + pub(crate) fn elements(&self) -> &[F] { + &self.0 + } +} diff --git a/poseidon/src/lib.rs b/poseidon/src/lib.rs new file mode 100644 index 000000000..b8c1a28fd --- /dev/null +++ b/poseidon/src/lib.rs @@ -0,0 +1,6 @@ +pub(crate) mod constants; +pub mod digest; +pub(crate) mod poseidon; +mod poseidon_goldilocks; +pub mod poseidon_hash; +pub mod poseidon_permutation; diff --git a/poseidon/src/poseidon.rs b/poseidon/src/poseidon.rs new file mode 100644 index 000000000..ed5d76d14 --- /dev/null +++ b/poseidon/src/poseidon.rs @@ -0,0 +1,266 @@ +use crate::constants::{ + ALL_ROUND_CONSTANTS, HALF_N_FULL_ROUNDS, N_PARTIAL_ROUNDS, N_ROUNDS, SPONGE_WIDTH, +}; +use goldilocks::SmallField; +use unroll::unroll_for_loops; + +pub trait Poseidon: AdaptedField { + // Total number of round constants required: width of the input + // times number of rounds. + const N_ROUND_CONSTANTS: usize = SPONGE_WIDTH * N_ROUNDS; + + // The MDS matrix we use is C + D, where C is the circulant matrix whose first + // row is given by `MDS_MATRIX_CIRC`, and D is the diagonal matrix whose + // diagonal is given by `MDS_MATRIX_DIAG`. + const MDS_MATRIX_CIRC: [u64; SPONGE_WIDTH]; + const MDS_MATRIX_DIAG: [u64; SPONGE_WIDTH]; + + // Precomputed constants for the fast Poseidon calculation. See + // the paper. + const FAST_PARTIAL_FIRST_ROUND_CONSTANT: [u64; SPONGE_WIDTH]; + const FAST_PARTIAL_ROUND_CONSTANTS: [u64; N_PARTIAL_ROUNDS]; + const FAST_PARTIAL_ROUND_VS: [[u64; SPONGE_WIDTH - 1]; N_PARTIAL_ROUNDS]; + const FAST_PARTIAL_ROUND_W_HATS: [[u64; SPONGE_WIDTH - 1]; N_PARTIAL_ROUNDS]; + const FAST_PARTIAL_ROUND_INITIAL_MATRIX: [[u64; SPONGE_WIDTH - 1]; SPONGE_WIDTH - 1]; + + #[inline] + fn poseidon(input: [Self; SPONGE_WIDTH]) -> [Self; SPONGE_WIDTH] { + let mut state = input; + let mut round_ctr = 0; + + Self::full_rounds(&mut state, &mut round_ctr); + Self::partial_rounds(&mut state, &mut round_ctr); + Self::full_rounds(&mut state, &mut round_ctr); + debug_assert_eq!(round_ctr, N_ROUNDS); + + state + } + + #[inline] + fn full_rounds(state: &mut [Self; SPONGE_WIDTH], round_ctr: &mut usize) { + for _ in 0..HALF_N_FULL_ROUNDS { + Self::constant_layer(state, *round_ctr); + Self::sbox_layer(state); + *state = Self::mds_layer(state); + *round_ctr += 1; + } + } + + #[inline] + fn partial_rounds(state: &mut [Self; SPONGE_WIDTH], round_ctr: &mut usize) { + Self::partial_first_constant_layer(state); + *state = Self::mds_partial_layer_init(state); + + for i in 0..N_PARTIAL_ROUNDS { + state[0] = Self::sbox_monomial(state[0]); + unsafe { + state[0] = state[0].add_canonical_u64(Self::FAST_PARTIAL_ROUND_CONSTANTS[i]); + } + *state = Self::mds_partial_layer_fast(state, i); + } + *round_ctr += N_PARTIAL_ROUNDS; + } + + #[inline(always)] + #[unroll_for_loops] + fn constant_layer(state: &mut [Self; SPONGE_WIDTH], round_ctr: usize) { + for i in 0..12 { + if i < SPONGE_WIDTH { + let round_constant = ALL_ROUND_CONSTANTS[i + SPONGE_WIDTH * round_ctr]; + unsafe { + state[i] = state[i].add_canonical_u64(round_constant); + } + } + } + } + + #[inline(always)] + #[unroll_for_loops] + fn sbox_layer(state: &mut [Self; SPONGE_WIDTH]) { + for i in 0..12 { + if i < SPONGE_WIDTH { + state[i] = Self::sbox_monomial(state[i]); + } + } + } + + #[inline(always)] + #[unroll_for_loops] + fn mds_layer(state_: &[Self; SPONGE_WIDTH]) -> [Self; SPONGE_WIDTH] { + let mut result = [Self::ZERO; SPONGE_WIDTH]; + + let mut state = [0u64; SPONGE_WIDTH]; + for r in 0..SPONGE_WIDTH { + state[r] = state_[r].to_noncanonical_u64(); + } + + // This is a hacky way of fully unrolling the loop. + for r in 0..12 { + if r < SPONGE_WIDTH { + let sum = Self::mds_row_shf(r, &state); + let sum_lo = sum as u64; + let sum_hi = (sum >> 64) as u32; + result[r] = Self::from_noncanonical_u96(sum_lo, sum_hi); + } + } + + result + } + + #[inline(always)] + #[unroll_for_loops] + fn partial_first_constant_layer(state: &mut [Self; SPONGE_WIDTH]) { + for i in 0..12 { + if i < SPONGE_WIDTH { + state[i] += Self::from_canonical_u64(Self::FAST_PARTIAL_FIRST_ROUND_CONSTANT[i]); + } + } + } + + #[inline(always)] + #[unroll_for_loops] + fn mds_partial_layer_init(state: &[Self; SPONGE_WIDTH]) -> [Self; SPONGE_WIDTH] { + let mut result = [Self::ZERO; SPONGE_WIDTH]; + + // Initial matrix has first row/column = [1, 0, ..., 0]; + + // c = 0 + result[0] = state[0]; + + for r in 1..12 { + if r < SPONGE_WIDTH { + for c in 1..12 { + if c < SPONGE_WIDTH { + // NB: FAST_PARTIAL_ROUND_INITIAL_MATRIX is stored in + // row-major order so that this dot product is cache + // friendly. + let t = Self::from_canonical_u64( + Self::FAST_PARTIAL_ROUND_INITIAL_MATRIX[r - 1][c - 1], + ); + result[c] += state[r] * t; + } + } + } + } + result + } + + #[inline(always)] + fn sbox_monomial(x: Self) -> Self { + // Observed a performance improvement by using x*x rather than x.square(). + // In Plonky2, where this function originates, operations might be over an algebraic extension field. + // Specialized square functions could leverage the field's structure for potential savings. + // Adding this note in case future generalizations or optimizations are considered. + + // x |--> x^7 + let x2 = x * x; + let x4 = x2 * x2; + let x3 = x * x2; + x3 * x4 + } + + /// Computes s*A where s is the state row vector and A is the matrix + /// + /// [ M_00 | v ] + /// [ ------+--- ] + /// [ w_hat | Id ] + /// + /// M_00 is a scalar, v is 1x(t-1), w_hat is (t-1)x1 and Id is the + /// (t-1)x(t-1) identity matrix. + #[inline(always)] + #[unroll_for_loops] + fn mds_partial_layer_fast(state: &[Self; SPONGE_WIDTH], r: usize) -> [Self; SPONGE_WIDTH] { + // Set d = [M_00 | w^] dot [state] + + let mut d_sum = (0u128, 0u32); // u160 accumulator + for i in 1..12 { + if i < SPONGE_WIDTH { + let t = Self::FAST_PARTIAL_ROUND_W_HATS[r][i - 1] as u128; + let si = state[i].to_noncanonical_u64() as u128; + d_sum = add_u160_u128(d_sum, si * t); + } + } + let s0 = state[0].to_noncanonical_u64() as u128; + let mds0to0 = (Self::MDS_MATRIX_CIRC[0] + Self::MDS_MATRIX_DIAG[0]) as u128; + d_sum = add_u160_u128(d_sum, s0 * mds0to0); + let d = reduce_u160::(d_sum); + + // result = [d] concat [state[0] * v + state[shift up by 1]] + let mut result = [Self::ZERO; SPONGE_WIDTH]; + result[0] = d; + for i in 1..12 { + if i < SPONGE_WIDTH { + let t = Self::from_canonical_u64(Self::FAST_PARTIAL_ROUND_VS[r][i - 1]); + result[i] = state[i].multiply_accumulate(state[0], t); + } + } + result + } + + #[inline(always)] + #[unroll_for_loops] + fn mds_row_shf(r: usize, v: &[u64; SPONGE_WIDTH]) -> u128 { + debug_assert!(r < SPONGE_WIDTH); + // The values of `MDS_MATRIX_CIRC` and `MDS_MATRIX_DIAG` are + // known to be small, so we can accumulate all the products for + // each row and reduce just once at the end (done by the + // caller). + + // NB: Unrolling this, calculating each term independently, and + // summing at the end, didn't improve performance for me. + let mut res = 0u128; + + // This is a hacky way of fully unrolling the loop. + for i in 0..12 { + if i < SPONGE_WIDTH { + res += (v[(i + r) % SPONGE_WIDTH] as u128) * (Self::MDS_MATRIX_CIRC[i] as u128); + } + } + res += (v[r] as u128) * (Self::MDS_MATRIX_DIAG[r] as u128); + + res + } +} + +#[inline(always)] +const fn add_u160_u128((x_lo, x_hi): (u128, u32), y: u128) -> (u128, u32) { + let (res_lo, over) = x_lo.overflowing_add(y); + let res_hi = x_hi + (over as u32); + (res_lo, res_hi) +} + +#[inline(always)] +fn reduce_u160((n_lo, n_hi): (u128, u32)) -> F { + let n_lo_hi = (n_lo >> 64) as u64; + let n_lo_lo = n_lo as u64; + let reduced_hi: u64 = F::from_noncanonical_u96(n_lo_hi, n_hi).to_noncanonical_u64(); + let reduced128: u128 = ((reduced_hi as u128) << 64) + (n_lo_lo as u128); + F::from_noncanonical_u128(reduced128) +} + +pub trait AdaptedField: SmallField { + const ORDER: u64; + + fn from_noncanonical_u96(n_lo: u64, n_hi: u32) -> Self; + + fn from_noncanonical_u128(n: u128) -> Self; + + fn multiply_accumulate(&self, x: Self, y: Self) -> Self; + + /// Returns `n`. Assumes that `n` is already in canonical form, i.e. `n < Self::order()`. + // TODO: Should probably be unsafe. + fn from_canonical_u64(n: u64) -> Self { + debug_assert!(n < Self::ORDER); + Self::from(n) + } + + /// # Safety + /// Equivalent to *self + Self::from_canonical_u64(rhs), but may be cheaper. The caller must + /// ensure that 0 <= rhs < Self::ORDER. The function may return incorrect results if this + /// precondition is not met. It is marked unsafe for this reason. + #[inline] + unsafe fn add_canonical_u64(&self, rhs: u64) -> Self { + // Default implementation. + *self + Self::from_canonical_u64(rhs) + } +} diff --git a/poseidon/src/poseidon_goldilocks.rs b/poseidon/src/poseidon_goldilocks.rs new file mode 100644 index 000000000..d59eab72f --- /dev/null +++ b/poseidon/src/poseidon_goldilocks.rs @@ -0,0 +1,340 @@ +use crate::{ + constants::N_PARTIAL_ROUNDS, + poseidon::{AdaptedField, Poseidon}, +}; +use goldilocks::{Goldilocks, SmallField, EPSILON}; +#[cfg(target_arch = "x86_64")] +use std::hint::unreachable_unchecked; + +#[rustfmt::skip] +impl Poseidon for Goldilocks { + // The MDS matrix we use is C + D, where C is the circulant matrix whose first row is given by + // `MDS_MATRIX_CIRC`, and D is the diagonal matrix whose diagonal is given by `MDS_MATRIX_DIAG`. + // + // WARNING: If the MDS matrix is changed, then the following + // constants need to be updated accordingly: + // - FAST_PARTIAL_ROUND_CONSTANTS + // - FAST_PARTIAL_ROUND_VS + // - FAST_PARTIAL_ROUND_W_HATS + // - FAST_PARTIAL_ROUND_INITIAL_MATRIX + const MDS_MATRIX_CIRC: [u64; 12] = [17, 15, 41, 16, 2, 28, 13, 13, 39, 18, 34, 20]; + const MDS_MATRIX_DIAG: [u64; 12] = [8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; + + const FAST_PARTIAL_FIRST_ROUND_CONSTANT: [u64; 12] = [ + 0x3cc3f892184df408, 0xe993fd841e7e97f1, 0xf2831d3575f0f3af, 0xd2500e0a350994ca, + 0xc5571f35d7288633, 0x91d89c5184109a02, 0xf37f925d04e5667b, 0x2d6e448371955a69, + 0x740ef19ce01398a1, 0x694d24c0752fdf45, 0x60936af96ee2f148, 0xc33448feadc78f0c, + ]; + + const FAST_PARTIAL_ROUND_CONSTANTS: [u64; N_PARTIAL_ROUNDS] = [ + 0x74cb2e819ae421ab, 0xd2559d2370e7f663, 0x62bf78acf843d17c, 0xd5ab7b67e14d1fb4, + 0xb9fe2ae6e0969bdc, 0xe33fdf79f92a10e8, 0x0ea2bb4c2b25989b, 0xca9121fbf9d38f06, + 0xbdd9b0aa81f58fa4, 0x83079fa4ecf20d7e, 0x650b838edfcc4ad3, 0x77180c88583c76ac, + 0xaf8c20753143a180, 0xb8ccfe9989a39175, 0x954a1729f60cc9c5, 0xdeb5b550c4dca53b, + 0xf01bb0b00f77011e, 0xa1ebb404b676afd9, 0x860b6e1597a0173e, 0x308bb65a036acbce, + 0x1aca78f31c97c876, 0x0, + ]; + + const FAST_PARTIAL_ROUND_VS: [[u64; 12 - 1]; N_PARTIAL_ROUNDS] =[ + [0x94877900674181c3, 0xc6c67cc37a2a2bbd, 0xd667c2055387940f, 0x0ba63a63e94b5ff0, + 0x99460cc41b8f079f, 0x7ff02375ed524bb3, 0xea0870b47a8caf0e, 0xabcad82633b7bc9d, + 0x3b8d135261052241, 0xfb4515f5e5b0d539, 0x3ee8011c2b37f77c, ], + [0x0adef3740e71c726, 0xa37bf67c6f986559, 0xc6b16f7ed4fa1b00, 0x6a065da88d8bfc3c, + 0x4cabc0916844b46f, 0x407faac0f02e78d1, 0x07a786d9cf0852cf, 0x42433fb6949a629a, + 0x891682a147ce43b0, 0x26cfd58e7b003b55, 0x2bbf0ed7b657acb3, ], + [0x481ac7746b159c67, 0xe367de32f108e278, 0x73f260087ad28bec, 0x5cfc82216bc1bdca, + 0xcaccc870a2663a0e, 0xdb69cd7b4298c45d, 0x7bc9e0c57243e62d, 0x3cc51c5d368693ae, + 0x366b4e8cc068895b, 0x2bd18715cdabbca4, 0xa752061c4f33b8cf, ], + [0xb22d2432b72d5098, 0x9e18a487f44d2fe4, 0x4b39e14ce22abd3c, 0x9e77fde2eb315e0d, + 0xca5e0385fe67014d, 0x0c2cb99bf1b6bddb, 0x99ec1cd2a4460bfe, 0x8577a815a2ff843f, + 0x7d80a6b4fd6518a5, 0xeb6c67123eab62cb, 0x8f7851650eca21a5, ], + [0x11ba9a1b81718c2a, 0x9f7d798a3323410c, 0xa821855c8c1cf5e5, 0x535e8d6fac0031b2, + 0x404e7c751b634320, 0xa729353f6e55d354, 0x4db97d92e58bb831, 0xb53926c27897bf7d, + 0x965040d52fe115c5, 0x9565fa41ebd31fd7, 0xaae4438c877ea8f4, ], + [0x37f4e36af6073c6e, 0x4edc0918210800e9, 0xc44998e99eae4188, 0x9f4310d05d068338, + 0x9ec7fe4350680f29, 0xc5b2c1fdc0b50874, 0xa01920c5ef8b2ebe, 0x59fa6f8bd91d58ba, + 0x8bfc9eb89b515a82, 0xbe86a7a2555ae775, 0xcbb8bbaa3810babf, ], + [0x577f9a9e7ee3f9c2, 0x88c522b949ace7b1, 0x82f07007c8b72106, 0x8283d37c6675b50e, + 0x98b074d9bbac1123, 0x75c56fb7758317c1, 0xfed24e206052bc72, 0x26d7c3d1bc07dae5, + 0xf88c5e441e28dbb4, 0x4fe27f9f96615270, 0x514d4ba49c2b14fe, ], + [0xf02a3ac068ee110b, 0x0a3630dafb8ae2d7, 0xce0dc874eaf9b55c, 0x9a95f6cff5b55c7e, + 0x626d76abfed00c7b, 0xa0c1cf1251c204ad, 0xdaebd3006321052c, 0x3d4bd48b625a8065, + 0x7f1e584e071f6ed2, 0x720574f0501caed3, 0xe3260ba93d23540a, ], + [0xab1cbd41d8c1e335, 0x9322ed4c0bc2df01, 0x51c3c0983d4284e5, 0x94178e291145c231, + 0xfd0f1a973d6b2085, 0xd427ad96e2b39719, 0x8a52437fecaac06b, 0xdc20ee4b8c4c9a80, + 0xa2c98e9549da2100, 0x1603fe12613db5b6, 0x0e174929433c5505, ], + [0x3d4eab2b8ef5f796, 0xcfff421583896e22, 0x4143cb32d39ac3d9, 0x22365051b78a5b65, + 0x6f7fd010d027c9b6, 0xd9dd36fba77522ab, 0xa44cf1cb33e37165, 0x3fc83d3038c86417, + 0xc4588d418e88d270, 0xce1320f10ab80fe2, 0xdb5eadbbec18de5d, ], + [0x1183dfce7c454afd, 0x21cea4aa3d3ed949, 0x0fce6f70303f2304, 0x19557d34b55551be, + 0x4c56f689afc5bbc9, 0xa1e920844334f944, 0xbad66d423d2ec861, 0xf318c785dc9e0479, + 0x99e2032e765ddd81, 0x400ccc9906d66f45, 0xe1197454db2e0dd9, ], + [0x84d1ecc4d53d2ff1, 0xd8af8b9ceb4e11b6, 0x335856bb527b52f4, 0xc756f17fb59be595, + 0xc0654e4ea5553a78, 0x9e9a46b61f2ea942, 0x14fc8b5b3b809127, 0xd7009f0f103be413, + 0x3e0ee7b7a9fb4601, 0xa74e888922085ed7, 0xe80a7cde3d4ac526, ], + [0x238aa6daa612186d, 0x9137a5c630bad4b4, 0xc7db3817870c5eda, 0x217e4f04e5718dc9, + 0xcae814e2817bd99d, 0xe3292e7ab770a8ba, 0x7bb36ef70b6b9482, 0x3c7835fb85bca2d3, + 0xfe2cdf8ee3c25e86, 0x61b3915ad7274b20, 0xeab75ca7c918e4ef, ], + [0xd6e15ffc055e154e, 0xec67881f381a32bf, 0xfbb1196092bf409c, 0xdc9d2e07830ba226, + 0x0698ef3245ff7988, 0x194fae2974f8b576, 0x7a5d9bea6ca4910e, 0x7aebfea95ccdd1c9, + 0xf9bd38a67d5f0e86, 0xfa65539de65492d8, 0xf0dfcbe7653ff787, ], + [0x0bd87ad390420258, 0x0ad8617bca9e33c8, 0x0c00ad377a1e2666, 0x0ac6fc58b3f0518f, + 0x0c0cc8a892cc4173, 0x0c210accb117bc21, 0x0b73630dbb46ca18, 0x0c8be4920cbd4a54, + 0x0bfe877a21be1690, 0x0ae790559b0ded81, 0x0bf50db2f8d6ce31, ], + [0x000cf29427ff7c58, 0x000bd9b3cf49eec8, 0x000d1dc8aa81fb26, 0x000bc792d5c394ef, + 0x000d2ae0b2266453, 0x000d413f12c496c1, 0x000c84128cfed618, 0x000db5ebd48fc0d4, + 0x000d1b77326dcb90, 0x000beb0ccc145421, 0x000d10e5b22b11d1, ], + [0x00000e24c99adad8, 0x00000cf389ed4bc8, 0x00000e580cbf6966, 0x00000cde5fd7e04f, + 0x00000e63628041b3, 0x00000e7e81a87361, 0x00000dabe78f6d98, 0x00000efb14cac554, + 0x00000e5574743b10, 0x00000d05709f42c1, 0x00000e4690c96af1, ], + [0x0000000f7157bc98, 0x0000000e3006d948, 0x0000000fa65811e6, 0x0000000e0d127e2f, + 0x0000000fc18bfe53, 0x0000000fd002d901, 0x0000000eed6461d8, 0x0000001068562754, + 0x0000000fa0236f50, 0x0000000e3af13ee1, 0x0000000fa460f6d1, ], + [0x0000000011131738, 0x000000000f56d588, 0x0000000011050f86, 0x000000000f848f4f, + 0x00000000111527d3, 0x00000000114369a1, 0x00000000106f2f38, 0x0000000011e2ca94, + 0x00000000110a29f0, 0x000000000fa9f5c1, 0x0000000010f625d1, ], + [0x000000000011f718, 0x000000000010b6c8, 0x0000000000134a96, 0x000000000010cf7f, + 0x0000000000124d03, 0x000000000013f8a1, 0x0000000000117c58, 0x0000000000132c94, + 0x0000000000134fc0, 0x000000000010a091, 0x0000000000128961, ], + [0x0000000000001300, 0x0000000000001750, 0x000000000000114e, 0x000000000000131f, + 0x000000000000167b, 0x0000000000001371, 0x0000000000001230, 0x000000000000182c, + 0x0000000000001368, 0x0000000000000f31, 0x00000000000015c9, ], + [0x0000000000000014, 0x0000000000000022, 0x0000000000000012, 0x0000000000000027, + 0x000000000000000d, 0x000000000000000d, 0x000000000000001c, 0x0000000000000002, + 0x0000000000000010, 0x0000000000000029, 0x000000000000000f, ], + ]; + + const FAST_PARTIAL_ROUND_W_HATS: [[u64; 12 - 1]; N_PARTIAL_ROUNDS] = [ + [0x3d999c961b7c63b0, 0x814e82efcd172529, 0x2421e5d236704588, 0x887af7d4dd482328, + 0xa5e9c291f6119b27, 0xbdc52b2676a4b4aa, 0x64832009d29bcf57, 0x09c4155174a552cc, + 0x463f9ee03d290810, 0xc810936e64982542, 0x043b1c289f7bc3ac, ], + [0x673655aae8be5a8b, 0xd510fe714f39fa10, 0x2c68a099b51c9e73, 0xa667bfa9aa96999d, + 0x4d67e72f063e2108, 0xf84dde3e6acda179, 0x40f9cc8c08f80981, 0x5ead032050097142, + 0x6591b02092d671bb, 0x00e18c71963dd1b7, 0x8a21bcd24a14218a, ], + [0x202800f4addbdc87, 0xe4b5bdb1cc3504ff, 0xbe32b32a825596e7, 0x8e0f68c5dc223b9a, + 0x58022d9e1c256ce3, 0x584d29227aa073ac, 0x8b9352ad04bef9e7, 0xaead42a3f445ecbf, + 0x3c667a1d833a3cca, 0xda6f61838efa1ffe, 0xe8f749470bd7c446, ], + [0xc5b85bab9e5b3869, 0x45245258aec51cf7, 0x16e6b8e68b931830, 0xe2ae0f051418112c, + 0x0470e26a0093a65b, 0x6bef71973a8146ed, 0x119265be51812daf, 0xb0be7356254bea2e, + 0x8584defff7589bd7, 0x3c5fe4aeb1fb52ba, 0x9e7cd88acf543a5e, ], + [0x179be4bba87f0a8c, 0xacf63d95d8887355, 0x6696670196b0074f, 0xd99ddf1fe75085f9, + 0xc2597881fef0283b, 0xcf48395ee6c54f14, 0x15226a8e4cd8d3b6, 0xc053297389af5d3b, + 0x2c08893f0d1580e2, 0x0ed3cbcff6fcc5ba, 0xc82f510ecf81f6d0, ], + [0x94b06183acb715cc, 0x500392ed0d431137, 0x861cc95ad5c86323, 0x05830a443f86c4ac, + 0x3b68225874a20a7c, 0x10b3309838e236fb, 0x9b77fc8bcd559e2c, 0xbdecf5e0cb9cb213, + 0x30276f1221ace5fa, 0x7935dd342764a144, 0xeac6db520bb03708, ], + [0x7186a80551025f8f, 0x622247557e9b5371, 0xc4cbe326d1ad9742, 0x55f1523ac6a23ea2, + 0xa13dfe77a3d52f53, 0xe30750b6301c0452, 0x08bd488070a3a32b, 0xcd800caef5b72ae3, + 0x83329c90f04233ce, 0xb5b99e6664a0a3ee, 0x6b0731849e200a7f, ], + [0xec3fabc192b01799, 0x382b38cee8ee5375, 0x3bfb6c3f0e616572, 0x514abd0cf6c7bc86, + 0x47521b1361dcc546, 0x178093843f863d14, 0xad1003c5d28918e7, 0x738450e42495bc81, + 0xaf947c59af5e4047, 0x4653fb0685084ef2, 0x057fde2062ae35bf, ], + [0xe376678d843ce55e, 0x66f3860d7514e7fc, 0x7817f3dfff8b4ffa, 0x3929624a9def725b, + 0x0126ca37f215a80a, 0xfce2f5d02762a303, 0x1bc927375febbad7, 0x85b481e5243f60bf, + 0x2d3c5f42a39c91a0, 0x0811719919351ae8, 0xf669de0add993131, ], + [0x7de38bae084da92d, 0x5b848442237e8a9b, 0xf6c705da84d57310, 0x31e6a4bdb6a49017, + 0x889489706e5c5c0f, 0x0e4a205459692a1b, 0xbac3fa75ee26f299, 0x5f5894f4057d755e, + 0xb0dc3ecd724bb076, 0x5e34d8554a6452ba, 0x04f78fd8c1fdcc5f, ], + [0x4dd19c38779512ea, 0xdb79ba02704620e9, 0x92a29a3675a5d2be, 0xd5177029fe495166, + 0xd32b3298a13330c1, 0x251c4a3eb2c5f8fd, 0xe1c48b26e0d98825, 0x3301d3362a4ffccb, + 0x09bb6c88de8cd178, 0xdc05b676564f538a, 0x60192d883e473fee, ], + [0x16b9774801ac44a0, 0x3cb8411e786d3c8e, 0xa86e9cf505072491, 0x0178928152e109ae, + 0x5317b905a6e1ab7b, 0xda20b3be7f53d59f, 0xcb97dedecebee9ad, 0x4bd545218c59f58d, + 0x77dc8d856c05a44a, 0x87948589e4f243fd, 0x7e5217af969952c2, ], + [0xbc58987d06a84e4d, 0x0b5d420244c9cae3, 0xa3c4711b938c02c0, 0x3aace640a3e03990, + 0x865a0f3249aacd8a, 0x8d00b2a7dbed06c7, 0x6eacb905beb7e2f8, 0x045322b216ec3ec7, + 0xeb9de00d594828e6, 0x088c5f20df9e5c26, 0xf555f4112b19781f, ], + [0xa8cedbff1813d3a7, 0x50dcaee0fd27d164, 0xf1cb02417e23bd82, 0xfaf322786e2abe8b, + 0x937a4315beb5d9b6, 0x1b18992921a11d85, 0x7d66c4368b3c497b, 0x0e7946317a6b4e99, + 0xbe4430134182978b, 0x3771e82493ab262d, 0xa671690d8095ce82, ], + [0xb035585f6e929d9d, 0xba1579c7e219b954, 0xcb201cf846db4ba3, 0x287bf9177372cf45, + 0xa350e4f61147d0a6, 0xd5d0ecfb50bcff99, 0x2e166aa6c776ed21, 0xe1e66c991990e282, + 0x662b329b01e7bb38, 0x8aa674b36144d9a9, 0xcbabf78f97f95e65, ], + [0xeec24b15a06b53fe, 0xc8a7aa07c5633533, 0xefe9c6fa4311ad51, 0xb9173f13977109a1, + 0x69ce43c9cc94aedc, 0xecf623c9cd118815, 0x28625def198c33c7, 0xccfc5f7de5c3636a, + 0xf5e6c40f1621c299, 0xcec0e58c34cb64b1, 0xa868ea113387939f, ], + [0xd8dddbdc5ce4ef45, 0xacfc51de8131458c, 0x146bb3c0fe499ac0, 0x9e65309f15943903, + 0x80d0ad980773aa70, 0xf97817d4ddbf0607, 0xe4626620a75ba276, 0x0dfdc7fd6fc74f66, + 0xf464864ad6f2bb93, 0x02d55e52a5d44414, 0xdd8de62487c40925, ], + [0xc15acf44759545a3, 0xcbfdcf39869719d4, 0x33f62042e2f80225, 0x2599c5ead81d8fa3, + 0x0b306cb6c1d7c8d0, 0x658c80d3df3729b1, 0xe8d1b2b21b41429c, 0xa1b67f09d4b3ccb8, + 0x0e1adf8b84437180, 0x0d593a5e584af47b, 0xa023d94c56e151c7, ], + [0x49026cc3a4afc5a6, 0xe06dff00ab25b91b, 0x0ab38c561e8850ff, 0x92c3c8275e105eeb, + 0xb65256e546889bd0, 0x3c0468236ea142f6, 0xee61766b889e18f2, 0xa206f41b12c30415, + 0x02fe9d756c9f12d1, 0xe9633210630cbf12, 0x1ffea9fe85a0b0b1, ], + [0x81d1ae8cc50240f3, 0xf4c77a079a4607d7, 0xed446b2315e3efc1, 0x0b0a6b70915178c3, + 0xb11ff3e089f15d9a, 0x1d4dba0b7ae9cc18, 0x65d74e2f43b48d05, 0xa2df8c6b8ae0804a, + 0xa4e6f0a8c33348a6, 0xc0a26efc7be5669b, 0xa6b6582c547d0d60, ], + [0x84afc741f1c13213, 0x2f8f43734fc906f3, 0xde682d72da0a02d9, 0x0bb005236adb9ef2, + 0x5bdf35c10a8b5624, 0x0739a8a343950010, 0x52f515f44785cfbc, 0xcbaf4e5d82856c60, + 0xac9ea09074e3e150, 0x8f0fa011a2035fb0, 0x1a37905d8450904a, ], + [0x3abeb80def61cc85, 0x9d19c9dd4eac4133, 0x075a652d9641a985, 0x9daf69ae1b67e667, + 0x364f71da77920a18, 0x50bd769f745c95b1, 0xf223d1180dbbf3fc, 0x2f885e584e04aa99, + 0xb69a0fa70aea684a, 0x09584acaa6e062a0, 0x0bc051640145b19b, ], + ]; + + // NB: This is in ROW-major order to support cache-friendly pre-multiplication. + const FAST_PARTIAL_ROUND_INITIAL_MATRIX: [[u64; 12 - 1]; 12 - 1] = [ + [0x80772dc2645b280b, 0xdc927721da922cf8, 0xc1978156516879ad, 0x90e80c591f48b603, + 0x3a2432625475e3ae, 0x00a2d4321cca94fe, 0x77736f524010c932, 0x904d3f2804a36c54, + 0xbf9b39e28a16f354, 0x3a1ded54a6cd058b, 0x42392870da5737cf, ], + [0xe796d293a47a64cb, 0xb124c33152a2421a, 0x0ee5dc0ce131268a, 0xa9032a52f930fae6, + 0x7e33ca8c814280de, 0xad11180f69a8c29e, 0xc75ac6d5b5a10ff3, 0xf0674a8dc5a387ec, + 0xb36d43120eaa5e2b, 0x6f232aab4b533a25, 0x3a1ded54a6cd058b, ], + [0xdcedab70f40718ba, 0x14a4a64da0b2668f, 0x4715b8e5ab34653b, 0x1e8916a99c93a88e, + 0xbba4b5d86b9a3b2c, 0xe76649f9bd5d5c2e, 0xaf8e2518a1ece54d, 0xdcda1344cdca873f, + 0xcd080204256088e5, 0xb36d43120eaa5e2b, 0xbf9b39e28a16f354, ], + [0xf4a437f2888ae909, 0xc537d44dc2875403, 0x7f68007619fd8ba9, 0xa4911db6a32612da, + 0x2f7e9aade3fdaec1, 0xe7ffd578da4ea43d, 0x43a608e7afa6b5c2, 0xca46546aa99e1575, + 0xdcda1344cdca873f, 0xf0674a8dc5a387ec, 0x904d3f2804a36c54, ], + [0xf97abba0dffb6c50, 0x5e40f0c9bb82aab5, 0x5996a80497e24a6b, 0x07084430a7307c9a, + 0xad2f570a5b8545aa, 0xab7f81fef4274770, 0xcb81f535cf98c9e9, 0x43a608e7afa6b5c2, + 0xaf8e2518a1ece54d, 0xc75ac6d5b5a10ff3, 0x77736f524010c932, ], + [0x7f8e41e0b0a6cdff, 0x4b1ba8d40afca97d, 0x623708f28fca70e8, 0xbf150dc4914d380f, + 0xc26a083554767106, 0x753b8b1126665c22, 0xab7f81fef4274770, 0xe7ffd578da4ea43d, + 0xe76649f9bd5d5c2e, 0xad11180f69a8c29e, 0x00a2d4321cca94fe, ], + [0x726af914971c1374, 0x1d7f8a2cce1a9d00, 0x18737784700c75cd, 0x7fb45d605dd82838, + 0x862361aeab0f9b6e, 0xc26a083554767106, 0xad2f570a5b8545aa, 0x2f7e9aade3fdaec1, + 0xbba4b5d86b9a3b2c, 0x7e33ca8c814280de, 0x3a2432625475e3ae, ], + [0x64dd936da878404d, 0x4db9a2ead2bd7262, 0xbe2e19f6d07f1a83, 0x02290fe23c20351a, + 0x7fb45d605dd82838, 0xbf150dc4914d380f, 0x07084430a7307c9a, 0xa4911db6a32612da, + 0x1e8916a99c93a88e, 0xa9032a52f930fae6, 0x90e80c591f48b603, ], + [0x85418a9fef8a9890, 0xd8a2eb7ef5e707ad, 0xbfe85ababed2d882, 0xbe2e19f6d07f1a83, + 0x18737784700c75cd, 0x623708f28fca70e8, 0x5996a80497e24a6b, 0x7f68007619fd8ba9, + 0x4715b8e5ab34653b, 0x0ee5dc0ce131268a, 0xc1978156516879ad, ], + [0x156048ee7a738154, 0x91f7562377e81df5, 0xd8a2eb7ef5e707ad, 0x4db9a2ead2bd7262, + 0x1d7f8a2cce1a9d00, 0x4b1ba8d40afca97d, 0x5e40f0c9bb82aab5, 0xc537d44dc2875403, + 0x14a4a64da0b2668f, 0xb124c33152a2421a, 0xdc927721da922cf8, ], + [0xd841e8ef9dde8ba0, 0x156048ee7a738154, 0x85418a9fef8a9890, 0x64dd936da878404d, + 0x726af914971c1374, 0x7f8e41e0b0a6cdff, 0xf97abba0dffb6c50, 0xf4a437f2888ae909, + 0xdcedab70f40718ba, 0xe796d293a47a64cb, 0x80772dc2645b280b, ], + ]; + + +} + +impl AdaptedField for Goldilocks { + const ORDER: u64 = Goldilocks::MODULUS_U64; + + fn from_noncanonical_u96(n_lo: u64, n_hi: u32) -> Self { + reduce96((n_lo, n_hi)) + } + + fn from_noncanonical_u128(n: u128) -> Self { + reduce128(n) + } + + fn multiply_accumulate(&self, x: Self, y: Self) -> Self { + // u64 + u64 * u64 cannot overflow. + reduce128((self.0 as u128) + (x.0 as u128) * (y.0 as u128)) + } +} + +/// Fast addition modulo ORDER for x86-64. +/// This function is marked unsafe for the following reasons: +/// - It is only correct if x + y < 2**64 + ORDER = 0x1ffffffff00000001. +/// - It is only faster in some circumstances. In particular, on x86 it overwrites both inputs in +/// the registers, so its use is not recommended when either input will be used again. +#[inline(always)] +#[cfg(target_arch = "x86_64")] +unsafe fn add_no_canonicalize_trashing_input(x: u64, y: u64) -> u64 { + let res_wrapped: u64; + let adjustment: u64; + core::arch::asm!( + "add {0}, {1}", + // Trick. The carry flag is set iff the addition overflowed. + // sbb x, y does x := x - y - CF. In our case, x and y are both {1:e}, so it simply does + // {1:e} := 0xffffffff on overflow and {1:e} := 0 otherwise. {1:e} is the low 32 bits of + // {1}; the high 32-bits are zeroed on write. In the end, we end up with 0xffffffff in {1} + // on overflow; this happens be EPSILON. + // Note that the CPU does not realize that the result of sbb x, x does not actually depend + // on x. We must write the result to a register that we know to be ready. We have a + // dependency on {1} anyway, so let's use it. + "sbb {1:e}, {1:e}", + inlateout(reg) x => res_wrapped, + inlateout(reg) y => adjustment, + options(pure, nomem, nostack), + ); + assume(x != 0 || (res_wrapped == y && adjustment == 0)); + assume(y != 0 || (res_wrapped == x && adjustment == 0)); + // Add EPSILON == subtract ORDER. + // Cannot overflow unless the assumption if x + y < 2**64 + ORDER is incorrect. + res_wrapped + adjustment +} + +#[inline(always)] +#[cfg(not(target_arch = "x86_64"))] +const unsafe fn add_no_canonicalize_trashing_input(x: u64, y: u64) -> u64 { + let (res_wrapped, carry) = x.overflowing_add(y); + // Below cannot overflow unless the assumption if x + y < 2**64 + ORDER is incorrect. + res_wrapped + EPSILON * (carry as u64) +} + +/// Reduces to a 64-bit value. The result might not be in canonical form; it could be in between the +/// field order and `2^64`. +#[inline] +fn reduce96((x_lo, x_hi): (u64, u32)) -> Goldilocks { + let t1 = x_hi as u64 * EPSILON; + let t2 = unsafe { add_no_canonicalize_trashing_input(x_lo, t1) }; + Goldilocks(t2) +} + +/// Reduces to a 64-bit value. The result might not be in canonical form; it could be in between the +/// field order and `2^64`. +#[inline] +fn reduce128(x: u128) -> Goldilocks { + let (x_lo, x_hi) = split(x); // This is a no-op + let x_hi_hi = x_hi >> 32; + let x_hi_lo = x_hi & EPSILON; + + let (mut t0, borrow) = x_lo.overflowing_sub(x_hi_hi); + if borrow { + branch_hint(); // A borrow is exceedingly rare. It is faster to branch. + t0 -= EPSILON; // Cannot underflow. + } + let t1 = x_hi_lo * EPSILON; + let t2 = unsafe { add_no_canonicalize_trashing_input(t0, t1) }; + Goldilocks(t2) +} + +#[inline] +const fn split(x: u128) -> (u64, u64) { + (x as u64, (x >> 64) as u64) +} + +#[inline(always)] +#[cfg(target_arch = "x86_64")] +pub fn assume(p: bool) { + debug_assert!(p); + if !p { + unsafe { + unreachable_unchecked(); + } + } +} + +/// Try to force Rust to emit a branch. Example: +/// if x > 2 { +/// y = foo(); +/// branch_hint(); +/// } else { +/// y = bar(); +/// } +/// This function has no semantics. It is a hint only. +#[inline(always)] +pub fn branch_hint() { + // NOTE: These are the currently supported assembly architectures. See the + // [nightly reference](https://doc.rust-lang.org/nightly/reference/inline-assembly.html) for + // the most up-to-date list. + #[cfg(any( + target_arch = "aarch64", + target_arch = "arm", + target_arch = "riscv32", + target_arch = "riscv64", + target_arch = "x86", + target_arch = "x86_64", + ))] + unsafe { + core::arch::asm!("", options(nomem, nostack, preserves_flags)); + } +} diff --git a/poseidon/src/poseidon_hash.rs b/poseidon/src/poseidon_hash.rs new file mode 100644 index 000000000..3a09685ae --- /dev/null +++ b/poseidon/src/poseidon_hash.rs @@ -0,0 +1,156 @@ +use crate::{ + constants::{DIGEST_WIDTH, SPONGE_RATE, SPONGE_WIDTH}, + digest::Digest, + poseidon::{AdaptedField, Poseidon}, + poseidon_permutation::PoseidonPermutation, +}; + +pub struct PoseidonHash; + +impl PoseidonHash { + pub fn two_to_one( + left: &Digest, + right: &Digest, + ) -> Digest { + compress(left, right) + } + + pub fn hash_or_noop(inputs: &[F]) -> Digest { + if inputs.len() <= DIGEST_WIDTH { + Digest::from_partial(inputs) + } else { + hash_n_to_hash_no_pad(inputs) + } + } +} + +pub fn hash_n_to_m_no_pad(inputs: &[F], num_outputs: usize) -> Vec { + let mut perm = PoseidonPermutation::new(core::iter::repeat(F::ZERO)); + + // Absorb all input chunks. + for input_chunk in inputs.chunks(SPONGE_RATE) { + // Overwrite the first r elements with the inputs. This differs from a standard sponge, + // where we would xor or add in the inputs. This is a well-known variant, though, + // sometimes called "overwrite mode". + perm.set_from_slice(input_chunk, 0); + perm.permute(); + } + + // Squeeze until we have the desired number of outputs + let mut outputs = Vec::with_capacity(num_outputs); + loop { + for &item in perm.squeeze() { + outputs.push(item); + if outputs.len() == num_outputs { + return outputs; + } + } + perm.permute(); + } +} + +pub fn hash_n_to_hash_no_pad(inputs: &[F]) -> Digest { + hash_n_to_m_no_pad(inputs, DIGEST_WIDTH).try_into().unwrap() +} + +pub fn compress(x: &Digest, y: &Digest) -> Digest { + debug_assert!(SPONGE_RATE >= DIGEST_WIDTH); + debug_assert!(SPONGE_WIDTH >= 2 * DIGEST_WIDTH); + debug_assert_eq!(x.elements().len(), DIGEST_WIDTH); + debug_assert_eq!(y.elements().len(), DIGEST_WIDTH); + + let mut perm = PoseidonPermutation::new(core::iter::repeat(F::ZERO)); + perm.set_from_slice(x.elements(), 0); + perm.set_from_slice(y.elements(), DIGEST_WIDTH); + + perm.permute(); + + Digest(perm.squeeze()[..DIGEST_WIDTH].try_into().unwrap()) +} + +#[cfg(test)] +mod tests { + use crate::{digest::Digest, poseidon_hash::PoseidonHash}; + use goldilocks::Goldilocks; + use plonky2::{ + field::{ + goldilocks_field::GoldilocksField, + types::{PrimeField64, Sample}, + }, + hash::{hash_types::HashOut, poseidon::PoseidonHash as PlonkyPoseidonHash}, + plonk::config::{GenericHashOut, Hasher}, + }; + use rand::{thread_rng, Rng}; + + type PlonkyFieldElements = Vec; + type CenoFieldElements = Vec; + + const N_ITERATIONS: usize = 100; + + fn ceno_goldy_from_plonky_goldy(values: &[GoldilocksField]) -> Vec { + values + .iter() + .map(|value| Goldilocks(value.to_canonical_u64())) + .collect() + } + + fn test_vector_pair(n: usize) -> (PlonkyFieldElements, CenoFieldElements) { + let plonky_elems = GoldilocksField::rand_vec(n); + let ceno_elems = ceno_goldy_from_plonky_goldy(plonky_elems.as_slice()); + (plonky_elems, ceno_elems) + } + + fn random_hash_pair() -> (HashOut, Digest) { + let plonky_random_hash = HashOut::::rand(); + let ceno_equivalent_hash = Digest( + ceno_goldy_from_plonky_goldy(plonky_random_hash.elements.as_slice()) + .try_into() + .unwrap(), + ); + (plonky_random_hash, ceno_equivalent_hash) + } + + fn compare_hash_output( + plonky_hash: HashOut, + ceno_hash: Digest, + ) -> bool { + let plonky_elems = plonky_hash.to_vec(); + let plonky_in_ceno_field = ceno_goldy_from_plonky_goldy(plonky_elems.as_slice()); + plonky_in_ceno_field == ceno_hash.elements() + } + + #[test] + fn compare_hash() { + let mut rng = thread_rng(); + for _ in 0..N_ITERATIONS { + let n = rng.gen_range(5..=100); + let (plonky_elems, ceno_elems) = test_vector_pair(n); + let plonky_out = PlonkyPoseidonHash::hash_or_noop(plonky_elems.as_slice()); + let ceno_out = PoseidonHash::hash_or_noop(ceno_elems.as_slice()); + assert!(compare_hash_output(plonky_out, ceno_out)); + } + } + + #[test] + fn compare_noop() { + let mut rng = thread_rng(); + for _ in 0..N_ITERATIONS { + let n = rng.gen_range(0..=4); + let (plonky_elems, ceno_elems) = test_vector_pair(n); + let plonky_out = PlonkyPoseidonHash::hash_or_noop(plonky_elems.as_slice()); + let ceno_out = PoseidonHash::hash_or_noop(ceno_elems.as_slice()); + assert!(compare_hash_output(plonky_out, ceno_out)); + } + } + + #[test] + fn compare_two_to_one() { + for _ in 0..N_ITERATIONS { + let (plonky_hash_a, ceno_hash_a) = random_hash_pair(); + let (plonky_hash_b, ceno_hash_b) = random_hash_pair(); + let plonky_combined = PlonkyPoseidonHash::two_to_one(plonky_hash_a, plonky_hash_b); + let ceno_combined = PoseidonHash::two_to_one(&ceno_hash_a, &ceno_hash_b); + assert!(compare_hash_output(plonky_combined, ceno_combined)); + } + } +} diff --git a/poseidon/src/poseidon_permutation.rs b/poseidon/src/poseidon_permutation.rs new file mode 100644 index 000000000..423f8148d --- /dev/null +++ b/poseidon/src/poseidon_permutation.rs @@ -0,0 +1,52 @@ +use crate::{ + constants::{SPONGE_RATE, SPONGE_WIDTH}, + poseidon::Poseidon, +}; + +pub struct PoseidonPermutation { + state: [T; SPONGE_WIDTH], +} + +impl PoseidonPermutation { + /// Initialises internal state with values from `iter` until + /// `iter` is exhausted or `SPONGE_WIDTH` values have been + /// received; remaining state (if any) initialised with + /// `T::default()`. To initialise remaining elements with a + /// different value, instead of your original `iter` pass + /// `iter.chain(core::iter::repeat(F::from_canonical_u64(12345)))` + /// or similar. + pub fn new>(elts: I) -> Self { + let mut perm = Self { + state: [T::default(); SPONGE_WIDTH], + }; + perm.set_from_iter(elts, 0); + perm + } + + /// Set state element `i` to be `elts[i] for i = + /// start_idx..start_idx + n` where `n = min(elts.len(), + /// WIDTH-start_idx)`. Panics if `start_idx > SPONGE_WIDTH`. + pub(crate) fn set_from_slice(&mut self, elts: &[T], start_idx: usize) { + let begin = start_idx; + let end = start_idx + elts.len(); + self.state[begin..end].copy_from_slice(elts) + } + + /// Same semantics as for `set_from_iter` but probably faster than + /// just calling `set_from_iter(elts.iter())`. + fn set_from_iter>(&mut self, elts: I, start_idx: usize) { + for (s, e) in self.state[start_idx..].iter_mut().zip(elts) { + *s = e; + } + } + + /// Apply permutation to internal state + pub fn permute(&mut self) { + self.state = T::poseidon(self.state); + } + + /// Return a slice of `RATE` elements + pub fn squeeze(&self) -> &[T] { + &self.state[..SPONGE_RATE] + } +} diff --git a/rust-toolchain b/rust-toolchain index 07ade694b..7a5fe266d 100644 --- a/rust-toolchain +++ b/rust-toolchain @@ -1 +1 @@ -nightly \ No newline at end of file +nightly-2024-05-02 \ No newline at end of file diff --git a/rustfmt.toml b/rustfmt.toml index c46be5f21..d37b10e82 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -1,5 +1,6 @@ -edition = "2021" +edition = "2024" version = "Two" +style_edition = "2024" wrap_comments = false comment_width = 300 imports_granularity = "Crate" diff --git a/sumcheck/src/structs.rs b/sumcheck/src/structs.rs index 6a48171e1..2397a9cb8 100644 --- a/sumcheck/src/structs.rs +++ b/sumcheck/src/structs.rs @@ -16,8 +16,6 @@ pub struct IOPProof { impl IOPProof { #[allow(dead_code)] pub fn extract_sum(&self) -> E { - - self.proofs[0].evaluations[0] + self.proofs[0].evaluations[1] } }