Skip to content

Commit

Permalink
Merge pull request #131 from erhant/erhant/public-output-array
Browse files Browse the repository at this point in the history
feat: Allow array of fields on public output
  • Loading branch information
mimoo authored Jun 21, 2024
2 parents d422cc3 + f6a563d commit 96cd1f1
Show file tree
Hide file tree
Showing 11 changed files with 146 additions and 17 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ ark-ff = "0.3.0"
ark-bls12-381 = "0.3.0" # bls12-381 curve for r1cs backend
ark-bn254 = "0.3.0" # bn128 curve for r1cs backend
ark-serialize = "0.3.0" # serialization of arkworks types
educe = { version = "0.6", default-features = false, features = ["Hash", "PartialEq"] }
educe = { version = "0.6", default-features = false, features = ["Hash", "PartialEq", "PartialOrd"] }
ena = "0.14.0" # union-find implementation for the wiring
num-bigint = "0.4.3" # big int library
camino = "1.1.1" # to replace Path and PathBuf
Expand Down
17 changes: 17 additions & 0 deletions examples/fixture/asm/kimchi/public_output_array.asm
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
@ noname.0.7.0

DoubleGeneric<1>
DoubleGeneric<1>
DoubleGeneric<1>
DoubleGeneric<1,1,-1>
DoubleGeneric<1,0,0,0,-2>
DoubleGeneric<1,0,-1,0,6>
DoubleGeneric<0,0,-1,1>
DoubleGeneric<1,-1>
DoubleGeneric<1,-1>
(0,0) -> (7,0)
(1,0) -> (8,0)
(2,0) -> (3,1) -> (6,1)
(3,2) -> (4,0) -> (5,0) -> (6,0)
(5,2) -> (7,1)
(6,2) -> (8,1)
28 changes: 28 additions & 0 deletions examples/fixture/asm/kimchi/types_array_output.asm
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
@ noname.0.7.0

DoubleGeneric<1>
DoubleGeneric<1>
DoubleGeneric<1>
DoubleGeneric<1>
DoubleGeneric<1>
DoubleGeneric<1>
DoubleGeneric<2,0,-1>
DoubleGeneric<2,0,-1>
DoubleGeneric<2,0,-1>
DoubleGeneric<1,-1>
DoubleGeneric<2,0,-1>
DoubleGeneric<1,-1>
DoubleGeneric<1,-1>
DoubleGeneric<1,-1>
DoubleGeneric<1,-1>
DoubleGeneric<1,-1>
(0,0) -> (12,0)
(1,0) -> (13,0)
(2,0) -> (14,0)
(3,0) -> (15,0)
(4,0) -> (6,0) -> (8,0) -> (14,1)
(5,0) -> (7,0) -> (10,0) -> (13,1)
(6,2) -> (9,1) -> (12,1)
(7,2) -> (11,1) -> (15,1)
(8,2) -> (9,0)
(10,2) -> (11,0)
6 changes: 6 additions & 0 deletions examples/fixture/asm/r1cs/public_output_array.asm
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
@ noname.0.7.0

2 == (v_3 + v_4) * (1)
v_5 == (v_3 + v_4) * (v_3)
v_3 + v_4 + 6 == (v_1) * (1)
v_5 == (v_2) * (1)
8 changes: 8 additions & 0 deletions examples/fixture/asm/r1cs/types_array_output.asm
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
@ noname.0.7.0

2 * v_5 == (2 * v_5) * (1)
2 * v_6 == (2 * v_6) * (1)
2 * v_5 == (v_1) * (1)
v_6 == (v_2) * (1)
v_5 == (v_3) * (1)
2 * v_6 == (v_4) * (1)
6 changes: 6 additions & 0 deletions examples/public_output_array.no
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
fn main(pub public_input: Field, private_input: Field) -> [Field; 2] {
let xx = private_input + public_input;
assert_eq(xx, 2);
let yy = xx + 6;
return [yy, xx * public_input];
}
21 changes: 21 additions & 0 deletions examples/types_array_output.no
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
struct Thing {
xx: Field,
yy: Field,
}

fn main(pub xx: Field, pub yy: Field) -> [Thing; 2] {
let thing1 = Thing {
xx: xx * 2,
yy: yy,
};
let thing2 = Thing {
xx: xx,
yy: yy * 2,
};
let things = [thing1, thing2];

assert_eq(things[1].xx * 2, things[0].xx);
assert_eq(things[0].yy * 2, things[1].yy);

return things;
}
9 changes: 6 additions & 3 deletions src/backends/kimchi/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ pub mod asm;
pub mod builtin;
pub mod prover;

use educe::Educe;
use std::{
collections::{HashMap, HashSet},
collections::{BTreeMap, HashMap, HashSet},
fmt::Write,
ops::Neg as _,
};
Expand Down Expand Up @@ -243,9 +244,11 @@ impl KimchiVesta {
}
}

#[derive(Default, Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[derive(Default, Clone, Copy, Debug, Eq, Hash, Serialize, Deserialize, PartialEq, Ord, Educe)]
#[educe(PartialOrd)]
pub struct KimchiCellVar {
index: usize,
#[educe(PartialOrd(ignore))]
pub span: Span,
}

Expand Down Expand Up @@ -391,7 +394,7 @@ impl Backend for KimchiVesta {

let mut witness = vec![];
// compute each rows' vars, except for the deferred ones (public output)
let mut public_outputs_vars: HashMap<KimchiCellVar, Vec<(usize, usize)>> = HashMap::new();
let mut public_outputs_vars: BTreeMap<KimchiCellVar, Vec<(usize, usize)>> = BTreeMap::new();

// calculate witness except for public outputs
for (row, row_of_vars) in self.witness_table.iter().enumerate() {
Expand Down
9 changes: 3 additions & 6 deletions src/circuit_writer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,9 @@ impl<B: Backend> CircuitWriter<B> {

// create public output
if let Some(typ) = &function.sig.return_type {
if typ.kind != TyKind::Field {
unimplemented!();
}

// create it
circuit_writer.add_public_outputs(1, typ.span);
// whatever is the size of return type, we need to add that many public outputs
let size_of = circuit_writer.size_of(&typ.kind);
circuit_writer.add_public_outputs(size_of, typ.span);
}

// public inputs should be handled first
Expand Down
36 changes: 36 additions & 0 deletions src/tests/examples.rs
Original file line number Diff line number Diff line change
Expand Up @@ -436,3 +436,39 @@ fn test_literals(#[case] backend: BackendKind) -> miette::Result<()> {

Ok(())
}

#[rstest]
#[case::kimchi_vesta(BackendKind::KimchiVesta(KimchiVesta::new(false)))]
#[case::r1cs(BackendKind::R1csBls12_381(R1CS::new()))]
fn test_public_output_array(#[case] backend: BackendKind) -> miette::Result<()> {
let public_inputs = r#"{"public_input": "1"}"#;
let private_inputs = r#"{"private_input": "1"}"#;

test_file(
"public_output_array",
public_inputs,
private_inputs,
vec!["8", "2"],
backend,
)?;

Ok(())
}

#[rstest]
#[case::kimchi_vesta(BackendKind::KimchiVesta(KimchiVesta::new(false)))]
#[case::r1cs(BackendKind::R1csBls12_381(R1CS::new()))]
fn test_types_array_output(#[case] backend: BackendKind) -> miette::Result<()> {
let public_inputs = r#"{"xx": "1", "yy": "4"}"#;
let private_inputs = r#"{}"#;

test_file(
"types_array_output",
public_inputs,
private_inputs,
vec!["2", "4", "1", "8"], // 2x, y, x, 2y
backend,
)?;

Ok(())
}
21 changes: 14 additions & 7 deletions src/type_checker/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -379,14 +379,21 @@ impl<B: Backend> TypeChecker<B> {
// the output value returned by the main function is also a main_args with a special name (public_output)
if let Some(typ) = &function.sig.return_type {
if is_main {
if !matches!(typ.kind, TyKind::Field) {
unimplemented!();
match typ.kind {
TyKind::Field => {
typed_fn_env.store_type(
"public_output".to_string(),
TypeInfo::new_mut(typ.kind.clone(), typ.span),
)?;
}
TyKind::Array(_, _) => {
typed_fn_env.store_type(
"public_output".to_string(),
TypeInfo::new_mut(typ.kind.clone(), typ.span),
)?;
}
_ => unimplemented!(),
}

typed_fn_env.store_type(
"public_output".to_string(),
TypeInfo::new_mut(typ.kind.clone(), typ.span),
)?;
}
}

Expand Down

0 comments on commit 96cd1f1

Please sign in to comment.