Skip to content

Commit

Permalink
chore: cleanup for allen (#518)
Browse files Browse the repository at this point in the history
  • Loading branch information
jtguibas authored Apr 15, 2024
1 parent 7390181 commit c440969
Show file tree
Hide file tree
Showing 91 changed files with 1,922 additions and 589 deletions.
113 changes: 67 additions & 46 deletions core/src/air/builder.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use std::iter::once;

use itertools::Itertools;
use p3_air::{AirBuilder, FilteredAirBuilder};
use p3_air::{AirBuilderWithPublicValues, PermutationAirBuilder};
use p3_field::{AbstractField, Field};
use p3_uni_stark::StarkGenericConfig;
use p3_uni_stark::{ProverConstraintFolder, SymbolicAirBuilder, VerifierConstraintFolder};

use super::interaction::AirInteraction;
Expand All @@ -11,10 +15,6 @@ use crate::cpu::columns::OpcodeSelectorCols;
use crate::lookup::InteractionKind;
use crate::memory::MemoryAccessCols;
use crate::{bytes::ByteOpcode, memory::MemoryCols};
use p3_field::{AbstractField, Field};

use p3_uni_stark::StarkGenericConfig;
use std::iter::once;

/// A Builder with the ability to encode the existance of interactions with other AIRs by sending
/// and receiving messages.
Expand Down Expand Up @@ -240,6 +240,25 @@ pub trait WordAirBuilder: ByteAirBuilder {
result
}

/// Same as `if_else` above, but arguments are `Word` instead of individual expressions.
fn select_word<ECond, EA, EB>(
&mut self,
condition: ECond,
a: Word<EA>,
b: Word<EB>,
) -> Word<Self::Expr>
where
ECond: Into<Self::Expr> + Clone,
EA: Into<Self::Expr> + Clone,
EB: Into<Self::Expr> + Clone,
{
let mut res = vec![];
for i in 0..WORD_SIZE {
res.push(self.if_else(condition.clone(), a[i].clone(), b[i].clone()));
}
Word(res.try_into().unwrap())
}

/// Check that each limb of the given slice is a u8.
fn slice_range_check_u8<
EWord: Into<Self::Expr> + Clone,
Expand Down Expand Up @@ -426,11 +445,11 @@ pub trait AluAirBuilder: BaseAirBuilder {

/// A trait which contains methods related to memory interactions in an AIR.
pub trait MemoryAirBuilder: BaseAirBuilder {
/// Constraints a memory read or write.
/// Constrain a memory read or write.
///
/// This method verifies that a memory access timestamp (shard, clk) is greater than the
/// previous access's timestamp. It will also add to the memory argument.
fn constraint_memory_access<EClk, EShard, Ea, Eb, EVerify, M>(
fn eval_memory_access<EClk, EShard, Ea, Eb, EVerify, M>(
&mut self,
shard: EShard,
clk: EClk,
Expand All @@ -453,7 +472,7 @@ pub trait MemoryAirBuilder: BaseAirBuilder {
self.assert_bool(do_check.clone());

// Verify that the current memory access time is greater than the previous's.
self.verify_mem_access_ts(mem_access, do_check.clone(), shard.clone(), clk.clone());
self.eval_memory_access_timestamp(mem_access, do_check.clone(), shard.clone(), clk.clone());

// Add to the memory argument.
let addr = addr.into();
Expand Down Expand Up @@ -485,13 +504,39 @@ pub trait MemoryAirBuilder: BaseAirBuilder {
));
}

/// Constraints a memory read or write to a slice of `MemoryAccessCols`.
fn eval_memory_access_slice<EShard, Ea, Eb, EVerify, M>(
&mut self,
shard: EShard,
clk: Self::Expr,
initial_addr: Ea,
memory_access_slice: &[M],
verify_memory_access: EVerify,
) where
EShard: Into<Self::Expr> + Copy,
Ea: Into<Self::Expr> + Copy,
Eb: Into<Self::Expr> + Copy,
EVerify: Into<Self::Expr> + Copy,
M: MemoryCols<Eb>,
{
for (i, access_slice) in memory_access_slice.iter().enumerate() {
self.eval_memory_access(
shard,
clk.clone(),
initial_addr.into() + Self::Expr::from_canonical_usize(i * 4),
access_slice,
verify_memory_access,
);
}
}

/// Verifies the memory access timestamp.
///
/// This method verifies that the current memory access happend after the previous one's. Specifically
/// it will ensure that if the current and previous access are in the same shard, then the
/// current's clk val is greater than the previous's. If they are not in the same shard, then
/// it will ensure that the current's shard val is greater than the previous's.
fn verify_mem_access_ts<Eb, EVerify, EShard, EClk>(
/// This method verifies that the current memory access happend after the previous one's.
/// Specifically it will ensure that if the current and previous access are in the same shard,
/// then the current's clk val is greater than the previous's. If they are not in the same
/// shard, then it will ensure that the current's shard val is greater than the previous's.
fn eval_memory_access_timestamp<Eb, EVerify, EShard, EClk>(
&mut self,
mem_access: &MemoryAccessCols<Eb>,
do_check: EVerify,
Expand Down Expand Up @@ -526,15 +571,17 @@ pub trait MemoryAirBuilder: BaseAirBuilder {
// Assert `current_comp_val > prev_comp_val`. We check this by asserting that
// `0 <= current_comp_val-prev_comp_val-1 < 2^24`.
//
// The equivalence of these statements comes from the fact that if `current_comp_val <= prev_comp_val`,
// then `current_comp_val-prev_comp_val-1 < 0` and will underflow in the prime field,
// resulting in a value that is `>= 2^24` as long as both `current_comp_val, prev_comp_val` are
// range-checked to be `<2^24` and as long as we're working in a field larger than `2 * 2^24`
// (which is true of the BabyBear and Mersenne31 prime).
// The equivalence of these statements comes from the fact that if
// `current_comp_val <= prev_comp_val`, then `current_comp_val-prev_comp_val-1 < 0` and will
// underflow in the prime field, resulting in a value that is `>= 2^24` as long as both
// `current_comp_val, prev_comp_val` are range-checked to be `<2^24` and as long as we're
// working in a field larger than `2 * 2^24` (which is true of the BabyBear and Mersenne31
// prime).
let diff_minus_one = current_comp_val - prev_comp_value - Self::Expr::one();

// Verify that mem_access.ts_diff = mem_access.ts_diff_16bit_limb + mem_access.ts_diff_8bit_limb * 2^16.
self.verify_range_24bits(
// Verify that mem_access.ts_diff = mem_access.ts_diff_16bit_limb
// + mem_access.ts_diff_8bit_limb * 2^16.
self.eval_range_check_24bits(
diff_minus_one,
mem_access.diff_16bit_limb.clone(),
mem_access.diff_8bit_limb.clone(),
Expand All @@ -549,7 +596,7 @@ pub trait MemoryAirBuilder: BaseAirBuilder {
/// check on it's limbs. It will also verify that the limbs are correct. This method is needed
/// since the memory access timestamp check (see [Self::verify_mem_access_ts]) needs to assume
/// the clk is within 24 bits.
fn verify_range_24bits<EValue, ELimb, EShard, EVerify>(
fn eval_range_check_24bits<EValue, ELimb, EShard, EVerify>(
&mut self,
value: EValue,
limb_16: ELimb,
Expand Down Expand Up @@ -588,32 +635,6 @@ pub trait MemoryAirBuilder: BaseAirBuilder {
do_check,
)
}

/// Constraints a memory read or write to a slice of `MemoryAccessCols`.
fn constraint_memory_access_slice<EShard, Ea, Eb, EVerify, M>(
&mut self,
shard: EShard,
clk: Self::Expr,
initial_addr: Ea,
memory_access_slice: &[M],
verify_memory_access: EVerify,
) where
EShard: Into<Self::Expr> + Copy,
Ea: Into<Self::Expr> + Copy,
Eb: Into<Self::Expr> + Copy,
EVerify: Into<Self::Expr> + Copy,
M: MemoryCols<Eb>,
{
for (i, access_slice) in memory_access_slice.iter().enumerate() {
self.constraint_memory_access(
shard,
clk.clone(),
initial_addr.into() + Self::Expr::from_canonical_usize(i * 4),
access_slice,
verify_memory_access,
);
}
}
}

/// A trait which contains methods related to program interactions in an AIR.
Expand Down
3 changes: 2 additions & 1 deletion core/src/air/extension.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use std::ops::{Add, Mul, Neg, Sub};

use p3_field::{
extension::{BinomialExtensionField, BinomiallyExtendable},
AbstractExtensionField, AbstractField,
};
use sp1_derive::AlignedBorrow;
use std::ops::{Add, Mul, Neg, Sub};

const DEGREE: usize = 4;

Expand Down
6 changes: 3 additions & 3 deletions core/src/air/machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::stark::MachineRecord;

pub use sp1_derive::MachineAir;

/// An AIR that is part of a Risc-V AIR arithmetization.
/// An AIR that is part of a multi table AIR arithmetization.
pub trait MachineAir<F: Field>: BaseAir<F> {
/// The execution record containing events for producing the air trace.
type Record: MachineRecord;
Expand All @@ -31,13 +31,13 @@ pub trait MachineAir<F: Field>: BaseAir<F> {
/// Whether this execution record contains events for this air.
fn included(&self, shard: &Self::Record) -> bool;

/// The width of the preprocessed trace.
fn preprocessed_width(&self) -> usize {
0
}

/// Generate the preprocessed trace given a specific program.
#[allow(unused_variables)]
fn generate_preprocessed_trace(&self, program: &Self::Program) -> Option<RowMajorMatrix<F>> {
fn generate_preprocessed_trace(&self, _program: &Self::Program) -> Option<RowMajorMatrix<F>> {
None
}
}
1 change: 1 addition & 0 deletions core/src/air/polynomial.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use core::fmt::Debug;
use core::ops::{Add, AddAssign, Mul, Neg, Sub};

use itertools::Itertools;
use p3_field::{AbstractExtensionField, AbstractField, Field};

Expand Down
10 changes: 6 additions & 4 deletions core/src/air/public_values.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
use crate::stark::PROOF_MAX_NUM_PVS;

use super::Word;
use core::fmt::Debug;
use core::mem::size_of;
use std::iter::once;

use itertools::Itertools;
use p3_field::{AbstractField, PrimeField32};
use serde::{Deserialize, Serialize};
use std::iter::once;

use super::Word;
use crate::stark::PROOF_MAX_NUM_PVS;

/// The number of words needed to represent a public value digest.
/// The number of non padded elements in the SP1 proofs public values vec.
pub const SP1_PROOF_NUM_PV_ELTS: usize = size_of::<PublicValues<Word<u8>, u8>>();

Expand Down
3 changes: 1 addition & 2 deletions core/src/air/word.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use core::fmt::Debug;
use std::array::IntoIter;
use std::ops::{Index, IndexMut};

Expand All @@ -8,8 +9,6 @@ use p3_field::Field;
use serde::{Deserialize, Serialize};
use sp1_derive::AlignedBorrow;

use core::fmt::Debug;

use super::SP1AirBuilder;

/// The size of a word in bytes.
Expand Down
24 changes: 13 additions & 11 deletions core/src/alu/add_sub/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use core::borrow::{Borrow, BorrowMut};
use core::mem::size_of;

use p3_air::{Air, BaseAir};
use p3_field::PrimeField;
use p3_matrix::dense::RowMajorMatrix;
Expand Down Expand Up @@ -35,11 +36,6 @@ pub struct AddSubCols<T> {
/// The shard number, used for byte lookup table.
pub shard: T,

/// Boolean to indicate whether the row is for an add operation.
pub is_add: T,
/// Boolean to indicate whether the row is for a sub operation.
pub is_sub: T,

/// Instance of `AddOperation` to handle addition logic in `AddSubChip`'s ALU operations.
/// It's result will be `a` for the add operation and `b` for the sub operation.
pub add_operation: AddOperation<T>,
Expand All @@ -49,6 +45,12 @@ pub struct AddSubCols<T> {

/// The second input operand. This will be `c` for both operations.
pub operand_2: Word<T>,

/// Boolean to indicate whether the row is for an add operation.
pub is_add: T,

/// Boolean to indicate whether the row is for a sub operation.
pub is_sub: T,
}

impl<F: PrimeField> MachineAir<F> for AddSubChip {
Expand Down Expand Up @@ -143,19 +145,14 @@ where
let local = main.row_slice(0);
let local: &AddSubCols<AB::Var> = (*local).borrow();

builder.assert_bool(local.is_add);
builder.assert_bool(local.is_sub);
let is_real = local.is_add + local.is_sub;
builder.assert_bool(is_real.clone());

// Evaluate the addition operation.
AddOperation::<AB::F>::eval(
builder,
local.operand_1,
local.operand_2,
local.add_operation,
local.shard,
is_real,
local.is_add + local.is_sub,
);

// Receive the arguments. There are seperate receives for ADD and SUB.
Expand All @@ -179,6 +176,11 @@ where
local.is_sub,
);

let is_real = local.is_add + local.is_sub;
builder.assert_bool(local.is_add);
builder.assert_bool(local.is_sub);
builder.assert_bool(is_real);

// Degree 3 constraint to avoid "OodEvaluationMismatch".
builder.assert_zero(
local.operand_1[0] * local.operand_1[0] * local.operand_1[0]
Expand Down
16 changes: 13 additions & 3 deletions core/src/alu/bitwise/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use core::borrow::{Borrow, BorrowMut};
use core::mem::size_of;

use p3_air::{Air, BaseAir};
use p3_field::PrimeField;
use p3_matrix::dense::RowMajorMatrix;
Expand Down Expand Up @@ -140,18 +141,27 @@ where
builder.send_byte(opcode.clone(), a, b, c, local.shard, mult.clone());
}

// Get the cpu opcode, which corresponds to the opcode being sent in the CPU table.
let cpu_opcode = local.is_xor * Opcode::XOR.as_field::<AB::F>()
+ local.is_or * Opcode::OR.as_field::<AB::F>()
+ local.is_and * Opcode::AND.as_field::<AB::F>();

// Receive the arguments.
builder.receive_alu(
local.is_xor * Opcode::XOR.as_field::<AB::F>()
+ local.is_or * Opcode::OR.as_field::<AB::F>()
+ local.is_and * Opcode::AND.as_field::<AB::F>(),
cpu_opcode,
local.a,
local.b,
local.c,
local.shard,
local.is_xor + local.is_or + local.is_and,
);

let is_real = local.is_xor + local.is_or + local.is_and;
builder.assert_bool(local.is_xor);
builder.assert_bool(local.is_or);
builder.assert_bool(local.is_and);
builder.assert_bool(is_real);

// Degree 3 constraint to avoid "OodEvaluationMismatch".
builder.assert_zero(
local.a[0] * local.b[0] * local.c[0] - local.a[0] * local.b[0] * local.c[0],
Expand Down
14 changes: 8 additions & 6 deletions core/src/alu/divrem/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ mod utils;

use core::borrow::{Borrow, BorrowMut};
use core::mem::size_of;

use p3_air::{Air, AirBuilder, BaseAir};
use p3_field::AbstractField;
use p3_field::PrimeField;
Expand Down Expand Up @@ -716,17 +717,18 @@ where
// Check that the flags are boolean.
{
let bool_flags = [
local.is_real,
local.is_remu,
local.is_div,
local.is_divu,
local.is_rem,
local.is_div,
local.b_neg,
local.rem_neg,
local.is_remu,
local.is_overflow,
local.b_msb,
local.rem_msb,
local.c_neg,
local.c_msb,
local.b_neg,
local.rem_neg,
local.c_neg,
local.is_real,
];

for flag in bool_flags.iter() {
Expand Down
Loading

0 comments on commit c440969

Please sign in to comment.