Skip to content

Commit

Permalink
chore(recursion): poseidon2 loose ends (#672)
Browse files Browse the repository at this point in the history
  • Loading branch information
kevjue authored May 9, 2024
1 parent 227bbdd commit e699a98
Show file tree
Hide file tree
Showing 9 changed files with 64 additions and 23 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions prover/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,7 @@ impl SP1Prover {
runtime.witness_stream = witness_stream.into();
runtime.run();
let mut checkpoint = runtime.memory.clone();
let checkpoint_uninit = runtime.uninitialized_memory.clone();

// Execute runtime.
let machine = RecursionAirWideDeg3::machine(InnerSC::default());
Expand All @@ -568,6 +569,7 @@ impl SP1Prover {
e.1.timestamp = BabyBear::zero();
});
runtime.memory = checkpoint;
runtime.uninitialized_memory = checkpoint_uninit;
runtime.run();
runtime.print_stats();
tracing::info!(
Expand Down
3 changes: 3 additions & 0 deletions recursion/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,6 @@ serde_with = "3.6.1"
backtrace = { version = "0.3.71", features = ["serde"] }
arrayref = "0.3.6"
static_assertions = "1.1.0"

[dev-dependencies]
rand = "0.8.5"
1 change: 1 addition & 0 deletions recursion/core/src/multi/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ where
poseidon2_chip.eval_poseidon2(
&mut sub_builder,
local.poseidon2(),
next.poseidon2(),
local.poseidon2_receive_table,
local.poseidon2_memory_access.into(),
);
Expand Down
2 changes: 0 additions & 2 deletions recursion/core/src/poseidon2/columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ pub struct Poseidon2Cols<T: Copy> {
pub left_input: T,
pub right_input: T,
pub rounds: [T; 24], // 1 round for memory input; 1 round for initialize; 8 rounds for external; 13 rounds for internal; 1 round for memory output
pub is_computation: T,
pub is_memory_access: T,
pub round_specific_cols: RoundSpecificCols<T>,
}

Expand Down
66 changes: 52 additions & 14 deletions recursion/core/src/poseidon2/external.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use sp1_primitives::RC_16_30_U32;
use std::ops::Add;

use crate::air::{RecursionInteractionAirBuilder, RecursionMemoryAirBuilder};
use crate::memory::MemoryCols;
use crate::poseidon2_wide::{apply_m_4, internal_linear_layer};
use crate::runtime::Opcode;

Expand Down Expand Up @@ -37,6 +38,7 @@ impl Poseidon2Chip {
&self,
builder: &mut AB,
local: &Poseidon2Cols<AB::Var>,
next: &Poseidon2Cols<AB::Var>,
receive_table: AB::Var,
memory_access: AB::Expr,
) {
Expand Down Expand Up @@ -65,6 +67,7 @@ impl Poseidon2Chip {
self.eval_mem(
builder,
local,
next,
is_memory_read,
is_memory_write,
memory_access,
Expand All @@ -73,6 +76,7 @@ impl Poseidon2Chip {
self.eval_computation(
builder,
local,
next,
is_initial.into(),
is_external_layer.clone(),
is_internal_layer.clone(),
Expand All @@ -94,12 +98,12 @@ impl Poseidon2Chip {
&self,
builder: &mut AB,
local: &Poseidon2Cols<AB::Var>,
next: &Poseidon2Cols<AB::Var>,
is_memory_read: AB::Var,
is_memory_write: AB::Var,
memory_access: AB::Expr,
) {
let memory_access_cols = local.round_specific_cols.memory_access();

builder
.when(is_memory_read)
.assert_eq(local.left_input, memory_access_cols.addr_first_half);
Expand Down Expand Up @@ -128,12 +132,24 @@ impl Poseidon2Chip {
memory_access.clone(),
);
}

// For the memory read round, need to connect the memory val to the input of the next
// computation round.
let next_computation_col = next.round_specific_cols.computation();
for i in 0..WIDTH {
builder.when_transition().when(is_memory_read).assert_eq(
*memory_access_cols.mem_access[i].value(),
next_computation_col.input[i],
);
}
}

#[allow(clippy::too_many_arguments)]
fn eval_computation<AB: BaseAirBuilder + ExtensionAirBuilder>(
&self,
builder: &mut AB,
local: &Poseidon2Cols<AB::Var>,
next: &Poseidon2Cols<AB::Var>,
is_initial: AB::Expr,
is_external_layer: AB::Expr,
is_internal_layer: AB::Expr,
Expand All @@ -158,11 +174,11 @@ impl Poseidon2Chip {
let mut result: AB::Expr = computation_cols.input[i].into();
for r in 0..rounds {
if i == 0 {
result += local.rounds[r + 1]
result += local.rounds[r + 2]
* constants[r][i]
* (is_external_layer.clone() + is_internal_layer.clone());
} else {
result += local.rounds[r + 1] * constants[r][i] * is_external_layer.clone();
result += local.rounds[r + 2] * constants[r][i] * is_external_layer.clone();
}
}
builder
Expand Down Expand Up @@ -251,9 +267,26 @@ impl Poseidon2Chip {
let mut state: [AB::Expr; WIDTH] = sbox_result.clone();
internal_linear_layer(&mut state);
builder
.when(is_internal_layer)
.when(is_internal_layer.clone())
.assert_all_eq(state.clone(), computation_cols.output);
}

// Assert that the round's output values are equal the the next round's input values. For the
// last computation round, assert athat the output values are equal to the output memory values.
let next_row_computation = next.round_specific_cols.computation();
let next_row_memory_access = next.round_specific_cols.memory_access();
for i in 0..WIDTH {
let next_round_value = builder.if_else(
local.rounds[22],
*next_row_memory_access.mem_access[i].value(),
next_row_computation.input[i],
);

builder
.when_transition()
.when(is_initial.clone() + is_external_layer.clone() + is_internal_layer.clone())
.assert_eq(computation_cols.output[i], next_round_value);
}
}

fn eval_syscall<AB: BaseAirBuilder + ExtensionAirBuilder>(
Expand Down Expand Up @@ -295,9 +328,13 @@ where
let main = builder.main();
let local = main.row_slice(0);
let local: &Poseidon2Cols<AB::Var> = (*local).borrow();
let next = main.row_slice(1);
let next: &Poseidon2Cols<AB::Var> = (*next).borrow();

self.eval_poseidon2::<AB>(
builder,
local,
next,
Self::do_receive_table::<AB::Var>(local),
Self::do_memory_access::<AB::Var, AB::Expr>(local),
);
Expand All @@ -309,10 +346,10 @@ mod tests {
use itertools::Itertools;
use std::borrow::Borrow;
use std::time::Instant;
use zkhash::ark_ff::UniformRand;

use p3_baby_bear::BabyBear;
use p3_baby_bear::DiffusionMatrixBabyBear;
use p3_field::AbstractField;
use p3_matrix::{dense::RowMajorMatrix, Matrix};
use p3_poseidon2::Poseidon2;
use p3_poseidon2::Poseidon2ExternalMatrixGeneral;
Expand All @@ -324,7 +361,7 @@ mod tests {
};

use crate::{
poseidon2::{Poseidon2Chip, Poseidon2Event, WIDTH},
poseidon2::{Poseidon2Chip, Poseidon2Event},
runtime::ExecutionRecord,
};
use p3_symmetric::Permutation;
Expand All @@ -338,12 +375,12 @@ mod tests {
let chip = Poseidon2Chip {
fixed_log2_rows: None,
};
let test_inputs = vec![
[BabyBear::from_canonical_u32(1); WIDTH],
[BabyBear::from_canonical_u32(2); WIDTH],
[BabyBear::from_canonical_u32(3); WIDTH],
[BabyBear::from_canonical_u32(4); WIDTH],
];

let rng = &mut rand::thread_rng();

let test_inputs: Vec<[BabyBear; 16]> = (0..16)
.map(|_| core::array::from_fn(|_| BabyBear::rand(rng)))
.collect_vec();

let gt: Poseidon2<
BabyBear,
Expand Down Expand Up @@ -384,9 +421,10 @@ mod tests {
let chip = Poseidon2Chip {
fixed_log2_rows: None,
};
let rng = &mut rand::thread_rng();

let test_inputs = (0..16)
.map(|i| [BabyBear::from_canonical_u32(i); WIDTH])
let test_inputs: Vec<[BabyBear; 16]> = (0..16)
.map(|_| core::array::from_fn(|_| BabyBear::rand(rng)))
.collect_vec();

let gt: Poseidon2<
Expand Down
2 changes: 1 addition & 1 deletion recursion/core/src/poseidon2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ impl<F: PrimeField32> Poseidon2Event<F> {
MemoryRecord::new_read(F::zero(), Block::from(input[i]), F::one(), F::zero())
});
let output_records: [MemoryRecord<F>; WIDTH] = core::array::from_fn(|i| {
MemoryRecord::new_read(F::zero(), Block::from(output[i]), F::one(), F::zero())
MemoryRecord::new_read(F::zero(), Block::from(output[i]), F::two(), F::zero())
});

Self {
Expand Down
4 changes: 2 additions & 2 deletions recursion/core/src/poseidon2/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,15 @@ impl<F: PrimeField32> MachineAir<F> for Poseidon2Chip {
// Apply the round constants.
for j in 0..WIDTH {
computation_cols.add_rc[j] = computation_cols.input[j]
+ F::from_wrapped_u32(RC_16_30_U32[r - 1][j]);
+ F::from_wrapped_u32(RC_16_30_U32[r - 2][j]);
}
} else {
// Apply the round constants only on the first element.
computation_cols
.add_rc
.copy_from_slice(&computation_cols.input);
computation_cols.add_rc[0] =
computation_cols.input[0] + F::from_wrapped_u32(RC_16_30_U32[r - 1][0]);
computation_cols.input[0] + F::from_wrapped_u32(RC_16_30_U32[r - 2][0]);
};

// Apply the sbox.
Expand Down
6 changes: 2 additions & 4 deletions recursion/core/src/poseidon2_wide/external.rs
Original file line number Diff line number Diff line change
Expand Up @@ -493,10 +493,8 @@ mod tests {
.push(Poseidon2Event::dummy_from_input(input, output));
}

let trace: RowMajorMatrix<BabyBear> =
chip.generate_trace(&input_exec, &mut ExecutionRecord::<BabyBear>::default());

assert_eq!(trace.height(), test_inputs.len());
// Generate trace will assert for the expected outputs.
chip.generate_trace(&input_exec, &mut ExecutionRecord::<BabyBear>::default());
}

/// A test generating a trace for a single permutation that checks that the output is correct
Expand Down

0 comments on commit e699a98

Please sign in to comment.