Skip to content

Commit

Permalink
feat: verify shard transitions + fixes (#482)
Browse files Browse the repository at this point in the history
  • Loading branch information
ctian1 authored Apr 17, 2024
1 parent 04e0a27 commit f81fac0
Show file tree
Hide file tree
Showing 23 changed files with 302 additions and 112 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 12 additions & 2 deletions core/src/air/machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use p3_air::BaseAir;
use p3_field::Field;
use p3_matrix::dense::RowMajorMatrix;

use crate::stark::MachineRecord;
use crate::{runtime::Program, stark::MachineRecord};

pub use sp1_derive::MachineAir;

Expand All @@ -11,7 +11,7 @@ pub trait MachineAir<F: Field>: BaseAir<F> {
/// The execution record containing events for producing the air trace.
type Record: MachineRecord;

type Program: Send + Sync;
type Program: MachineProgram<F>;

/// A unique identifier for this AIR as part of a machine.
fn name(&self) -> String;
Expand Down Expand Up @@ -41,3 +41,13 @@ pub trait MachineAir<F: Field>: BaseAir<F> {
None
}
}

pub trait MachineProgram<F>: Send + Sync {
fn pc_start(&self) -> F;
}

impl<F: Field> MachineProgram<F> for Program {
fn pc_start(&self) -> F {
F::from_canonical_u32(self.pc_start)
}
}
21 changes: 2 additions & 19 deletions core/src/cpu/air/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,6 @@ where
&local.op_b_access,
AB::Expr::one() - local.selectors.imm_b,
);
builder
.when_not(local.selectors.imm_b)
.assert_word_eq(local.op_b_val(), *local.op_b_access.prev_value());

builder.eval_memory_access(
local.shard,
Expand All @@ -84,9 +81,6 @@ where
&local.op_c_access,
AB::Expr::one() - local.selectors.imm_c,
);
builder
.when_not(local.selectors.imm_c)
.assert_word_eq(local.op_c_val(), *local.op_c_access.prev_value());

// Write the `a` or the result to the first register described in the instruction unless
// we are performing a branch or a store.
Expand Down Expand Up @@ -407,10 +401,10 @@ impl CpuChip {
.when(is_ecall_instruction.clone() * is_enter_unconstrained)
.assert_word_eq(local.op_a_val(), zero_word);

// When the syscall is not one of ENTER_UNCONSTRAINED, HINT_LEN, or HALT, op_a shouldn't change.
// When the syscall is not one of ENTER_UNCONSTRAINED or HINT_LEN, op_a shouldn't change.
builder
.when(is_ecall_instruction.clone())
.when_not(is_enter_unconstrained + is_hint_len + is_halt)
.when_not(is_enter_unconstrained + is_hint_len)
.assert_word_eq(local.op_a_val(), local.op_a_access.prev_value);

(
Expand Down Expand Up @@ -574,17 +568,6 @@ impl CpuChip {
builder.index_word_array(&commit_digest, &ecall_columns.index_bitmap);

let digest_word = local.op_c_access.prev_value();
// Verify b and c do not change during commit syscall.
builder
.when(
local.selectors.is_ecall * (is_commit.clone() + is_commit_deferred_proofs.clone()),
)
.assert_word_eq(*local.op_b_access.value(), *local.op_b_access.prev_value());
builder
.when(
local.selectors.is_ecall * (is_commit.clone() + is_commit_deferred_proofs.clone()),
)
.assert_word_eq(*local.op_c_access.value(), *local.op_c_access.prev_value());

// Verify the public_values_digest_word.
builder
Expand Down
4 changes: 2 additions & 2 deletions core/src/cpu/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,6 @@ impl CpuChip {
let syscall_id = cols.op_a_access.prev_value[0];
// let send_to_table = cols.op_a_access.prev_value[1];
// let num_cycles = cols.op_a_access.prev_value[2];
// let is_halt = cols.op_a_access.prev_value[3];

// Populate `is_enter_unconstrained`.
ecall_cols
Expand Down Expand Up @@ -621,7 +620,7 @@ mod tests {
use super::*;

use crate::runtime::{tests::simple_program, Instruction, Runtime};
use crate::utils::run_test;
use crate::utils::{run_test, setup_logger};

#[test]
fn generate_trace() {
Expand Down Expand Up @@ -671,6 +670,7 @@ mod tests {

#[test]
fn prove_trace() {
setup_logger();
let program = simple_program();
run_test(program).unwrap();
}
Expand Down
6 changes: 6 additions & 0 deletions core/src/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ impl SP1PublicValues {
}
}

impl AsRef<[u8]> for SP1PublicValues {
fn as_ref(&self) -> &[u8] {
&self.buffer.data
}
}

pub mod proof_serde {
use serde::{de::DeserializeOwned, Deserialize, Deserializer, Serialize};

Expand Down
85 changes: 59 additions & 26 deletions core/src/memory/program.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
use core::borrow::{Borrow, BorrowMut};
use core::mem::size_of;
use p3_air::{Air, BaseAir, PairBuilder};
use p3_air::{Air, AirBuilder, BaseAir, PairBuilder};
use p3_field::AbstractField;
use p3_field::PrimeField;
use p3_matrix::dense::RowMajorMatrix;
use p3_matrix::Matrix;
use std::collections::BTreeMap;

use sp1_derive::AlignedBorrow;

use crate::air::{AirInteraction, SP1AirBuilder};
use crate::air::{AirInteraction, PublicValues, SP1AirBuilder};
use crate::air::{MachineAir, Word};
use crate::operations::IsZeroOperation;
use crate::runtime::{ExecutionRecord, Program};
use crate::utils::pad_to_power_of_two;

Expand All @@ -24,16 +24,22 @@ pub const NUM_MEMORY_PROGRAM_MULT_COLS: usize = size_of::<MemoryProgramMultCols<
pub struct MemoryProgramPreprocessedCols<T> {
pub addr: T,
pub value: Word<T>,
pub is_real: T,
}

/// The column layout for the chip.
/// Multiplicity columns.
#[derive(AlignedBorrow, Clone, Copy, Default)]
#[repr(C)]
pub struct MemoryProgramMultCols<T> {
pub used: T,
/// The multiplicity of the event, must be 1 in the first shard and 0 otherwise.
pub multiplicity: T,
/// Columns to see if current shard is 1.
pub is_first_shard: IsZeroOperation<T>,
}

/// Chip that initializes memory that is provided from the program.
/// Chip that initializes memory that is provided from the program. The table is preprocessed and
/// receives each row in the first shard. This prevents any of these addresses from being
/// overwritten through the normal MemoryInit.
#[derive(Default)]
pub struct MemoryProgramChip;

Expand All @@ -58,13 +64,16 @@ impl<F: PrimeField> MachineAir<F> for MemoryProgramChip {

fn generate_preprocessed_trace(&self, program: &Self::Program) -> Option<RowMajorMatrix<F>> {
let program_memory = program.memory_image.clone();
// Note that BTreeMap is guaranteed to be sorted by key. This makes the row order
// deterministic.
let rows = program_memory
.into_iter()
.map(|(addr, word)| {
let mut row = [F::zero(); NUM_MEMORY_PROGRAM_PREPROCESSED_COLS];
let cols: &mut MemoryProgramPreprocessedCols<F> = row.as_mut_slice().borrow_mut();
cols.addr = F::from_canonical_u32(addr);
cols.value = Word::from(word);
cols.is_real = F::one();

row
})
Expand All @@ -91,30 +100,28 @@ impl<F: PrimeField> MachineAir<F> for MemoryProgramChip {
input: &ExecutionRecord,
_output: &mut ExecutionRecord,
) -> RowMajorMatrix<F> {
// Build a map of each address in program memory image to whether it was used.
// We have to do it from program because only the last shard has all the events, but every
// preprocessed row needs a corresponding mult row even if it's not used.
let mut addr_used_map = input
let program_memory_addrs = input
.program
.memory_image
.keys()
.map(|addr| (*addr, false))
.collect::<BTreeMap<_, _>>();
for event in &input.program_memory_events {
if event.used == 1 {
if let Some(used) = addr_used_map.get_mut(&event.addr) {
*used = true;
}
}
}
.copied()
.collect::<Vec<_>>();

let mult = if input.index == 1 {
F::one()
} else {
F::zero()
};

// Generate the trace rows for each event.
let rows = addr_used_map
.values()
.map(|used| {
let rows = program_memory_addrs
.into_iter()
.map(|_| {
let mut row = [F::zero(); NUM_MEMORY_PROGRAM_MULT_COLS];
let cols: &mut MemoryProgramMultCols<F> = row.as_mut_slice().borrow_mut();
cols.used = F::from_bool(*used);
cols.multiplicity = mult;
IsZeroOperation::populate(&mut cols.is_first_shard, input.index - 1);

row
})
.collect::<Vec<_>>();
Expand Down Expand Up @@ -147,21 +154,47 @@ where
AB: SP1AirBuilder + PairBuilder,
{
fn eval(&self, builder: &mut AB) {
let main = builder.main();
let preprocessed = builder.preprocessed();
let main = builder.main();

let prep_local = preprocessed.row_slice(0);
let prep_local: &MemoryProgramPreprocessedCols<AB::Var> = (*prep_local).borrow();

let mult_local = main.row_slice(0);
let mult_local: &MemoryProgramMultCols<AB::Var> = (*mult_local).borrow();

builder.assert_bool(mult_local.used);
// Get shard from public values and evaluate whether it is the first shard.
let public_values = PublicValues::<Word<AB::Expr>, AB::Expr>::from_vec(
builder
.public_values()
.iter()
.map(|elm| (*elm).into())
.collect::<Vec<_>>(),
);
IsZeroOperation::<AB::F>::eval(
builder,
public_values.shard - AB::Expr::one(),
mult_local.is_first_shard,
prep_local.is_real.into(),
);
let is_first_shard = mult_local.is_first_shard.result;

// Multiplicity must be either 0 or 1.
builder.assert_bool(mult_local.multiplicity);
// If first shard and preprocessed is real, multiplicity must be one.
builder
.when(is_first_shard * prep_local.is_real)
.assert_one(mult_local.multiplicity);
// If not first shard or preprocessed is not real, multiplicity must be zero.
builder
.when((AB::Expr::one() - is_first_shard) + (AB::Expr::one() - prep_local.is_real))
.assert_zero(mult_local.multiplicity);

let mut values = vec![AB::Expr::zero(), AB::Expr::zero(), prep_local.addr.into()];
values.extend(prep_local.value.map(Into::into));
builder.receive(AirInteraction::new(
values,
mult_local.used.into(),
mult_local.multiplicity.into(),
crate::lookup::InteractionKind::Memory,
));
}
Expand Down
31 changes: 0 additions & 31 deletions core/src/runtime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ use std::io::Write;
use std::rc::Rc;
use std::sync::Arc;

use nohash_hasher::BuildNoHashHasher;

use crate::memory::MemoryInitializeFinalizeEvent;
use crate::utils::env;
use crate::{alu::AluEvent, cpu::CpuEvent};
Expand Down Expand Up @@ -967,16 +965,6 @@ impl Runtime {
}

// SECTION: Set up all MemoryInitializeFinalizeEvents needed for memory argument.

// Program Memory is the global constants of the program. We need to mark which of these
// addresses are used by the program, as some invocations might not touch all addresses.
// program_memory_map maps an addr to its value and whether it was touched during the program.
let mut program_memory_map = HashMap::with_hasher(BuildNoHashHasher::<u32>::default());

for (key, value) in &self.program.memory_image {
program_memory_map.insert(key, (*value, true));
}

let memory_finalize_events = &mut self.record.memory_finalize_events;

// We handle the addr = 0 case separately, as we constrain it to be 0 in the first row
Expand All @@ -1002,30 +990,11 @@ impl Runtime {
}

let record = *self.state.memory.get(addr).unwrap();
if record.shard == 0 && record.timestamp == 0 {
// This means that we never accessed this memory location throughout our entire program.
// The only way this can happen is if this was in the program memory image.
// We mark this (addr, value) as not touched in the `program_memory_map` map.
program_memory_map.insert(addr, (record.value, false));
continue;
}

memory_finalize_events.push(MemoryInitializeFinalizeEvent::finalize_from_record(
*addr, &record,
));
}

let mut program_memory_events = program_memory_map
.into_iter()
.map(|(addr, (value, used))| {
MemoryInitializeFinalizeEvent::initialize(*addr, value, used)
})
.collect::<Vec<MemoryInitializeFinalizeEvent>>();
// Sort the program_memory_events by addr to create a canonical ordering for the
// preprocessed table, as this is part of the vkey.
program_memory_events.sort_by_key(|event| event.addr);

self.record.program_memory_events = program_memory_events;
}

fn get_syscall(&mut self, code: SyscallCode) -> Option<&Rc<dyn Syscall>> {
Expand Down
7 changes: 0 additions & 7 deletions core/src/runtime/record.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,6 @@ pub struct ExecutionRecord {

pub memory_finalize_events: Vec<MemoryInitializeFinalizeEvent>,

pub program_memory_events: Vec<MemoryInitializeFinalizeEvent>,

pub bls12381_decompress_events: Vec<ECDecompressEvent>,

/// The public values.
Expand Down Expand Up @@ -272,8 +270,6 @@ impl MachineRecord for ExecutionRecord {
.append(&mut other.memory_initialize_events);
self.memory_finalize_events
.append(&mut other.memory_finalize_events);
self.program_memory_events
.append(&mut other.program_memory_events);
}

fn shard(mut self, config: &ShardingConfig) -> Vec<Self> {
Expand Down Expand Up @@ -479,9 +475,6 @@ impl MachineRecord for ExecutionRecord {
last_shard
.memory_finalize_events
.extend_from_slice(&self.memory_finalize_events);
last_shard
.program_memory_events
.extend_from_slice(&self.program_memory_events);

shards
}
Expand Down
4 changes: 0 additions & 4 deletions core/src/runtime/syscall.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,6 @@ impl SyscallCode {
pub fn num_cycles(&self) -> u32 {
(*self as u32).to_le_bytes()[2].into()
}

pub fn is_halt(&self) -> u32 {
(*self as u32).to_le_bytes()[3].into()
}
}

pub trait Syscall {
Expand Down
Loading

0 comments on commit f81fac0

Please sign in to comment.