Skip to content

Commit

Permalink
Stwo support "addition and equal test" (#1942)
Browse files Browse the repository at this point in the history
## Initial Support for `stwo` Backend

This PR introduces the initial support for the `stwo` backend in the
Powdr pipeline.

### Summary
- **Constraint Proving**: Added functionality to prove and verify
constraints between columns.
- **Test Added**: A test case has been added to validate the above
functionality.

### Remaining Tasks
- Implement the `setup` function.
- Use `logup` for the next reference (now the next reference creates a
new column).
- Check the constant columns and public values in stwo and implement to
Powdr.

### How to Test
To test the `stwo` backend, use the following command:

```bash
cargo test --features stwo --package powdr-pipeline --test pil -- stwo_add_and_equal --exact --show-output

---------

Co-authored-by: Shuang Wu <[email protected]>
Co-authored-by: Shuang Wu <[email protected]>
Co-authored-by: Thibaut Schaeffer <[email protected]>
  • Loading branch information
4 people authored Nov 18, 2024
1 parent 58624e5 commit 8f3a572
Show file tree
Hide file tree
Showing 8 changed files with 390 additions and 31 deletions.
2 changes: 1 addition & 1 deletion backend/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ p3-commit = { git = "https://github.com/plonky3/Plonky3.git", rev = "2192432ddf2
p3-matrix = { git = "https://github.com/plonky3/Plonky3.git", rev = "2192432ddf28e7359dd2c577447886463e6124f0", optional = true }
p3-uni-stark = { git = "https://github.com/plonky3/Plonky3.git", rev = "2192432ddf28e7359dd2c577447886463e6124f0", optional = true }
# TODO: Change this to main branch when the `andrew/dev/update-toolchain` branch is merged,the main branch is using "nightly-2024-01-04", not compatiable with plonky3
stwo-prover = { git = "https://github.com/starkware-libs/stwo.git", optional = true, rev = "52d050c18b5dbc74af40214b3b441a6f60a20d41" }
stwo-prover = { git = "https://github.com/starkware-libs/stwo.git", optional = true, rev = "e6d10bc107c11cce54bb4aa152c3afa2e15e92c1" }

strum = { version = "0.24.1", features = ["derive"] }
log = "0.4.17"
Expand Down
10 changes: 1 addition & 9 deletions backend/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,7 @@ impl BackendType {
Box::new(composite::CompositeBackendFactory::new(plonky3::Factory))
}
#[cfg(feature = "stwo")]
BackendType::Stwo => Box::new(stwo::StwoProverFactory),
#[cfg(not(any(
feature = "halo2",
feature = "estark-polygon",
feature = "estark-starky",
feature = "plonky3",
feature = "stwo"
)))]
_ => panic!("Empty backend."),
BackendType::Stwo => Box::new(stwo::Factory),
}
}
}
Expand Down
193 changes: 193 additions & 0 deletions backend/src/stwo/circuit_builder.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
use num_traits::Zero;
use std::fmt::Debug;
use std::ops::{Add, AddAssign, Mul, Neg, Sub};

extern crate alloc;
use alloc::{collections::btree_map::BTreeMap, string::String, vec::Vec};
use powdr_ast::analyzed::{
AlgebraicBinaryOperation, AlgebraicBinaryOperator, AlgebraicExpression, Analyzed, Identity,
};
use powdr_number::{FieldElement, LargeInt};
use std::sync::Arc;

use powdr_ast::analyzed::{
AlgebraicUnaryOperation, AlgebraicUnaryOperator, PolyID, PolynomialType,
};
use stwo_prover::constraint_framework::{EvalAtRow, FrameworkComponent, FrameworkEval};
use stwo_prover::core::backend::ColumnOps;
use stwo_prover::core::fields::m31::{BaseField, M31};
use stwo_prover::core::fields::{ExtensionOf, FieldExpOps, FieldOps};
use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation};
use stwo_prover::core::poly::BitReversedOrder;
use stwo_prover::core::ColumnVec;

pub type PowdrComponent<'a, F> = FrameworkComponent<PowdrEval<F>>;

pub(crate) fn gen_stwo_circuit_trace<T, B, F>(
witness: &[(String, Vec<T>)],
) -> ColumnVec<CircleEvaluation<B, BaseField, BitReversedOrder>>
where
T: FieldElement, //only Merenne31Field is supported, checked in runtime
B: FieldOps<M31> + ColumnOps<F>, // Ensure B implements FieldOps for M31
F: ExtensionOf<BaseField>,
{
assert!(
witness
.iter()
.all(|(_name, vec)| vec.len() == witness[0].1.len()),
"All Vec<T> in witness must have the same length. Mismatch found!"
);
let domain = CanonicCoset::new(witness[0].1.len().ilog2()).circle_domain();
witness
.iter()
.map(|(_name, values)| {
let values = values
.iter()
.map(|v| v.try_into_i32().unwrap().into())
.collect();
CircleEvaluation::new(domain, values)
})
.collect()
}

pub struct PowdrEval<T> {
analyzed: Arc<Analyzed<T>>,
witness_columns: BTreeMap<PolyID, usize>,
}

impl<T: FieldElement> PowdrEval<T> {
pub fn new(analyzed: Arc<Analyzed<T>>) -> Self {
let witness_columns: BTreeMap<PolyID, usize> = analyzed
.definitions_in_source_order(PolynomialType::Committed)
.flat_map(|(symbol, _)| symbol.array_elements())
.enumerate()
.map(|(index, (_, id))| (id, index))
.collect();

Self {
analyzed,
witness_columns,
}
}
}

impl<T: FieldElement> FrameworkEval for PowdrEval<T> {
fn log_size(&self) -> u32 {
self.analyzed.degree().ilog2()
}
fn max_constraint_log_degree_bound(&self) -> u32 {
self.analyzed.degree().ilog2() + 1
}
fn evaluate<E: EvalAtRow>(&self, mut eval: E) -> E {
assert!(
self.analyzed.constant_count() == 0 && self.analyzed.publics_count() == 0,
"Error: Expected no fixed columns nor public inputs, as they are not supported yet.",
);

let witness_eval: BTreeMap<PolyID, [<E as EvalAtRow>::F; 2]> = self
.witness_columns
.keys()
.map(|poly_id| (*poly_id, eval.next_interaction_mask(0, [0, 1])))
.collect();

for id in self
.analyzed
.identities_with_inlined_intermediate_polynomials()
{
match id {
Identity::Polynomial(identity) => {
let expr = to_stwo_expression(&identity.expression, &witness_eval);
eval.add_constraint(expr);
}
Identity::Connect(..) => {
unimplemented!("Connect is not implemented in stwo yet")
}
Identity::Lookup(..) => {
unimplemented!("Lookup is not implemented in stwo yet")
}
Identity::Permutation(..) => {
unimplemented!("Permutation is not implemented in stwo yet")
}
Identity::PhantomPermutation(..) => {}
Identity::PhantomLookup(..) => {}
}
}
eval
}
}

fn to_stwo_expression<T: FieldElement, F>(
expr: &AlgebraicExpression<T>,
witness_eval: &BTreeMap<PolyID, [F; 2]>,
) -> F
where
F: FieldExpOps
+ Clone
+ Debug
+ Zero
+ Neg<Output = F>
+ AddAssign
+ AddAssign<BaseField>
+ Add<F, Output = F>
+ Sub<F, Output = F>
+ Mul<BaseField, Output = F>
+ Neg<Output = F>
+ From<BaseField>,
{
use AlgebraicBinaryOperator::*;
match expr {
AlgebraicExpression::Reference(r) => {
let poly_id = r.poly_id;

match poly_id.ptype {
PolynomialType::Committed => match r.next {
false => witness_eval[&poly_id][0].clone(),
true => witness_eval[&poly_id][1].clone(),
},
PolynomialType::Constant => {
unimplemented!("Constant polynomials are not supported in stwo yet")
}
PolynomialType::Intermediate => {
unimplemented!("Intermediate polynomials are not supported in stwo yet")
}
}
}
AlgebraicExpression::PublicReference(..) => {
unimplemented!("Public references are not supported in stwo yet")
}
AlgebraicExpression::Number(n) => F::from(M31::from(n.try_into_i32().unwrap())),
AlgebraicExpression::BinaryOperation(AlgebraicBinaryOperation {
left,
op: Pow,
right,
}) => match **right {
AlgebraicExpression::Number(n) => {
let left = to_stwo_expression(left, witness_eval);
(0u32..n.to_integer().try_into_u32().unwrap())
.fold(F::one(), |acc, _| acc * left.clone())
}
_ => unimplemented!("pow with non-constant exponent"),
},
AlgebraicExpression::BinaryOperation(AlgebraicBinaryOperation { left, op, right }) => {
let left = to_stwo_expression(left, witness_eval);
let right = to_stwo_expression(right, witness_eval);

match op {
Add => left + right,
Sub => left - right,
Mul => left * right,
Pow => unreachable!("This case was handled above"),
}
}
AlgebraicExpression::UnaryOperation(AlgebraicUnaryOperation { op, expr }) => {
let expr = to_stwo_expression(expr, witness_eval);

match op {
AlgebraicUnaryOperator::Minus => -expr,
}
}
AlgebraicExpression::Challenge(_challenge) => {
unimplemented!("challenges are not supported in stwo yet")
}
}
}
40 changes: 30 additions & 10 deletions backend/src/stwo/mod.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,28 @@
use serde::de::DeserializeOwned;
use serde::Serialize;
use std::io;
use std::path::PathBuf;
use std::sync::Arc;

use crate::{Backend, BackendFactory, BackendOptions, Error, Proof};
use crate::{
field_filter::generalize_factory, Backend, BackendFactory, BackendOptions, Error, Proof,
};
use powdr_ast::analyzed::Analyzed;
use powdr_executor::constant_evaluator::{get_uniquely_sized_cloned, VariablySizedColumn};
use powdr_executor::witgen::WitgenCallback;
use powdr_number::FieldElement;
use powdr_number::{FieldElement, Mersenne31Field};
use prover::StwoProver;
use stwo_prover::core::backend::{simd::SimdBackend, BackendForChannel};
use stwo_prover::core::channel::{Blake2sChannel, Channel, MerkleChannel};
use stwo_prover::core::vcs::blake2_merkle::Blake2sMerkleChannel;

mod circuit_builder;
mod prover;

#[allow(dead_code)]
pub(crate) struct StwoProverFactory;

impl<F: FieldElement> BackendFactory<F> for StwoProverFactory {
struct RestrictedFactory;

impl<F: FieldElement> BackendFactory<F> for RestrictedFactory {
#[allow(unreachable_code)]
#[allow(unused_variables)]
fn create(
Expand All @@ -37,16 +45,28 @@ impl<F: FieldElement> BackendFactory<F> for StwoProverFactory {
let fixed = Arc::new(
get_uniquely_sized_cloned(&fixed).map_err(|_| Error::NoVariableDegreeAvailable)?,
);
let stwo = Box::new(StwoProver::new(pil, fixed, setup)?);
let stwo: Box<StwoProver<F, SimdBackend, Blake2sMerkleChannel, Blake2sChannel>> =
Box::new(StwoProver::new(pil, fixed)?);
Ok(stwo)
}
}

impl<T: FieldElement> Backend<T> for StwoProver<T> {
generalize_factory!(Factory <- RestrictedFactory, [Mersenne31Field]);

impl<T: FieldElement, MC: MerkleChannel + Send, C: Channel + Send> Backend<T>
for StwoProver<T, SimdBackend, MC, C>
where
SimdBackend: BackendForChannel<MC>,
MC: MerkleChannel,
C: Channel,
MC::H: DeserializeOwned + Serialize,
{
#[allow(unused_variables)]
fn verify(&self, proof: &[u8], instances: &[Vec<T>]) -> Result<(), Error> {
assert!(instances.len() == 1);
unimplemented!()
assert_eq!(instances.len(), 1);
let instances = &instances[0];

Ok(self.verify(proof, instances)?)
}
#[allow(unreachable_code)]
#[allow(unused_variables)]
Expand All @@ -59,7 +79,7 @@ impl<T: FieldElement> Backend<T> for StwoProver<T> {
if prev_proof.is_some() {
return Err(Error::NoAggregationAvailable);
}
unimplemented!()
Ok(StwoProver::prove(self, witness)?)
}
#[allow(unused_variables)]
fn export_verification_key(&self, output: &mut dyn io::Write) -> Result<(), Error> {
Expand Down
Loading

0 comments on commit 8f3a572

Please sign in to comment.