diff --git a/Cargo.toml b/Cargo.toml index e0807f744..1d736a14d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,7 @@ ark-ff = "0.3.0" ark-bls12-381 = "0.3.0" # bls12-381 curve for r1cs backend ark-bn254 = "0.3.0" # bn128 curve for r1cs backend ark-serialize = "0.3.0" # serialization of arkworks types -educe = { version = "0.6", default-features = false, features = ["Hash", "PartialEq"] } +educe = { version = "0.6", default-features = false, features = ["Hash", "PartialEq", "PartialOrd"] } ena = "0.14.0" # union-find implementation for the wiring num-bigint = "0.4.3" # big int library camino = "1.1.1" # to replace Path and PathBuf diff --git a/examples/fixture/asm/kimchi/public_output_array.asm b/examples/fixture/asm/kimchi/public_output_array.asm new file mode 100644 index 000000000..bc6986b34 --- /dev/null +++ b/examples/fixture/asm/kimchi/public_output_array.asm @@ -0,0 +1,17 @@ +@ noname.0.7.0 + +DoubleGeneric<1> +DoubleGeneric<1> +DoubleGeneric<1> +DoubleGeneric<1,1,-1> +DoubleGeneric<1,0,0,0,-2> +DoubleGeneric<1,0,-1,0,6> +DoubleGeneric<0,0,-1,1> +DoubleGeneric<1,-1> +DoubleGeneric<1,-1> +(0,0) -> (7,0) +(1,0) -> (8,0) +(2,0) -> (3,1) -> (6,1) +(3,2) -> (4,0) -> (5,0) -> (6,0) +(5,2) -> (7,1) +(6,2) -> (8,1) diff --git a/examples/fixture/asm/kimchi/types_array_output.asm b/examples/fixture/asm/kimchi/types_array_output.asm new file mode 100644 index 000000000..71c87d11e --- /dev/null +++ b/examples/fixture/asm/kimchi/types_array_output.asm @@ -0,0 +1,28 @@ +@ noname.0.7.0 + +DoubleGeneric<1> +DoubleGeneric<1> +DoubleGeneric<1> +DoubleGeneric<1> +DoubleGeneric<1> +DoubleGeneric<1> +DoubleGeneric<2,0,-1> +DoubleGeneric<2,0,-1> +DoubleGeneric<2,0,-1> +DoubleGeneric<1,-1> +DoubleGeneric<2,0,-1> +DoubleGeneric<1,-1> +DoubleGeneric<1,-1> +DoubleGeneric<1,-1> +DoubleGeneric<1,-1> +DoubleGeneric<1,-1> +(0,0) -> (12,0) +(1,0) -> (13,0) +(2,0) -> (14,0) +(3,0) -> (15,0) +(4,0) -> (6,0) -> (8,0) -> (14,1) +(5,0) -> (7,0) -> (10,0) -> (13,1) +(6,2) -> (9,1) -> (12,1) +(7,2) -> (11,1) -> (15,1) +(8,2) -> (9,0) +(10,2) -> (11,0) diff --git a/examples/fixture/asm/r1cs/public_output_array.asm b/examples/fixture/asm/r1cs/public_output_array.asm new file mode 100644 index 000000000..caee1b5af --- /dev/null +++ b/examples/fixture/asm/r1cs/public_output_array.asm @@ -0,0 +1,6 @@ +@ noname.0.7.0 + +2 == (v_3 + v_4) * (1) +v_5 == (v_3 + v_4) * (v_3) +v_3 + v_4 + 6 == (v_1) * (1) +v_5 == (v_2) * (1) diff --git a/examples/fixture/asm/r1cs/types_array_output.asm b/examples/fixture/asm/r1cs/types_array_output.asm new file mode 100644 index 000000000..ad73bd17d --- /dev/null +++ b/examples/fixture/asm/r1cs/types_array_output.asm @@ -0,0 +1,8 @@ +@ noname.0.7.0 + +2 * v_5 == (2 * v_5) * (1) +2 * v_6 == (2 * v_6) * (1) +2 * v_5 == (v_1) * (1) +v_6 == (v_2) * (1) +v_5 == (v_3) * (1) +2 * v_6 == (v_4) * (1) diff --git a/examples/public_output_array.no b/examples/public_output_array.no new file mode 100644 index 000000000..60cd2f870 --- /dev/null +++ b/examples/public_output_array.no @@ -0,0 +1,6 @@ +fn main(pub public_input: Field, private_input: Field) -> [Field; 2] { + let xx = private_input + public_input; + assert_eq(xx, 2); + let yy = xx + 6; + return [yy, xx * public_input]; +} diff --git a/examples/types_array_output.no b/examples/types_array_output.no new file mode 100644 index 000000000..43c75493e --- /dev/null +++ b/examples/types_array_output.no @@ -0,0 +1,21 @@ +struct Thing { + xx: Field, + yy: Field, +} + +fn main(pub xx: Field, pub yy: Field) -> [Thing; 2] { + let thing1 = Thing { + xx: xx * 2, + yy: yy, + }; + let thing2 = Thing { + xx: xx, + yy: yy * 2, + }; + let things = [thing1, thing2]; + + assert_eq(things[1].xx * 2, things[0].xx); + assert_eq(things[0].yy * 2, things[1].yy); + + return things; +} diff --git a/src/backends/kimchi/mod.rs b/src/backends/kimchi/mod.rs index fb77f5633..095ddd30b 100644 --- a/src/backends/kimchi/mod.rs +++ b/src/backends/kimchi/mod.rs @@ -2,8 +2,9 @@ pub mod asm; pub mod builtin; pub mod prover; +use educe::Educe; use std::{ - collections::{HashMap, HashSet}, + collections::{BTreeMap, HashMap, HashSet}, fmt::Write, ops::Neg as _, }; @@ -243,9 +244,11 @@ impl KimchiVesta { } } -#[derive(Default, Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[derive(Default, Clone, Copy, Debug, Eq, Hash, Serialize, Deserialize, PartialEq, Ord, Educe)] +#[educe(PartialOrd)] pub struct KimchiCellVar { index: usize, + #[educe(PartialOrd(ignore))] pub span: Span, } @@ -391,7 +394,7 @@ impl Backend for KimchiVesta { let mut witness = vec![]; // compute each rows' vars, except for the deferred ones (public output) - let mut public_outputs_vars: HashMap> = HashMap::new(); + let mut public_outputs_vars: BTreeMap> = BTreeMap::new(); // calculate witness except for public outputs for (row, row_of_vars) in self.witness_table.iter().enumerate() { diff --git a/src/circuit_writer/mod.rs b/src/circuit_writer/mod.rs index ec935ae9b..7ab529941 100644 --- a/src/circuit_writer/mod.rs +++ b/src/circuit_writer/mod.rs @@ -163,12 +163,9 @@ impl CircuitWriter { // create public output if let Some(typ) = &function.sig.return_type { - if typ.kind != TyKind::Field { - unimplemented!(); - } - - // create it - circuit_writer.add_public_outputs(1, typ.span); + // whatever is the size of return type, we need to add that many public outputs + let size_of = circuit_writer.size_of(&typ.kind); + circuit_writer.add_public_outputs(size_of, typ.span); } // public inputs should be handled first diff --git a/src/tests/examples.rs b/src/tests/examples.rs index e13947a8c..399857278 100644 --- a/src/tests/examples.rs +++ b/src/tests/examples.rs @@ -424,3 +424,39 @@ fn test_literals(#[case] backend: BackendKind) -> miette::Result<()> { Ok(()) } + +#[rstest] +#[case::kimchi_vesta(BackendKind::KimchiVesta(KimchiVesta::new(false)))] +#[case::r1cs(BackendKind::R1csBls12_381(R1CS::new()))] +fn test_public_output_array(#[case] backend: BackendKind) -> miette::Result<()> { + let public_inputs = r#"{"public_input": "1"}"#; + let private_inputs = r#"{"private_input": "1"}"#; + + test_file( + "public_output_array", + public_inputs, + private_inputs, + vec!["8", "2"], + backend, + )?; + + Ok(()) +} + +#[rstest] +#[case::kimchi_vesta(BackendKind::KimchiVesta(KimchiVesta::new(false)))] +#[case::r1cs(BackendKind::R1csBls12_381(R1CS::new()))] +fn test_types_array_output(#[case] backend: BackendKind) -> miette::Result<()> { + let public_inputs = r#"{"xx": "1", "yy": "4"}"#; + let private_inputs = r#"{}"#; + + test_file( + "types_array_output", + public_inputs, + private_inputs, + vec!["2", "4", "1", "8"], // 2x, y, x, 2y + backend, + )?; + + Ok(()) +} diff --git a/src/type_checker/mod.rs b/src/type_checker/mod.rs index ff0cb9d95..2cf07783f 100644 --- a/src/type_checker/mod.rs +++ b/src/type_checker/mod.rs @@ -379,14 +379,21 @@ impl TypeChecker { // the output value returned by the main function is also a main_args with a special name (public_output) if let Some(typ) = &function.sig.return_type { if is_main { - if !matches!(typ.kind, TyKind::Field) { - unimplemented!(); + match typ.kind { + TyKind::Field => { + typed_fn_env.store_type( + "public_output".to_string(), + TypeInfo::new_mut(typ.kind.clone(), typ.span), + )?; + } + TyKind::Array(_, _) => { + typed_fn_env.store_type( + "public_output".to_string(), + TypeInfo::new_mut(typ.kind.clone(), typ.span), + )?; + } + _ => unimplemented!(), } - - typed_fn_env.store_type( - "public_output".to_string(), - TypeInfo::new_mut(typ.kind.clone(), typ.span), - )?; } }