Skip to content

Commit

Permalink
Refactor row major matrix (#624)
Browse files Browse the repository at this point in the history
_Written December 10th, most content and conversation predates this
description._

Fixes #600.
Makes several changes to the way `RowMajorMatrix` works:

1. It no longer uses `MaybeUninit`. Rather, its contents are initialized
to `T::default()` using parallel iterators.
2. Padding is no longer allocated in the constructor. It is also not the
concern of any particular user of the matrix to ensure correct padding.
Rather, users only pass the `InstancePaddingStrategy` argument which
describes the type of padding that they want. This padding is then
performed at the very latest stage (in the call to `self.into_mles()`).
3. `self.into_mles()` does more than before. `de_interleaving` is
removed because the indirection through a bi-dimensional matrix is
unnecessary. Notably, `self.into_mles()` parallelizes the work
differently than `de_interleaving` used to.
  • Loading branch information
mcalancea authored Dec 12, 2024
1 parent 0cd9258 commit 979ee30
Show file tree
Hide file tree
Showing 52 changed files with 306 additions and 320 deletions.
9 changes: 2 additions & 7 deletions ceno_zkvm/src/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use std::{
cmp::max,
fmt::Display,
iter::{Product, Sum},
mem::MaybeUninit,
ops::{Add, AddAssign, Deref, Mul, MulAssign, Neg, Shl, ShlAssign, Sub, SubAssign},
};

Expand Down Expand Up @@ -756,12 +755,8 @@ impl WitIn {
)
}

pub fn assign<E: ExtensionField>(
&self,
instance: &mut [MaybeUninit<E::BaseField>],
value: E::BaseField,
) {
instance[self.id as usize] = MaybeUninit::new(value);
pub fn assign<E: ExtensionField>(&self, instance: &mut [E::BaseField], value: E::BaseField) {
instance[self.id as usize] = value;
}
}

Expand Down
4 changes: 2 additions & 2 deletions ceno_zkvm/src/gadgets/div.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{fmt::Display, mem::MaybeUninit};
use std::fmt::Display;

use ff_ext::ExtensionField;

Expand Down Expand Up @@ -53,7 +53,7 @@ impl<E: ExtensionField> DivConfig<E> {

pub fn assign_instance<'a>(
&self,
instance: &mut [MaybeUninit<E::BaseField>],
instance: &mut [E::BaseField],
lkm: &mut LkMultiplicity,
divisor: &Value<'a, u32>,
quotient: &Value<'a, u32>,
Expand Down
18 changes: 9 additions & 9 deletions ceno_zkvm/src/gadgets/is_lt.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{fmt::Display, mem::MaybeUninit};
use std::fmt::Display;

use ceno_emul::{SWord, Word};
use ff_ext::ExtensionField;
Expand Down Expand Up @@ -52,7 +52,7 @@ impl AssertLtConfig {

pub fn assign_instance<F: SmallField>(
&self,
instance: &mut [MaybeUninit<F>],
instance: &mut [F],
lkm: &mut LkMultiplicity,
lhs: u64,
rhs: u64,
Expand Down Expand Up @@ -106,7 +106,7 @@ impl IsLtConfig {

pub fn assign_instance<F: SmallField>(
&self,
instance: &mut [MaybeUninit<F>],
instance: &mut [F],
lkm: &mut LkMultiplicity,
lhs: u64,
rhs: u64,
Expand All @@ -118,7 +118,7 @@ impl IsLtConfig {

pub fn assign_instance_signed<F: SmallField>(
&self,
instance: &mut [MaybeUninit<F>],
instance: &mut [F],
lkm: &mut LkMultiplicity,
lhs: SWord,
rhs: SWord,
Expand Down Expand Up @@ -184,7 +184,7 @@ impl InnerLtConfig {

pub fn assign_instance<F: SmallField>(
&self,
instance: &mut [MaybeUninit<F>],
instance: &mut [F],
lkm: &mut LkMultiplicity,
lhs: u64,
rhs: u64,
Expand All @@ -202,7 +202,7 @@ impl InnerLtConfig {
// TODO: refactor with the above function
pub fn assign_instance_signed<F: SmallField>(
&self,
instance: &mut [MaybeUninit<F>],
instance: &mut [F],
lkm: &mut LkMultiplicity,
lhs: SWord,
rhs: SWord,
Expand Down Expand Up @@ -256,7 +256,7 @@ impl<E: ExtensionField> AssertSignedLtConfig<E> {

pub fn assign_instance(
&self,
instance: &mut [MaybeUninit<E::BaseField>],
instance: &mut [E::BaseField],
lkm: &mut LkMultiplicity,
lhs: SWord,
rhs: SWord,
Expand Down Expand Up @@ -299,7 +299,7 @@ impl<E: ExtensionField> SignedLtConfig<E> {

pub fn assign_instance(
&self,
instance: &mut [MaybeUninit<E::BaseField>],
instance: &mut [E::BaseField],
lkm: &mut LkMultiplicity,
lhs: SWord,
rhs: SWord,
Expand Down Expand Up @@ -351,7 +351,7 @@ impl<E: ExtensionField> InnerSignedLtConfig<E> {

pub fn assign_instance(
&self,
instance: &mut [MaybeUninit<E::BaseField>],
instance: &mut [E::BaseField],
lkm: &mut LkMultiplicity,
lhs: SWord,
rhs: SWord,
Expand Down
6 changes: 2 additions & 4 deletions ceno_zkvm/src/gadgets/is_zero.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::mem::MaybeUninit;

use ff_ext::ExtensionField;
use goldilocks::SmallField;

Expand Down Expand Up @@ -64,7 +62,7 @@ impl IsZeroConfig {

pub fn assign_instance<F: SmallField>(
&self,
instance: &mut [MaybeUninit<F>],
instance: &mut [F],
x: F,
) -> Result<(), ZKVMError> {
let (is_zero, inverse) = if x.is_zero_vartime() {
Expand Down Expand Up @@ -117,7 +115,7 @@ impl IsEqualConfig {

pub fn assign_instance<F: SmallField>(
&self,
instance: &mut [MaybeUninit<F>],
instance: &mut [F],
a: F,
b: F,
) -> Result<(), ZKVMError> {
Expand Down
4 changes: 2 additions & 2 deletions ceno_zkvm/src/gadgets/signed_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
witness::LkMultiplicity,
};
use ff_ext::ExtensionField;
use std::{marker::PhantomData, mem::MaybeUninit};
use std::marker::PhantomData;

#[derive(Debug)]
pub struct SignedExtendConfig<E> {
Expand Down Expand Up @@ -84,7 +84,7 @@ impl<E: ExtensionField> SignedExtendConfig<E> {

pub fn assign_instance(
&self,
instance: &mut [MaybeUninit<E::BaseField>],
instance: &mut [E::BaseField],
lk_multiplicity: &mut LkMultiplicity,
val: u64,
) -> Result<(), ZKVMError> {
Expand Down
49 changes: 13 additions & 36 deletions ceno_zkvm/src/instructions.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,31 @@
use std::mem::MaybeUninit;

use ceno_emul::StepRecord;
use ff_ext::ExtensionField;
use multilinear_extensions::util::max_usable_threads;
use rayon::{
iter::{IndexedParallelIterator, ParallelIterator},
slice::ParallelSlice,
};
use std::sync::Arc;

use crate::{
circuit_builder::CircuitBuilder,
error::ZKVMError,
scheme::constants::MIN_PAR_SIZE,
witness::{LkMultiplicity, RowMajorMatrix},
};
use ff::Field;

pub mod riscv;

#[derive(Clone)]
pub enum InstancePaddingStrategy {
Zero,
// Pads with default values of underlying type
// Usually zero, but check carefully
Default,
// Pads by repeating last row
RepeatLast,
// Custom strategy consists of a closure
// `pad(i, j) = padding value for cell at row i, column j`
// pad should be able to cross thread boundaries
Custom(Arc<dyn Fn(u64, u64) -> u64 + Send + Sync>),
}

pub trait Instruction<E: ExtensionField> {
Expand All @@ -38,7 +43,7 @@ pub trait Instruction<E: ExtensionField> {
// assign single instance giving step from trace
fn assign_instance(
config: &Self::InstructionConfig,
instance: &mut [MaybeUninit<E::BaseField>],
instance: &mut [E::BaseField],
lk_multiplicity: &mut LkMultiplicity,
step: &StepRecord,
) -> Result<(), ZKVMError>;
Expand All @@ -56,7 +61,8 @@ pub trait Instruction<E: ExtensionField> {
}
.max(1);
let lk_multiplicity = LkMultiplicity::default();
let mut raw_witin = RowMajorMatrix::<E::BaseField>::new(steps.len(), num_witin);
let mut raw_witin =
RowMajorMatrix::<E::BaseField>::new(steps.len(), num_witin, Self::padding_strategy());
let raw_witin_iter = raw_witin.par_batch_iter_mut(num_instance_per_batch);

raw_witin_iter
Expand All @@ -73,35 +79,6 @@ pub trait Instruction<E: ExtensionField> {
})
.collect::<Result<(), ZKVMError>>()?;

let num_padding_instances = raw_witin.num_padding_instances();
if num_padding_instances > 0 {
// Fill the padding based on strategy

let padding_instance = match Self::padding_strategy() {
InstancePaddingStrategy::Zero => {
vec![MaybeUninit::new(E::BaseField::ZERO); num_witin]
}
InstancePaddingStrategy::RepeatLast if steps.is_empty() => {
tracing::debug!("No {} steps to repeat, using zero padding", Self::name());
vec![MaybeUninit::new(E::BaseField::ZERO); num_witin]
}
InstancePaddingStrategy::RepeatLast => raw_witin[steps.len() - 1].to_vec(),
};

let num_padding_instance_per_batch = if num_padding_instances > 256 {
num_padding_instances.div_ceil(nthreads)
} else {
num_padding_instances
};
raw_witin
.par_batch_iter_padding_mut(num_padding_instance_per_batch)
.with_min_len(MIN_PAR_SIZE)
.for_each(|row| {
row.chunks_mut(num_witin)
.for_each(|instance| instance.copy_from_slice(padding_instance.as_slice()));
});
}

Ok((raw_witin, lk_multiplicity))
}
}
3 changes: 1 addition & 2 deletions ceno_zkvm/src/instructions/riscv/arith.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use crate::{
circuit_builder::CircuitBuilder, error::ZKVMError, instructions::Instruction, uint::Value,
witness::LkMultiplicity,
};
use core::mem::MaybeUninit;

/// This config handles R-Instructions that represent registers values as 2 * u16.
#[derive(Debug)]
Expand Down Expand Up @@ -88,7 +87,7 @@ impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for ArithInstruction<E

fn assign_instance(
config: &Self::InstructionConfig,
instance: &mut [MaybeUninit<E::BaseField>],
instance: &mut [<E as ExtensionField>::BaseField],
lk_multiplicity: &mut LkMultiplicity,
step: &StepRecord,
) -> Result<(), ZKVMError> {
Expand Down
4 changes: 2 additions & 2 deletions ceno_zkvm/src/instructions/riscv/arith_imm.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{marker::PhantomData, mem::MaybeUninit};
use std::marker::PhantomData;

use ceno_emul::StepRecord;
use ff_ext::ExtensionField;
Expand Down Expand Up @@ -57,7 +57,7 @@ impl<E: ExtensionField> Instruction<E> for AddiInstruction<E> {

fn assign_instance(
config: &Self::InstructionConfig,
instance: &mut [MaybeUninit<E::BaseField>],
instance: &mut [<E as ExtensionField>::BaseField],
lk_multiplicity: &mut LkMultiplicity,
step: &StepRecord,
) -> Result<(), ZKVMError> {
Expand Down
3 changes: 1 addition & 2 deletions ceno_zkvm/src/instructions/riscv/b_insn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ use crate::{
utils::i64_to_base,
witness::LkMultiplicity,
};
use core::mem::MaybeUninit;

// Opcode: 1100011
// Funct3:
Expand Down Expand Up @@ -88,7 +87,7 @@ impl<E: ExtensionField> BInstructionConfig<E> {

pub fn assign_instance(
&self,
instance: &mut [MaybeUninit<E::BaseField>],
instance: &mut [<E as ExtensionField>::BaseField],
lk_multiplicity: &mut LkMultiplicity,
step: &StepRecord,
) -> Result<(), ZKVMError> {
Expand Down
4 changes: 2 additions & 2 deletions ceno_zkvm/src/instructions/riscv/branch/beq_circuit.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{marker::PhantomData, mem::MaybeUninit};
use std::marker::PhantomData;

use ceno_emul::{InsnKind, StepRecord};
use ff_ext::ExtensionField;
Expand Down Expand Up @@ -72,7 +72,7 @@ impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for BeqCircuit<E, I> {

fn assign_instance(
config: &Self::InstructionConfig,
instance: &mut [MaybeUninit<E::BaseField>],
instance: &mut [<E as ExtensionField>::BaseField],
lk_multiplicity: &mut LkMultiplicity,
step: &StepRecord,
) -> Result<(), ZKVMError> {
Expand Down
4 changes: 2 additions & 2 deletions ceno_zkvm/src/instructions/riscv/branch/blt.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{marker::PhantomData, mem::MaybeUninit};
use std::marker::PhantomData;

use ff_ext::ExtensionField;

Expand Down Expand Up @@ -66,7 +66,7 @@ impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for BltCircuit<E, I> {

fn assign_instance(
config: &Self::InstructionConfig,
instance: &mut [MaybeUninit<E::BaseField>],
instance: &mut [E::BaseField],
lk_multiplicity: &mut LkMultiplicity,
step: &ceno_emul::StepRecord,
) -> Result<(), ZKVMError> {
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/instructions/riscv/branch/bltu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for BltuCircuit<E, I>

fn assign_instance(
config: &Self::InstructionConfig,
instance: &mut [std::mem::MaybeUninit<E::BaseField>],
instance: &mut [E::BaseField],
lk_multiplicity: &mut LkMultiplicity,
step: &ceno_emul::StepRecord,
) -> Result<(), ZKVMError> {
Expand Down
8 changes: 3 additions & 5 deletions ceno_zkvm/src/instructions/riscv/config.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::mem::MaybeUninit;

use crate::{expression::WitIn, set_val, utils::i64_to_base, witness::LkMultiplicity};
use goldilocks::SmallField;
use itertools::Itertools;
Expand All @@ -25,7 +23,7 @@ pub struct MsbInput<'a> {
impl MsbInput<'_> {
pub fn assign<F: SmallField>(
&self,
instance: &mut [MaybeUninit<F>],
instance: &mut [F],
config: &MsbConfig,
lk_multiplicity: &mut LkMultiplicity,
) -> (u8, u8) {
Expand Down Expand Up @@ -61,7 +59,7 @@ pub struct UIntLtuInput<'a> {
impl UIntLtuInput<'_> {
pub fn assign<F: SmallField>(
&self,
instance: &mut [MaybeUninit<F>],
instance: &mut [F],
config: &UIntLtuConfig,
lk_multiplicity: &mut LkMultiplicity,
) -> bool {
Expand Down Expand Up @@ -138,7 +136,7 @@ pub struct UIntLtInput<'a> {
impl UIntLtInput<'_> {
pub fn assign<F: SmallField>(
&self,
instance: &mut [MaybeUninit<F>],
instance: &mut [F],
config: &UIntLtConfig,
lk_multiplicity: &mut LkMultiplicity,
) -> bool {
Expand Down
Loading

0 comments on commit 979ee30

Please sign in to comment.