diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e0bb096a..355c4de6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -80,6 +80,7 @@ jobs: test-rust-avx512: runs-on: 7950x3d + if: false # temporarily disabled steps: - uses: styfle/cancel-workflow-action@0.11.0 - uses: actions/checkout@v4 diff --git a/.gitignore b/.gitignore index b1ff00da..7375672a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,8 @@ # generated artifact *.txt +*.json +*.witness +*.log __* target libec_go_lib.* diff --git a/Cargo.lock b/Cargo.lock index a3e937ae..a40683ba 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,18 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "ahash" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" +dependencies = [ + "cfg-if", + "once_cell", + "version_check", + "zerocopy", +] + [[package]] name = "aho-corasick" version = "1.1.3" @@ -26,6 +38,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "allocator-api2" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" + [[package]] name = "android-tzdata" version = "0.1.1" @@ -99,9 +117,9 @@ dependencies = [ [[package]] name = "arith" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" dependencies = [ - "ark-std", + "ark-std 0.4.0", "cfg-if", "criterion", "ethnum", @@ -114,6 +132,121 @@ dependencies = [ "tynm", ] +[[package]] +name = "ark-bls12-381" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3df4dcc01ff89867cd86b0da835f23c3f02738353aaee7dde7495af71363b8d5" +dependencies = [ + "ark-ec", + "ark-ff", + "ark-serialize", + "ark-std 0.5.0", +] + +[[package]] +name = "ark-ec" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43d68f2d516162846c1238e755a7c4d131b892b70cc70c471a8e3ca3ed818fce" +dependencies = [ + "ahash", + "ark-ff", + "ark-poly", + "ark-serialize", + "ark-std 0.5.0", + "educe", + "fnv", + "hashbrown 0.15.2", + "itertools 0.13.0", + "num-bigint", + "num-integer", + "num-traits", + "zeroize", +] + +[[package]] +name = "ark-ff" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a177aba0ed1e0fbb62aa9f6d0502e9b46dad8c2eab04c14258a1212d2557ea70" +dependencies = [ + "ark-ff-asm", + "ark-ff-macros", + "ark-serialize", + "ark-std 0.5.0", + "arrayvec", + "digest", + "educe", + "itertools 0.13.0", + "num-bigint", + "num-traits", + "paste", + "zeroize", +] + +[[package]] +name = "ark-ff-asm" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62945a2f7e6de02a31fe400aa489f0e0f5b2502e69f95f853adb82a96c7a6b60" +dependencies = [ + "quote", + "syn 2.0.79", +] + +[[package]] +name = "ark-ff-macros" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09be120733ee33f7693ceaa202ca41accd5653b779563608f1234f78ae07c4b3" +dependencies = [ + "num-bigint", + "num-traits", + "proc-macro2", + "quote", + "syn 2.0.79", +] + +[[package]] +name = "ark-poly" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "579305839da207f02b89cd1679e50e67b4331e2f9294a57693e5051b7703fe27" +dependencies = [ + "ahash", + "ark-ff", + "ark-serialize", + "ark-std 0.5.0", + "educe", + "fnv", + "hashbrown 0.15.2", +] + +[[package]] +name = "ark-serialize" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f4d068aaf107ebcd7dfb52bc748f8030e0fc930ac8e360146ca54c1203088f7" +dependencies = [ + "ark-serialize-derive", + "ark-std 0.5.0", + "arrayvec", + "digest", + "num-bigint", +] + +[[package]] +name = "ark-serialize-derive" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "213888f660fddcca0d257e88e54ac05bca01885f258ccdf695bafd77031bb69d" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.79", +] + [[package]] name = "ark-std" version = "0.4.0" @@ -124,6 +257,16 @@ dependencies = [ "rand", ] +[[package]] +name = "ark-std" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "246a225cc6131e9ee4f24619af0f19d67761fff15d7ccc22e42b80846e69449a" +dependencies = [ + "num-traits", + "rand", +] + [[package]] name = "arrayref" version = "0.3.8" @@ -163,6 +306,26 @@ version = "0.21.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" +[[package]] +name = "big-int" +version = "7.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31375ce97b1316b3a92644c2cbc93fa9dcfba06e4aec9a440bce23397af82fd6" +dependencies = [ + "big-int-proc", + "thiserror", +] + +[[package]] +name = "big-int-proc" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73cfa06eb56d71f2bb1874b101a50c3ba29fcf3ff7dd8de274e473929459863b" +dependencies = [ + "quote", + "syn 2.0.79", +] + [[package]] name = "bindgen" version = "0.69.4" @@ -332,13 +495,15 @@ dependencies = [ [[package]] name = "circuit" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" dependencies = [ "arith", - "ark-std", + "ark-std 0.4.0", "config", "ethnum", + "gkr_field_config", "log", + "rand", "thiserror", "transcript", ] @@ -348,14 +513,21 @@ name = "circuit-std-rs" version = "0.1.0" dependencies = [ "arith", - "ark-std", + "ark-bls12-381", + "ark-ff", + "ark-std 0.4.0", + "big-int", "circuit", "config", "expander_compiler", "gf2", "gkr", "mersenne31", + "num-bigint", + "num-traits", "rand", + "sha2", + "tiny-keccak", ] [[package]] @@ -418,15 +590,33 @@ checksum = "d3fd119d74b830634cea2a0f58bbd0d54540518a14397557951e79340abc28c0" [[package]] name = "config" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" dependencies = [ "arith", - "ark-std", + "ark-std 0.4.0", "gf2", "gf2_128", + "gkr_field_config", "halo2curves", "mersenne31", "mpi", + "mpi_config", + "poly_commit", + "transcript", +] + +[[package]] +name = "config_macros" +version = "0.1.0" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" +dependencies = [ + "config", + "field_hashers", + "gkr_field_config", + "poly_commit", + "proc-macro2", + "quote", + "syn 2.0.79", "transcript", ] @@ -521,6 +711,24 @@ version = "0.8.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" +[[package]] +name = "crosslayer_prototype" +version = "0.1.0" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" +dependencies = [ + "arith", + "config", + "env_logger", + "ethnum", + "gkr_field_config", + "log", + "polynomials", + "rand", + "sumcheck", + "thiserror", + "transcript", +] + [[package]] name = "crunchy" version = "0.2.2" @@ -569,11 +777,25 @@ dependencies = [ "config", "expander_compiler", "gkr", + "gkr_field_config", "libc", + "mpi_config", "rand", "transcript", ] +[[package]] +name = "educe" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d7bc049e1bd8cdeb31b68bbd586a9464ecf9f3944af3958a7a9d0f8b9799417" +dependencies = [ + "enum-ordinalize", + "proc-macro2", + "quote", + "syn 2.0.79", +] + [[package]] name = "either" version = "1.13.0" @@ -589,6 +811,26 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "enum-ordinalize" +version = "4.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fea0dcfa4e54eeb516fe454635a95753ddd39acda650ce703031c6973e315dd5" +dependencies = [ + "enum-ordinalize-derive", +] + +[[package]] +name = "enum-ordinalize-derive" +version = "4.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d28318a75d4aead5c4db25382e8ef717932d0346600cacae6357eb5941bc5ff" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.79", +] + [[package]] name = "env_filter" version = "0.1.2" @@ -639,18 +881,24 @@ name = "expander_compiler" version = "0.1.0" dependencies = [ "arith", - "ark-std", + "ark-std 0.4.0", "chrono", "circuit", "clap", "config", + "crosslayer_prototype", "ethnum", "gf2", "gkr", + "gkr_field_config", "halo2curves", "mersenne31", + "mpi_config", "rand", + "rayon", + "sha2", "tiny-keccak", + "transcript", ] [[package]] @@ -664,6 +912,16 @@ dependencies = [ "subtle", ] +[[package]] +name = "field_hashers" +version = "0.1.0" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" +dependencies = [ + "arith", + "halo2curves", + "tiny-keccak", +] + [[package]] name = "fnv" version = "1.0.7" @@ -751,10 +1009,10 @@ dependencies = [ [[package]] name = "gf2" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" dependencies = [ "arith", - "ark-std", + "ark-std 0.4.0", "cfg-if", "ethnum", "halo2curves", @@ -768,10 +1026,10 @@ dependencies = [ [[package]] name = "gf2_128" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" dependencies = [ "arith", - "ark-std", + "ark-std 0.4.0", "gf2", "rand", ] @@ -785,25 +1043,31 @@ checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd" [[package]] name = "gkr" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" dependencies = [ "arith", - "ark-std", + "ark-std 0.4.0", "bytes", "chrono", "circuit", "clap", "config", + "config_macros", "env_logger", "ethnum", + "field_hashers", "gf2", "gf2_128", + "gkr_field_config", "halo2curves", "log", "mersenne31", "mpi", + "mpi_config", + "poly_commit", "polynomials", "rand", + "rand_chacha", "sha2", "sumcheck", "thiserror", @@ -814,6 +1078,18 @@ dependencies = [ "warp", ] +[[package]] +name = "gkr_field_config" +version = "0.1.0" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" +dependencies = [ + "arith", + "ark-std 0.4.0", + "gf2", + "gf2_128", + "mersenne31", +] + [[package]] name = "glob" version = "0.3.1" @@ -890,6 +1166,15 @@ version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +[[package]] +name = "hashbrown" +version = "0.15.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" +dependencies = [ + "allocator-api2", +] + [[package]] name = "headers" version = "0.3.9" @@ -1056,7 +1341,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68b900aa2f7301e21c36462b170ee99994de34dff39a4a6a528e80e7376d07e5" dependencies = [ "equivalent", - "hashbrown", + "hashbrown 0.14.5", ] [[package]] @@ -1094,6 +1379,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.11" @@ -1190,12 +1484,13 @@ checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" [[package]] name = "mersenne31" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" dependencies = [ "arith", - "ark-std", + "ark-std 0.4.0", "cfg-if", "ethnum", + "field_hashers", "halo2curves", "log", "rand", @@ -1273,6 +1568,15 @@ dependencies = [ "cc", ] +[[package]] +name = "mpi_config" +version = "0.1.0" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" +dependencies = [ + "arith", + "mpi", +] + [[package]] name = "multer" version = "2.1.0" @@ -1475,13 +1779,32 @@ dependencies = [ "plotters-backend", ] +[[package]] +name = "poly_commit" +version = "0.1.0" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" +dependencies = [ + "arith", + "ark-std 0.4.0", + "ethnum", + "gf2", + "gkr_field_config", + "itertools 0.13.0", + "mpi_config", + "polynomials", + "rand", + "thiserror", + "transcript", + "tree", +] + [[package]] name = "polynomials" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" dependencies = [ "arith", - "ark-std", + "ark-std 0.4.0", "criterion", "halo2curves", ] @@ -1814,13 +2137,15 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "sumcheck" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" dependencies = [ "arith", "circuit", "config", "env_logger", + "gkr_field_config", "log", + "mpi_config", "polynomials", "transcript", ] @@ -1990,13 +2315,26 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" dependencies = [ "arith", + "field_hashers", + "mpi_config", "sha2", "tiny-keccak", ] +[[package]] +name = "tree" +version = "0.1.0" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" +dependencies = [ + "arith", + "ark-std 0.4.0", + "rayon", + "sha2", +] + [[package]] name = "try-lock" version = "0.2.5" @@ -2335,3 +2673,43 @@ checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed" dependencies = [ "tap", ] + +[[package]] +name = "zerocopy" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.79", +] + +[[package]] +name = "zeroize" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" +dependencies = [ + "zeroize_derive", +] + +[[package]] +name = "zeroize_derive" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.79", +] diff --git a/Cargo.toml b/Cargo.toml index 98bcbd0b..b566927a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,10 +20,13 @@ halo2curves = { git = "https://github.com/PolyhedraZK/halo2curves", default-feat "bits", ] } arith = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev" } +mpi_config = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev" } +gkr_field_config = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev" } expander_config = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev", package = "config" } expander_circuit = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev", package = "circuit" } gkr = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev" } gf2 = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev" } mersenne31 = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev" } expander_transcript = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev", package = "transcript" } +crosslayer_prototype = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev"} diff --git a/build-rust-avx512.sh b/build-rust-avx512.sh old mode 100644 new mode 100755 diff --git a/build-rust.sh b/build-rust.sh index 76564163..ae493fef 100755 --- a/build-rust.sh +++ b/build-rust.sh @@ -2,4 +2,4 @@ cd "$(dirname "$0")" cargo build --release mkdir -p ~/.cache/ExpanderCompilerCollection -cp target/release/libec_go_lib.so ~/.cache/ExpanderCompilerCollection \ No newline at end of file +cp target/release/libec_go_lib.* ~/.cache/ExpanderCompilerCollection diff --git a/circuit-std-go/logup/hint.go b/circuit-std-go/logup/hint.go new file mode 100644 index 00000000..a4d2565f --- /dev/null +++ b/circuit-std-go/logup/hint.go @@ -0,0 +1,46 @@ +package logup + +import ( + "math/big" +) + +func rangeProofHint(q *big.Int, inputs []*big.Int, outputs []*big.Int) error { + n := inputs[0].Int64() + a := new(big.Int).Set(inputs[1]) + + for i := int64(0); i < n/int64(LookupTableBits); i++ { + a, outputs[i] = new(big.Int).DivMod(a, big.NewInt(int64(1< 1 { + n >>= 1 + for i := 0; i < n; i++ { + next = append(next, cur[i*2].Add(api, &cur[i*2+1])) + } + cur = next + next = next[:0] + } + + if len(cur) != 1 { + panic("Summation code may be wrong.") + } + + return cur[0] +} + +func SimpleMin(a uint, b uint) uint { + if a < b { + return a + } else { + return b + } +} + +func GetColumnRandomness(api ecgo.API, n_columns uint, column_combine_options ColumnCombineOptions) []frontend.Variable { + var randomness = make([]frontend.Variable, n_columns) + if column_combine_options == Poly { + beta := api.GetRandomValue() + randomness[0] = 1 + randomness[1] = beta + + // Hopefully this will generate fewer layers than sequential pows + max_deg := uint(1) + for max_deg < n_columns { + for i := max_deg + 1; i <= SimpleMin(max_deg*2, n_columns-1); i++ { + randomness[i] = api.Mul(randomness[max_deg], randomness[i-max_deg]) + } + max_deg *= 2 + } + + // Debug Code: + // for i := 1; i < n_columns; i++ { + // api.AssertIsEqual(randomness[i], api.Mul(randomness[i - 1], beta)) + // } + + } else if column_combine_options == FullRandom { + randomness[0] = 1 + for i := 1; i < int(n_columns); i++ { + randomness[i] = api.GetRandomValue() + } + } else { + panic("Unknown poly combine options") + } + return randomness +} + +func CombineColumn(api ecgo.API, vec_2d [][]frontend.Variable, randomness []frontend.Variable) []frontend.Variable { + n_rows := len(vec_2d) + if n_rows == 0 { + return make([]frontend.Variable, 0) + } + + n_columns := len(vec_2d[0]) + + // Do not introduce any randomness + if n_columns == 1 { + vec_combined := make([]frontend.Variable, n_rows) + for i := 0; i < n_rows; i++ { + vec_combined[i] = vec_2d[i][0] + } + return vec_combined + } + + if !IsPowerOf2(n_columns) { + panic("Consider support this") + } + + vec_return := make([]frontend.Variable, 0) + for i := 0; i < n_rows; i++ { + var v_at_row_i frontend.Variable = 0 + for j := 0; j < n_columns; j++ { + v_at_row_i = api.Add(v_at_row_i, api.Mul(randomness[j], vec_2d[i][j])) + } + vec_return = append(vec_return, v_at_row_i) + } + return vec_return +} diff --git a/circuit-std-go/poseidon-m31/poseidon.go b/circuit-std-go/poseidon-m31/poseidon.go new file mode 100644 index 00000000..7fc88bca --- /dev/null +++ b/circuit-std-go/poseidon-m31/poseidon.go @@ -0,0 +1,208 @@ +package poseidonM31 + +import ( + "encoding/binary" + "math/big" + + "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo" + "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/utils/customgates" + "github.com/PolyhedraZK/ExpanderCompilerCollection/field/m31" + "github.com/consensys/gnark/frontend" + "golang.org/x/crypto/sha3" +) + +var ( + poseidonM31x16FullRounds int + poseidonM31x16PartialRounds int + + poseidonM31x16RoundConstant [][]uint + poseidonM31x16MDS [][]uint + + POW_5_GATE_ID uint64 = 12345 + POW_5_COST_PSEUDO int = 20 +) + +func sBox(api frontend.API, f frontend.Variable) frontend.Variable { + return api.(ecgo.API).CustomGate(POW_5_GATE_ID, f) +} + +func Power5(field *big.Int, inputs []*big.Int, outputs []*big.Int) error { + a := big.NewInt(0) + a.Mul(inputs[0], inputs[0]) + a.Mul(a, a) + a.Mul(a, inputs[0]) + outputs[0] = a + return nil +} + +func init() { + poseidonM31x16FullRounds = 8 + poseidonM31x16PartialRounds = 14 + + var m31Modulus uint = uint(m31.ScalarField.Uint64()) + + // NOTE Poseidon full round parameter generation + poseidonM31x16Seed := []byte("poseidon_seed_Mersenne 31_16") + + hasher := sha3.NewLegacyKeccak256() + hasher.Write(poseidonM31x16Seed) + poseidonM31x16Seed = hasher.Sum(nil) + + poseidonM31x16RoundConstant = make([][]uint, poseidonM31x16FullRounds+poseidonM31x16PartialRounds) + for i := 0; i < int(poseidonM31x16FullRounds+poseidonM31x16PartialRounds); i++ { + poseidonM31x16RoundConstant[i] = make([]uint, 16) + + for j := 0; j < 16; j++ { + hasher.Reset() + hasher.Write(poseidonM31x16Seed) + poseidonM31x16Seed = hasher.Sum(nil) + + u32LE := binary.LittleEndian.Uint32(poseidonM31x16Seed[:4]) + poseidonM31x16RoundConstant[i][j] = uint(u32LE) % m31Modulus + } + } + + // NOTE MDS generation + poseidonM31x16MDS = make([][]uint, 16) + poseidonM31x16MDS[0] = []uint{1, 1, 51, 1, 11, 17, 2, 1, 101, 63, 15, 2, 67, 22, 13, 3} + for i := 1; i < 16; i++ { + poseidonM31x16MDS[i] = make([]uint, 16) + for j := 0; j < 16; j++ { + poseidonM31x16MDS[i][j] = poseidonM31x16MDS[0][(i+j)%16] + } + } + + // NOTE register pow-5 gate + customgates.Register(POW_5_GATE_ID, Power5, POW_5_COST_PSEUDO) +} + +func poseidonM31x16MDSApply( + api frontend.API, state []frontend.Variable) []frontend.Variable { + + res := make([]frontend.Variable, 16) + for i := 0; i < 16; i++ { + res[i] = 0 + } + + for i := 0; i < 16; i++ { + for j := 0; j < 16; j++ { + res[i] = api.Add(api.Mul(poseidonM31x16MDS[i][j], state[j]), res[i]) + } + } + + return res +} + +func poseidonM31x16FullRoundSBox( + api frontend.API, state []frontend.Variable) []frontend.Variable { + + for i := 0; i < 16; i++ { + state[i] = sBox(api, state[i]) + } + + return state +} + +func poseidonM31x16PartialRoundSbox( + api frontend.API, state []frontend.Variable) []frontend.Variable { + + state[0] = sBox(api, state[0]) + + return state +} + +func poseidonM31x16RoundConstantApply( + api frontend.API, state []frontend.Variable, round int) []frontend.Variable { + + for i := 0; i < 16; i++ { + state[i] = api.Add(state[i], poseidonM31x16RoundConstant[round][i]) + } + + return state +} + +func PoseidonM31x16Permutate( + api frontend.API, state []frontend.Variable) []frontend.Variable { + + partialRoundEnds := poseidonM31x16FullRounds/2 + poseidonM31x16PartialRounds + allRoundEnds := poseidonM31x16FullRounds + poseidonM31x16PartialRounds + + for i := 0; i < poseidonM31x16FullRounds/2; i++ { + state = poseidonM31x16RoundConstantApply(api, state, i) + state = poseidonM31x16MDSApply(api, state) + state = poseidonM31x16FullRoundSBox(api, state) + } + + for i := poseidonM31x16FullRounds / 2; i < partialRoundEnds; i++ { + state = poseidonM31x16RoundConstantApply(api, state, i) + state = poseidonM31x16MDSApply(api, state) + state = poseidonM31x16PartialRoundSbox(api, state) + } + + for i := partialRoundEnds; i < allRoundEnds; i++ { + state = poseidonM31x16RoundConstantApply(api, state, i) + state = poseidonM31x16MDSApply(api, state) + state = poseidonM31x16FullRoundSBox(api, state) + } + + return state +} + +type PoseidonM31x16Permutation struct { + State [16]frontend.Variable + Digest [16]frontend.Variable +} + +func (p *PoseidonM31x16Permutation) Define(api frontend.API) error { + + digest := poseidonM31x16FullRoundSBox(api, p.State[:]) + + for i := 0; i < 16; i++ { + api.AssertIsEqual(p.Digest[i], digest[i]) + } + + return nil +} + +func PoseidonM31x16HashToState( + api frontend.API, fs []frontend.Variable) ([]frontend.Variable, uint) { + + poseidonM31x16Rate := 8 + poseidonM31x16Capacity := 16 - poseidonM31x16Rate + numChunks := (len(fs) + poseidonM31x16Rate - 1) / poseidonM31x16Rate + + absorbBuffer := make([]frontend.Variable, numChunks*poseidonM31x16Rate) + copy(absorbBuffer, fs) + for i := len(fs); i < len(absorbBuffer); i++ { + absorbBuffer[i] = 0 + } + + res := make([]frontend.Variable, 16) + for i := 0; i < 16; i++ { + res[i] = 0 + } + + for i := 0; i < numChunks; i++ { + for j := poseidonM31x16Capacity; j < 16; j++ { + res[j] = api.Add(res[j], absorbBuffer[i*poseidonM31x16Rate+j-poseidonM31x16Capacity]) + } + res = PoseidonM31x16Permutate(api, res) + } + + return res, uint(numChunks) +} + +type PoseidonM31x16Sponge struct { + ToBeHashed []frontend.Variable + Digest [16]frontend.Variable +} + +func (p *PoseidonM31x16Sponge) Define(api frontend.API) error { + digest, _ := PoseidonM31x16HashToState(api, p.ToBeHashed) + + for i := 0; i < 16; i++ { + api.AssertIsEqual(digest[i], p.Digest[i]) + } + + return nil +} diff --git a/circuit-std-go/poseidon-m31/poseidon_test.go b/circuit-std-go/poseidon-m31/poseidon_test.go new file mode 100644 index 00000000..da5027e2 --- /dev/null +++ b/circuit-std-go/poseidon-m31/poseidon_test.go @@ -0,0 +1,82 @@ +package poseidonM31 + +import ( + "testing" + + "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo" + "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/test" + "github.com/PolyhedraZK/ExpanderCompilerCollection/field/m31" + "github.com/consensys/gnark/frontend" + "github.com/stretchr/testify/require" +) + +func TestPoseidonM31x16Params(t *testing.T) { + require.Equal(t, + uint(80596940), + poseidonM31x16RoundConstant[0][0], + "poseidon round constant m31x16 0.0 not matching ggs", + ) +} + +func TestPoseidonM31x16HashToState(t *testing.T) { + + testcases := []struct { + InputLen uint + Assignment PoseidonM31x16Sponge + }{ + { + InputLen: 8, + Assignment: PoseidonM31x16Sponge{ + ToBeHashed: []frontend.Variable{ + 114514, 114514, 114514, 114514, + 114514, 114514, 114514, 114514, + }, + Digest: [16]frontend.Variable{ + 1021105124, 1342990709, 1593716396, 2100280498, + 330652568, 1371365483, 586650367, 345482939, + 849034538, 175601510, 1454280121, 1362077584, + 528171622, 187534772, 436020341, 1441052621, + }, + }, + }, + { + InputLen: 16, + Assignment: PoseidonM31x16Sponge{ + ToBeHashed: []frontend.Variable{ + 114514, 114514, 114514, 114514, + 114514, 114514, 114514, 114514, + 114514, 114514, 114514, 114514, + 114514, 114514, 114514, 114514, + }, + Digest: [16]frontend.Variable{ + 1510043913, 1840611937, 45881205, 1134797377, + 803058407, 1772167459, 846553905, 2143336151, + 300871060, 545838827, 1603101164, 396293243, + 502075988, 2067011878, 402134378, 535675968, + }, + }, + }, + } + + for _, testcase := range testcases { + circuit := PoseidonM31x16Sponge{ + ToBeHashed: make([]frontend.Variable, testcase.InputLen), + } + circuitCompileResult, err := ecgo.Compile( + m31.ScalarField, + &circuit, + ) + require.NoError(t, err, "ggs compile circuit error") + layeredCircuit := circuitCompileResult.GetLayeredCircuit() + + inputSolver := circuitCompileResult.GetInputSolver() + witness, err := inputSolver.SolveInput(&testcase.Assignment, 0) + require.NoError(t, err, "ggs solving witness error") + + require.True( + t, + test.CheckCircuit(layeredCircuit, witness), + "ggs check circuit error", + ) + } +} diff --git a/circuit-std-rs/Cargo.toml b/circuit-std-rs/Cargo.toml index 9fbf43af..e98717d7 100644 --- a/circuit-std-rs/Cargo.toml +++ b/circuit-std-rs/Cargo.toml @@ -15,3 +15,10 @@ gkr.workspace = true arith.workspace = true gf2.workspace = true mersenne31.workspace = true +sha2 = "0.10.8" +big-int = "7.0.0" +num-bigint = "0.4.6" +num-traits = "0.2.19" +ark-bls12-381 = "0.5.0" +ark-ff = "0.5.0" +tiny-keccak = { version = "2.0.2", features = [ "sha3", "keccak" ] } diff --git a/circuit-std-rs/src/gnark/element.rs b/circuit-std-rs/src/gnark/element.rs new file mode 100644 index 00000000..e4a863d5 --- /dev/null +++ b/circuit-std-rs/src/gnark/element.rs @@ -0,0 +1,142 @@ +use crate::gnark::emparam::FieldParams; +use crate::gnark::limbs::*; +use expander_compiler::frontend::*; +use num_bigint::BigInt; +use num_traits::ToPrimitive; +use std::any::Any; +use std::cmp::Ordering; +#[derive(Default, Clone)] +pub struct Element { + pub limbs: Vec, + pub overflow: u32, + pub internal: bool, + pub mod_reduced: bool, + pub is_evaluated: bool, + pub evaluation: Variable, + pub _marker: std::marker::PhantomData, +} + +impl Element { + pub fn new( + limbs: Vec, + overflow: u32, + internal: bool, + mod_reduced: bool, + is_evaluated: bool, + evaluation: Variable, + ) -> Self { + Self { + limbs, + overflow, + internal, + mod_reduced, + is_evaluated, + evaluation, + _marker: std::marker::PhantomData, + } + } + pub fn my_default() -> Self { + Self { + limbs: Vec::new(), + overflow: 0, + internal: false, + mod_reduced: false, + is_evaluated: false, + evaluation: Variable::default(), + _marker: std::marker::PhantomData, + } + } + pub fn my_clone(&self) -> Self { + Self { + limbs: self.limbs.clone(), + overflow: self.overflow, + internal: self.internal, + mod_reduced: self.mod_reduced, + is_evaluated: self.is_evaluated, + evaluation: self.evaluation, + _marker: std::marker::PhantomData, + } + } + pub fn is_empty(&self) -> bool { + self.limbs.is_empty() + } +} +pub fn value_of, T: FieldParams>( + api: &mut B, + constant: Box, +) -> Element { + let r: Element = new_const_element::(api, constant); + r +} +pub fn new_const_element, T: FieldParams>( + api: &mut B, + v: Box, +) -> Element { + let fp = T::modulus(); + // convert to big.Int + let mut b_value = from_interface(v); + // mod reduce + if fp.cmp(&b_value) != Ordering::Equal { + b_value %= fp; + } + + // decompose into limbs + let mut blimbs = vec![BigInt::default(); T::nb_limbs() as usize]; + let mut limbs = vec![Variable::default(); blimbs.len()]; + if let Err(err) = decompose(&b_value, T::bits_per_limb(), &mut blimbs) { + panic!("decompose value: {}", err); + } + // assign limb values + for i in 0..limbs.len() { + limbs[i] = api.constant(blimbs[i].to_u64().unwrap() as u32); + } + Element::new(limbs, 0, true, false, false, Variable::default()) +} +pub fn new_internal_element(limbs: Vec, overflow: u32) -> Element { + Element::new(limbs, overflow, true, false, false, Variable::default()) +} +pub fn copy(e: &Element) -> Element { + let mut r = Element::new(Vec::new(), 0, false, false, false, Variable::default()); + r.limbs = e.limbs.clone(); + r.overflow = e.overflow; + r.internal = e.internal; + r.mod_reduced = e.mod_reduced; + r +} +pub fn from_interface(input: Box) -> BigInt { + let r; + + if let Some(v) = input.downcast_ref::() { + r = v.clone(); + } else if let Some(v) = input.downcast_ref::() { + r = BigInt::from(*v); + } else if let Some(v) = input.downcast_ref::() { + r = BigInt::from(*v); + } else if let Some(v) = input.downcast_ref::() { + r = BigInt::from(*v); + } else if let Some(v) = input.downcast_ref::() { + r = BigInt::from(*v); + } else if let Some(v) = input.downcast_ref::() { + r = BigInt::from(*v as u64); + } else if let Some(v) = input.downcast_ref::() { + r = BigInt::from(*v); + } else if let Some(v) = input.downcast_ref::() { + r = BigInt::from(*v); + } else if let Some(v) = input.downcast_ref::() { + r = BigInt::from(*v); + } else if let Some(v) = input.downcast_ref::() { + r = BigInt::from(*v); + } else if let Some(v) = input.downcast_ref::() { + r = BigInt::from(*v as i64); + } else if let Some(v) = input.downcast_ref::() { + r = BigInt::parse_bytes(v.as_bytes(), 10).unwrap_or_else(|| { + panic!("unable to set BigInt from string: {}", v); + }); + } else if let Some(v) = input.downcast_ref::>() { + r = BigInt::from_bytes_be(num_bigint::Sign::Plus, v); + } else { + panic!("value to BigInt not supported"); + } + + r +} diff --git a/circuit-std-rs/src/gnark/emparam.rs b/circuit-std-rs/src/gnark/emparam.rs new file mode 100644 index 00000000..49905eb1 --- /dev/null +++ b/circuit-std-rs/src/gnark/emparam.rs @@ -0,0 +1,72 @@ +use num_bigint::BigInt; + +#[derive(Default, Clone, Copy)] +pub struct Bls12381Fp {} +impl Bls12381Fp { + pub fn nb_limbs() -> u32 { + 48 + } + pub fn bits_per_limb() -> u32 { + 8 + } + pub fn is_prime() -> bool { + true + } + pub fn modulus() -> BigInt { + let hex_str = "1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab"; + BigInt::parse_bytes(hex_str.as_bytes(), 16).unwrap() + } +} +#[derive(Default, Clone)] +pub struct Bls12381Fr {} +impl Bls12381Fr { + pub fn nb_limbs() -> u32 { + 32 + } + pub fn bits_per_limb() -> u32 { + 8 + } + pub fn is_prime() -> bool { + true + } + pub fn modulus() -> BigInt { + let hex_str = "73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001"; + BigInt::parse_bytes(hex_str.as_bytes(), 16).unwrap() + } +} +pub trait FieldParams { + fn nb_limbs() -> u32; + fn bits_per_limb() -> u32; + fn is_prime() -> bool; + fn modulus() -> BigInt; +} + +impl FieldParams for Bls12381Fr { + fn nb_limbs() -> u32 { + Bls12381Fr::nb_limbs() + } + fn bits_per_limb() -> u32 { + Bls12381Fr::bits_per_limb() + } + fn is_prime() -> bool { + Bls12381Fr::is_prime() + } + fn modulus() -> BigInt { + Bls12381Fr::modulus() + } +} + +impl FieldParams for Bls12381Fp { + fn nb_limbs() -> u32 { + Bls12381Fp::nb_limbs() + } + fn bits_per_limb() -> u32 { + Bls12381Fp::bits_per_limb() + } + fn is_prime() -> bool { + Bls12381Fp::is_prime() + } + fn modulus() -> BigInt { + Bls12381Fp::modulus() + } +} diff --git a/circuit-std-rs/src/gnark/emulated/field_bls12381/e12.rs b/circuit-std-rs/src/gnark/emulated/field_bls12381/e12.rs new file mode 100644 index 00000000..732cd790 --- /dev/null +++ b/circuit-std-rs/src/gnark/emulated/field_bls12381/e12.rs @@ -0,0 +1,601 @@ +use expander_compiler::frontend::{Config, RootAPI, Variable}; + +use super::e2::*; +use super::e6::*; +#[derive(Default, Clone)] +pub struct GE12 { + pub c0: GE6, + pub c1: GE6, +} +impl GE12 { + pub fn my_clone(&self) -> Self { + GE12 { + c0: self.c0.my_clone(), + c1: self.c1.my_clone(), + } + } +} +pub struct Ext12 { + pub ext6: Ext6, +} + +impl Ext12 { + pub fn new>(api: &mut B) -> Self { + Self { + ext6: Ext6::new(api), + } + } + pub fn zero(&mut self) -> GE12 { + let zero = self.ext6.ext2.curve_f.zero_const.clone(); + GE12 { + c0: GE6 { + b0: GE2 { + a0: zero.clone(), + a1: zero.clone(), + }, + b1: GE2 { + a0: zero.clone(), + a1: zero.clone(), + }, + b2: GE2 { + a0: zero.clone(), + a1: zero.clone(), + }, + }, + c1: GE6 { + b0: GE2 { + a0: zero.clone(), + a1: zero.clone(), + }, + b1: GE2 { + a0: zero.clone(), + a1: zero.clone(), + }, + b2: GE2 { + a0: zero.clone(), + a1: zero.clone(), + }, + }, + } + } + pub fn one(&mut self) -> GE12 { + let one = self.ext6.ext2.curve_f.one_const.clone(); + let zero = self.ext6.ext2.curve_f.zero_const.clone(); + GE12 { + c0: GE6 { + b0: GE2 { + a0: one.clone(), + a1: zero.clone(), + }, + b1: GE2 { + a0: zero.clone(), + a1: zero.clone(), + }, + b2: GE2 { + a0: zero.clone(), + a1: zero.clone(), + }, + }, + c1: GE6 { + b0: GE2 { + a0: zero.clone(), + a1: zero.clone(), + }, + b1: GE2 { + a0: zero.clone(), + a1: zero.clone(), + }, + b2: GE2 { + a0: zero.clone(), + a1: zero.clone(), + }, + }, + } + } + pub fn is_zero>(&mut self, native: &mut B, z: &GE12) -> Variable { + let c0 = self.ext6.is_zero(native, &z.c0); + let c1 = self.ext6.is_zero(native, &z.c1); + native.and(c0, c1) + } + pub fn add>(&mut self, native: &mut B, x: &GE12, y: &GE12) -> GE12 { + let z0 = self.ext6.add(native, &x.c0, &y.c0); + let z1 = self.ext6.add(native, &x.c1, &y.c1); + GE12 { c0: z0, c1: z1 } + } + pub fn sub>(&mut self, native: &mut B, x: &GE12, y: &GE12) -> GE12 { + let z0 = self.ext6.sub(native, &x.c0, &y.c0); + let z1 = self.ext6.sub(native, &x.c1, &y.c1); + GE12 { c0: z0, c1: z1 } + } + pub fn conjugate>(&mut self, native: &mut B, x: &GE12) -> GE12 { + let z1 = self.ext6.neg(native, &x.c1); + GE12 { + c0: x.c0.my_clone(), + c1: z1, + } + } + pub fn mul>(&mut self, native: &mut B, x: &GE12, y: &GE12) -> GE12 { + let a = self.ext6.add(native, &x.c0, &x.c1); + let b = self.ext6.add(native, &y.c0, &y.c1); + let a = self.ext6.mul(native, &a, &b); + let b = self.ext6.mul(native, &x.c0, &y.c0); + let c = self.ext6.mul(native, &x.c1, &y.c1); + let d = self.ext6.add(native, &c, &b); + let z1 = self.ext6.sub(native, &a, &d); + let z0 = self.ext6.mul_by_non_residue(native, &c); + let z0 = self.ext6.add(native, &z0, &b); + GE12 { c0: z0, c1: z1 } + } + pub fn square>(&mut self, native: &mut B, x: &GE12) -> GE12 { + let c0 = self.ext6.sub(native, &x.c0, &x.c1); + let c3 = self.ext6.mul_by_non_residue(native, &x.c1); + let c3 = self.ext6.sub(native, &x.c0, &c3); + let c2 = self.ext6.mul(native, &x.c0, &x.c1); + let c0 = self.ext6.mul(native, &c0, &c3); + let c0 = self.ext6.add(native, &c0, &c2); + let z1 = self.ext6.double(native, &c2); + let c2 = self.ext6.mul_by_non_residue(native, &c2); + let z0 = self.ext6.add(native, &c0, &c2); + GE12 { c0: z0, c1: z1 } + } + + pub fn cyclotomic_square>( + &mut self, + native: &mut B, + x: &GE12, + ) -> GE12 { + let t0 = self.ext6.ext2.square(native, &x.c1.b1); + let t1 = self.ext6.ext2.square(native, &x.c0.b0); + let mut t6 = self.ext6.ext2.add(native, &x.c1.b1, &x.c0.b0); + t6 = self.ext6.ext2.square(native, &t6); + t6 = self.ext6.ext2.sub(native, &t6, &t0); + t6 = self.ext6.ext2.sub(native, &t6, &t1); + let t2 = self.ext6.ext2.square(native, &x.c0.b2); + let t3 = self.ext6.ext2.square(native, &x.c1.b0); + let mut t7 = self.ext6.ext2.add(native, &x.c0.b2, &x.c1.b0); + t7 = self.ext6.ext2.square(native, &t7); + t7 = self.ext6.ext2.sub(native, &t7, &t2); + t7 = self.ext6.ext2.sub(native, &t7, &t3); + let t4 = self.ext6.ext2.square(native, &x.c1.b2); + let t5 = self.ext6.ext2.square(native, &x.c0.b1); + let mut t8 = self.ext6.ext2.add(native, &x.c1.b2, &x.c0.b1); + t8 = self.ext6.ext2.square(native, &t8); + t8 = self.ext6.ext2.sub(native, &t8, &t4); + t8 = self.ext6.ext2.sub(native, &t8, &t5); + t8 = self.ext6.ext2.mul_by_non_residue(native, &t8); + let t0 = self.ext6.ext2.mul_by_non_residue(native, &t0); + let t0 = self.ext6.ext2.add(native, &t0, &t1); + let t2 = self.ext6.ext2.mul_by_non_residue(native, &t2); + let t2 = self.ext6.ext2.add(native, &t2, &t3); + let t4 = self.ext6.ext2.mul_by_non_residue(native, &t4); + let t4 = self.ext6.ext2.add(native, &t4, &t5); + let z00 = self.ext6.ext2.sub(native, &t0, &x.c0.b0); + let z00 = self.ext6.ext2.double(native, &z00); + let z00 = self.ext6.ext2.add(native, &z00, &t0); + let z01 = self.ext6.ext2.sub(native, &t2, &x.c0.b1); + let z01 = self.ext6.ext2.double(native, &z01); + let z01 = self.ext6.ext2.add(native, &z01, &t2); + let z02 = self.ext6.ext2.sub(native, &t4, &x.c0.b2); + let z02 = self.ext6.ext2.double(native, &z02); + let z02 = self.ext6.ext2.add(native, &z02, &t4); + let z10 = self.ext6.ext2.add(native, &t8, &x.c1.b0); + let z10 = self.ext6.ext2.double(native, &z10); + let z10 = self.ext6.ext2.add(native, &z10, &t8); + let z11 = self.ext6.ext2.add(native, &t6, &x.c1.b1); + let z11 = self.ext6.ext2.double(native, &z11); + let z11 = self.ext6.ext2.add(native, &z11, &t6); + let z12 = self.ext6.ext2.add(native, &t7, &x.c1.b2); + let z12 = self.ext6.ext2.double(native, &z12); + let z12 = self.ext6.ext2.add(native, &z12, &t7); + GE12 { + c0: GE6 { + b0: z00, + b1: z01, + b2: z02, + }, + c1: GE6 { + b0: z10, + b1: z11, + b2: z12, + }, + } + } + pub fn assert_isequal>(&mut self, native: &mut B, x: &GE12, y: &GE12) { + self.ext6.assert_isequal(native, &x.c0, &y.c0); + self.ext6.assert_isequal(native, &x.c1, &y.c1); + } + pub fn div>(&mut self, native: &mut B, x: &GE12, y: &GE12) -> GE12 { + let inputs = vec![ + x.c0.b0.a0.clone(), + x.c0.b0.a1.clone(), + x.c0.b1.a0.clone(), + x.c0.b1.a1.clone(), + x.c0.b2.a0.clone(), + x.c0.b2.a1.clone(), + x.c1.b0.a0.clone(), + x.c1.b0.a1.clone(), + x.c1.b1.a0.clone(), + x.c1.b1.a1.clone(), + x.c1.b2.a0.clone(), + x.c1.b2.a1.clone(), + y.c0.b0.a0.clone(), + y.c0.b0.a1.clone(), + y.c0.b1.a0.clone(), + y.c0.b1.a1.clone(), + y.c0.b2.a0.clone(), + y.c0.b2.a1.clone(), + y.c1.b0.a0.clone(), + y.c1.b0.a1.clone(), + y.c1.b1.a0.clone(), + y.c1.b1.a1.clone(), + y.c1.b2.a0.clone(), + y.c1.b2.a1.clone(), + ]; + let output = self + .ext6 + .ext2 + .curve_f + .new_hint(native, "myhint.dive12hint", 24, inputs); + let div = GE12 { + c0: GE6 { + b0: GE2 { + a0: output[0].clone(), + a1: output[1].clone(), + }, + b1: GE2 { + a0: output[2].clone(), + a1: output[3].clone(), + }, + b2: GE2 { + a0: output[4].clone(), + a1: output[5].clone(), + }, + }, + c1: GE6 { + b0: GE2 { + a0: output[6].clone(), + a1: output[7].clone(), + }, + b1: GE2 { + a0: output[8].clone(), + a1: output[9].clone(), + }, + b2: GE2 { + a0: output[10].clone(), + a1: output[11].clone(), + }, + }, + }; + let _x = self.mul(native, &div, y); + self.assert_isequal(native, x, &_x); + div + } + pub fn inverse>(&mut self, native: &mut B, x: &GE12) -> GE12 { + let inputs = vec![ + x.c0.b0.a0.clone(), + x.c0.b0.a1.clone(), + x.c0.b1.a0.clone(), + x.c0.b1.a1.clone(), + x.c0.b2.a0.clone(), + x.c0.b2.a1.clone(), + x.c1.b0.a0.clone(), + x.c1.b0.a1.clone(), + x.c1.b1.a0.clone(), + x.c1.b1.a1.clone(), + x.c1.b2.a0.clone(), + x.c1.b2.a1.clone(), + ]; + let output = self + .ext6 + .ext2 + .curve_f + .new_hint(native, "myhint.inversee12hint", 12, inputs); + let inv = GE12 { + c0: GE6 { + b0: GE2 { + a0: output[0].clone(), + a1: output[1].clone(), + }, + b1: GE2 { + a0: output[2].clone(), + a1: output[3].clone(), + }, + b2: GE2 { + a0: output[4].clone(), + a1: output[5].clone(), + }, + }, + c1: GE6 { + b0: GE2 { + a0: output[6].clone(), + a1: output[7].clone(), + }, + b1: GE2 { + a0: output[8].clone(), + a1: output[9].clone(), + }, + b2: GE2 { + a0: output[10].clone(), + a1: output[11].clone(), + }, + }, + }; + let one = self.one(); + let _one = self.mul(native, &inv, x); + self.assert_isequal(native, &one, &_one); + inv + } + pub fn copy>(&mut self, native: &mut B, x: &GE12) -> GE12 { + let inputs = vec![ + x.c0.b0.a0.clone(), + x.c0.b0.a1.clone(), + x.c0.b1.a0.clone(), + x.c0.b1.a1.clone(), + x.c0.b2.a0.clone(), + x.c0.b2.a1.clone(), + x.c1.b0.a0.clone(), + x.c1.b0.a1.clone(), + x.c1.b1.a0.clone(), + x.c1.b1.a1.clone(), + x.c1.b2.a0.clone(), + x.c1.b2.a1.clone(), + ]; + let output = self + .ext6 + .ext2 + .curve_f + .new_hint(native, "myhint.copye12hint", 12, inputs); + let res = GE12 { + c0: GE6 { + b0: GE2 { + a0: output[0].clone(), + a1: output[1].clone(), + }, + b1: GE2 { + a0: output[2].clone(), + a1: output[3].clone(), + }, + b2: GE2 { + a0: output[4].clone(), + a1: output[5].clone(), + }, + }, + c1: GE6 { + b0: GE2 { + a0: output[6].clone(), + a1: output[7].clone(), + }, + b1: GE2 { + a0: output[8].clone(), + a1: output[9].clone(), + }, + b2: GE2 { + a0: output[10].clone(), + a1: output[11].clone(), + }, + }, + }; + self.assert_isequal(native, x, &res); + res + } + pub fn select>( + &mut self, + native: &mut B, + selector: Variable, + z1: &GE12, + z0: &GE12, + ) -> GE12 { + let c0 = self.ext6.select(native, selector, &z1.c0, &z0.c0); + let c1 = self.ext6.select(native, selector, &z1.c1, &z0.c1); + GE12 { c0, c1 } + } + + /////// pairing /////// + pub fn mul_by_014>( + &mut self, + native: &mut B, + z: &GE12, + c0: &GE2, + c1: &GE2, + ) -> GE12 { + let a = self.ext6.mul_by_01(native, &z.c0, c0, c1); + let b = GE6 { + b0: self.ext6.ext2.mul_by_non_residue(native, &z.c1.b2), + b1: z.c1.b0.clone(), + b2: z.c1.b1.clone(), + }; + let one = self.ext6.ext2.one(); + let d = self.ext6.ext2.add(native, c1, &one); + let zc1 = self.ext6.add(native, &z.c1, &z.c0); + let zc1 = self.ext6.mul_by_01(native, &zc1, c0, &d); + let tmp = self.ext6.add(native, &b, &a); + let zc1 = self.ext6.sub(native, &zc1, &tmp); + let zc0 = self.ext6.mul_by_non_residue(native, &b); + let zc0 = self.ext6.add(native, &zc0, &a); + GE12 { c0: zc0, c1: zc1 } + } + pub fn mul_014_by_014>( + &mut self, + native: &mut B, + d0: &GE2, + d1: &GE2, + c0: &GE2, + c1: &GE2, + ) -> [GE2; 5] { + let x0 = self.ext6.ext2.mul(native, c0, d0); + let x1 = self.ext6.ext2.mul(native, c1, d1); + let x04 = self.ext6.ext2.add(native, c0, d0); + let tmp = self.ext6.ext2.add(native, c0, c1); + let x01 = self.ext6.ext2.add(native, d0, d1); + let x01 = self.ext6.ext2.mul(native, &x01, &tmp); + let tmp = self.ext6.ext2.add(native, &x1, &x0); + let x01 = self.ext6.ext2.sub(native, &x01, &tmp); + let x14 = self.ext6.ext2.add(native, c1, d1); + let z_c0_b0 = self.ext6.ext2.non_residue(native); + let z_c0_b0 = self.ext6.ext2.add(native, &z_c0_b0, &x0); + [z_c0_b0, x01, x1, x04, x14] + } + pub fn expt>(&mut self, native: &mut B, x: &GE12) -> GE12 { + let z = self.cyclotomic_square(native, x); + let z = self.mul(native, x, &z); + let z = self.n_square_gs_with_hint(native, &z, 2); + let z = self.mul(native, x, &z); + let z = self.n_square_gs_with_hint(native, &z, 3); + let z = self.mul(native, x, &z); + let z = self.n_square_gs_with_hint(native, &z, 9); + let z = self.mul(native, x, &z); + let z = self.n_square_gs_with_hint(native, &z, 32); + let z = self.mul(native, x, &z); + let z = self.n_square_gs_with_hint(native, &z, 15); + self.cyclotomic_square(native, &z) + } + pub fn n_square_gs>( + &mut self, + native: &mut B, + z: &GE12, + n: usize, + ) -> GE12 { + let mut new_z = z.my_clone(); + for _ in 0..n { + new_z = self.cyclotomic_square(native, &new_z); + } + new_z + } + pub fn n_square_gs_with_hint>( + &mut self, + native: &mut B, + z: &GE12, + n: usize, + ) -> GE12 { + let mut copy_z = self.copy(native, z); + for _ in 0..n - 1 { + let z = self.cyclotomic_square(native, ©_z); + copy_z = self.copy(native, &z); + } + self.cyclotomic_square(native, ©_z) + } + pub fn assert_final_exponentiation_is_one>( + &mut self, + native: &mut B, + x: &GE12, + ) { + let inputs = vec![ + x.c0.b0.a0.clone(), + x.c0.b0.a1.clone(), + x.c0.b1.a0.clone(), + x.c0.b1.a1.clone(), + x.c0.b2.a0.clone(), + x.c0.b2.a1.clone(), + x.c1.b0.a0.clone(), + x.c1.b0.a1.clone(), + x.c1.b1.a0.clone(), + x.c1.b1.a1.clone(), + x.c1.b2.a0.clone(), + x.c1.b2.a1.clone(), + ]; + let output = self + .ext6 + .ext2 + .curve_f + .new_hint(native, "myhint.finalexphint", 18, inputs); + let residue_witness = GE12 { + c0: GE6 { + b0: GE2 { + a0: output[0].clone(), + a1: output[1].clone(), + }, + b1: GE2 { + a0: output[2].clone(), + a1: output[3].clone(), + }, + b2: GE2 { + a0: output[4].clone(), + a1: output[5].clone(), + }, + }, + c1: GE6 { + b0: GE2 { + a0: output[6].clone(), + a1: output[7].clone(), + }, + b1: GE2 { + a0: output[8].clone(), + a1: output[9].clone(), + }, + b2: GE2 { + a0: output[10].clone(), + a1: output[11].clone(), + }, + }, + }; + let scaling_factor = GE12 { + c0: GE6 { + b0: GE2 { + a0: output[12].clone(), + a1: output[13].clone(), + }, + b1: GE2 { + a0: output[14].clone(), + a1: output[15].clone(), + }, + b2: GE2 { + a0: output[16].clone(), + a1: output[17].clone(), + }, + }, + c1: self.zero().c1, + }; + let t0 = self.frobenius(native, &residue_witness); + let t1 = self.expt(native, &residue_witness); + let t0 = self.mul(native, &t0, &t1); + let t1 = self.mul(native, x, &scaling_factor); + self.assert_isequal(native, &t0, &t1); + } + + pub fn frobenius>(&mut self, native: &mut B, x: &GE12) -> GE12 { + let z00 = self.ext6.ext2.conjugate(native, &x.c0.b0); + let z01 = self.ext6.ext2.conjugate(native, &x.c0.b1); + let z02 = self.ext6.ext2.conjugate(native, &x.c0.b2); + let z10 = self.ext6.ext2.conjugate(native, &x.c1.b0); + let z11 = self.ext6.ext2.conjugate(native, &x.c1.b1); + let z12 = self.ext6.ext2.conjugate(native, &x.c1.b2); + + let z01 = self.ext6.ext2.mul_by_non_residue1_power2(native, &z01); + let z02 = self.ext6.ext2.mul_by_non_residue1_power4(native, &z02); + let z10 = self.ext6.ext2.mul_by_non_residue1_power1(native, &z10); + let z11 = self.ext6.ext2.mul_by_non_residue1_power3(native, &z11); + let z12 = self.ext6.ext2.mul_by_non_residue1_power5(native, &z12); + GE12 { + c0: GE6 { + b0: z00, + b1: z01, + b2: z02, + }, + c1: GE6 { + b0: z10, + b1: z11, + b2: z12, + }, + } + } + pub fn frobenius_square>(&mut self, native: &mut B, x: &GE12) -> GE12 { + let z00 = x.c0.b0.clone(); + let z01 = self.ext6.ext2.mul_by_non_residue2_power2(native, &x.c0.b1); + let z02 = self.ext6.ext2.mul_by_non_residue2_power4(native, &x.c0.b2); + let z10 = self.ext6.ext2.mul_by_non_residue2_power1(native, &x.c1.b0); + let z11 = self.ext6.ext2.mul_by_non_residue2_power3(native, &x.c1.b1); + let z12 = self.ext6.ext2.mul_by_non_residue2_power5(native, &x.c1.b2); + GE12 { + c0: GE6 { + b0: z00, + b1: z01, + b2: z02, + }, + c1: GE6 { + b0: z10, + b1: z11, + b2: z12, + }, + } + } +} diff --git a/circuit-std-rs/src/gnark/emulated/field_bls12381/e2.rs b/circuit-std-rs/src/gnark/emulated/field_bls12381/e2.rs new file mode 100644 index 00000000..a57498db --- /dev/null +++ b/circuit-std-rs/src/gnark/emulated/field_bls12381/e2.rs @@ -0,0 +1,406 @@ +use crate::gnark::element::*; +use crate::gnark::emparam::*; +use crate::gnark::field::GField; +use expander_compiler::frontend::{Config, RootAPI, Variable}; +use num_bigint::BigInt; +use std::collections::HashMap; + +pub type CurveF = GField; +#[derive(Default, Clone)] +pub struct GE2 { + pub a0: Element, + pub a1: Element, +} +impl GE2 { + pub fn my_clone(&self) -> Self { + GE2 { + a0: self.a0.my_clone(), + a1: self.a1.my_clone(), + } + } + pub fn from_vars(x: Vec, y: Vec) -> Self { + GE2 { + a0: Element::new(x, 0, false, false, false, Variable::default()), + a1: Element::new(y, 0, false, false, false, Variable::default()), + } + } +} + +pub struct Ext2 { + pub curve_f: CurveF, + non_residues: HashMap>, +} + +impl Ext2 { + pub fn new>(api: &mut B) -> Self { + let mut _non_residues: HashMap> = HashMap::new(); + let mut pwrs: HashMap> = HashMap::new(); + let a1_1_0 = value_of::(api, Box::new("3850754370037169011952147076051364057158807420970682438676050522613628423219637725072182697113062777891589506424760".to_string())); + let a1_1_1 = value_of::(api, Box::new("151655185184498381465642749684540099398075398968325446656007613510403227271200139370504932015952886146304766135027".to_string())); + let a1_2_0 = value_of::(api, Box::new("0".to_string())); + let a1_2_1 = value_of::(api, Box::new("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939436".to_string())); + let a1_3_0 = value_of::(api, Box::new("1028732146235106349975324479215795277384839936929757896155643118032610843298655225875571310552543014690878354869257".to_string())); + let a1_3_1 = value_of::(api, Box::new("1028732146235106349975324479215795277384839936929757896155643118032610843298655225875571310552543014690878354869257".to_string())); + let a1_4_0 = value_of::(api, Box::new("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939437".to_string())); + let a1_4_1 = value_of::(api, Box::new("0".to_string())); + let a1_5_0 = value_of::(api, Box::new("877076961050607968509681729531255177986764537961432449499635504522207616027455086505066378536590128544573588734230".to_string())); + let a1_5_1 = value_of::(api, Box::new("3125332594171059424908108096204648978570118281977575435832422631601824034463382777937621250592425535493320683825557".to_string())); + let a2_1_0 = value_of::(api, Box::new("793479390729215512621379701633421447060886740281060493010456487427281649075476305620758731620351".to_string())); + let a2_2_0 = value_of::(api, Box::new("793479390729215512621379701633421447060886740281060493010456487427281649075476305620758731620350".to_string())); + let a2_3_0 = value_of::(api, Box::new("4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559786".to_string())); + let a2_4_0 = value_of::(api, Box::new("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939436".to_string())); + let a2_5_0 = value_of::(api, Box::new("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939437".to_string())); + pwrs.insert(1, HashMap::new()); + pwrs.get_mut(&1).unwrap().insert( + 1, + GE2 { + a0: a1_1_0, + a1: a1_1_1, + }, + ); + pwrs.get_mut(&1).unwrap().insert( + 2, + GE2 { + a0: a1_2_0, + a1: a1_2_1, + }, + ); + pwrs.get_mut(&1).unwrap().insert( + 3, + GE2 { + a0: a1_3_0, + a1: a1_3_1, + }, + ); + pwrs.get_mut(&1).unwrap().insert( + 4, + GE2 { + a0: a1_4_0, + a1: a1_4_1, + }, + ); + pwrs.get_mut(&1).unwrap().insert( + 5, + GE2 { + a0: a1_5_0, + a1: a1_5_1, + }, + ); + pwrs.insert(2, HashMap::new()); + let a_zero = value_of::(api, Box::new("0".to_string())); + pwrs.get_mut(&2).unwrap().insert( + 1, + GE2 { + a0: a2_1_0, + a1: a_zero, + }, + ); + let a_zero = value_of::(api, Box::new("0".to_string())); + pwrs.get_mut(&2).unwrap().insert( + 2, + GE2 { + a0: a2_2_0, + a1: a_zero, + }, + ); + let a_zero = value_of::(api, Box::new("0".to_string())); + pwrs.get_mut(&2).unwrap().insert( + 3, + GE2 { + a0: a2_3_0, + a1: a_zero, + }, + ); + let a_zero = value_of::(api, Box::new("0".to_string())); + pwrs.get_mut(&2).unwrap().insert( + 4, + GE2 { + a0: a2_4_0, + a1: a_zero, + }, + ); + let a_zero = value_of::(api, Box::new("0".to_string())); + pwrs.get_mut(&2).unwrap().insert( + 5, + GE2 { + a0: a2_5_0, + a1: a_zero, + }, + ); + let fp = CurveF::new(api, Bls12381Fp {}); + Ext2 { + curve_f: fp, + non_residues: pwrs, + } + } + pub fn one(&mut self) -> GE2 { + let z0 = self.curve_f.one_const.my_clone(); + let z1 = self.curve_f.zero_const.my_clone(); + GE2 { a0: z0, a1: z1 } + } + pub fn zero(&mut self) -> GE2 { + let z0 = self.curve_f.zero_const.my_clone(); + let z1 = self.curve_f.zero_const.my_clone(); + GE2 { a0: z0, a1: z1 } + } + pub fn is_zero>(&mut self, native: &mut B, z: &GE2) -> Variable { + let a0 = self.curve_f.is_zero(native, &z.a0); + let a1 = self.curve_f.is_zero(native, &z.a1); + native.and(a0, a1) + } + pub fn add>(&mut self, native: &mut B, x: &GE2, y: &GE2) -> GE2 { + let z0 = self.curve_f.add(native, &x.a0, &y.a0); + let z1 = self.curve_f.add(native, &x.a1, &y.a1); + GE2 { a0: z0, a1: z1 } + } + pub fn sub>(&mut self, native: &mut B, x: &GE2, y: &GE2) -> GE2 { + let z0 = self.curve_f.sub(native, &x.a0, &y.a0); + let z1 = self.curve_f.sub(native, &x.a1, &y.a1); + GE2 { a0: z0, a1: z1 } + } + pub fn double>(&mut self, native: &mut B, x: &GE2) -> GE2 { + let two = BigInt::from(2); + let z0 = self.curve_f.mul_const(native, &x.a0, two.clone()); + let z1 = self.curve_f.mul_const(native, &x.a1, two.clone()); + GE2 { a0: z0, a1: z1 } + } + pub fn neg>(&mut self, native: &mut B, x: &GE2) -> GE2 { + let z0 = self.curve_f.neg(native, &x.a0); + let z1 = self.curve_f.neg(native, &x.a1); + GE2 { a0: z0, a1: z1 } + } + pub fn mul>(&mut self, native: &mut B, x: &GE2, y: &GE2) -> GE2 { + let v0 = self.curve_f.mul(native, &x.a0, &y.a0); + let v1 = self.curve_f.mul(native, &x.a1, &y.a1); + let b0 = self.curve_f.sub(native, &v0, &v1); + let mut b1 = self.curve_f.add(native, &x.a0, &x.a1); + let mut tmp = self.curve_f.add(native, &y.a0, &y.a1); + b1 = self.curve_f.mul(native, &b1, &tmp); + tmp = self.curve_f.add(native, &v0, &v1); + b1 = self.curve_f.sub(native, &b1, &tmp); + GE2 { a0: b0, a1: b1 } + } + pub fn mul_by_element>( + &mut self, + native: &mut B, + x: &GE2, + y: &Element, + ) -> GE2 { + let v0 = self.curve_f.mul(native, &x.a0, y); + let v1 = self.curve_f.mul(native, &x.a1, y); + GE2 { a0: v0, a1: v1 } + } + pub fn mul_by_const_element>( + &mut self, + native: &mut B, + x: &GE2, + y: &BigInt, + ) -> GE2 { + let z0 = self.curve_f.mul_const(native, &x.a0, y.clone()); + let z1 = self.curve_f.mul_const(native, &x.a1, y.clone()); + GE2 { a0: z0, a1: z1 } + } + pub fn mul_by_non_residue>(&mut self, native: &mut B, x: &GE2) -> GE2 { + let a = self.curve_f.sub(native, &x.a0, &x.a1); + let b = self.curve_f.add(native, &x.a0, &x.a1); + GE2 { a0: a, a1: b } + } + pub fn square>(&mut self, native: &mut B, x: &GE2) -> GE2 { + let a = self.curve_f.add(native, &x.a0, &x.a1); + let b = self.curve_f.sub(native, &x.a0, &x.a1); + let a = self.curve_f.mul(native, &a, &b); + let b = self.curve_f.mul(native, &x.a0, &x.a1); + let b = self.curve_f.mul_const(native, &b, BigInt::from(2)); + GE2 { a0: a, a1: b } + } + pub fn div>(&mut self, native: &mut B, x: &GE2, y: &GE2) -> GE2 { + let inputs = vec![ + x.a0.my_clone(), + x.a1.my_clone(), + y.a0.my_clone(), + y.a1.my_clone(), + ]; + let output = self.curve_f.new_hint(native, "myhint.dive2hint", 2, inputs); + let div = GE2 { + a0: output[0].my_clone(), + a1: output[1].my_clone(), + }; + let _x = self.mul(native, &div, y); + self.assert_isequal(native, x, &_x); + div + } + pub fn inverse_div>(&mut self, native: &mut B, x: &GE2) -> GE2 { + self.div( + native, + &GE2 { + a0: self.curve_f.one_const.my_clone(), + a1: self.curve_f.zero_const.my_clone(), + }, + x, + ) + } + pub fn inverse>(&mut self, native: &mut B, x: &GE2) -> GE2 { + let inputs = vec![x.a0.my_clone(), x.a1.my_clone()]; + let output = self + .curve_f + .new_hint(native, "myhint.inversee2hint", 2, inputs); + let inv = GE2 { + a0: output[0].my_clone(), + a1: output[1].my_clone(), + }; + let one = GE2 { + a0: self.curve_f.one_const.my_clone(), + a1: self.curve_f.zero_const.my_clone(), + }; + let _one = self.mul(native, &inv, x); + self.assert_isequal(native, &one, &_one); + inv + } + pub fn assert_isequal>(&mut self, native: &mut B, x: &GE2, y: &GE2) { + self.curve_f.assert_isequal(native, &x.a0, &y.a0); + self.curve_f.assert_isequal(native, &x.a1, &y.a1); + } + pub fn select>( + &mut self, + native: &mut B, + selector: Variable, + z1: &GE2, + z0: &GE2, + ) -> GE2 { + let a0 = self.curve_f.select(native, selector, &z1.a0, &z0.a0); + let a1 = self.curve_f.select(native, selector, &z1.a1, &z0.a1); + GE2 { a0, a1 } + } + pub fn conjugate>(&mut self, native: &mut B, x: &GE2) -> GE2 { + let z0 = x.a0.my_clone(); + let z1 = self.curve_f.neg(native, &x.a1); + GE2 { a0: z0, a1: z1 } + } + pub fn mul_by_non_residue_generic>( + &mut self, + native: &mut B, + x: &GE2, + power: u32, + coef: u32, + ) -> GE2 { + let y = self + .non_residues + .get(&power) + .unwrap() + .get(&coef) + .unwrap() + .my_clone(); + self.mul(native, x, &y) + } + pub fn mul_by_non_residue1_power1>( + &mut self, + native: &mut B, + x: &GE2, + ) -> GE2 { + self.mul_by_non_residue_generic(native, x, 1, 1) + } + pub fn mul_by_non_residue1_power2>( + &mut self, + native: &mut B, + x: &GE2, + ) -> GE2 { + let element = value_of::(native, Box::new("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939436".to_string())); + let a = self.curve_f.mul(native, &x.a1, &element); + let a = self.curve_f.neg(native, &a); + let b = self.curve_f.mul(native, &x.a0, &element); + GE2 { a0: a, a1: b } + } + pub fn mul_by_non_residue1_power3>( + &mut self, + native: &mut B, + x: &GE2, + ) -> GE2 { + self.mul_by_non_residue_generic(native, x, 1, 3) + } + pub fn mul_by_non_residue1_power4>( + &mut self, + native: &mut B, + x: &GE2, + ) -> GE2 { + let element = value_of::(native, Box::new("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939437".to_string())); + let a = self.curve_f.mul(native, &x.a0, &element); + let b = self.curve_f.mul(native, &x.a1, &element); + GE2 { a0: a, a1: b } + } + pub fn mul_by_non_residue1_power5>( + &mut self, + native: &mut B, + x: &GE2, + ) -> GE2 { + self.mul_by_non_residue_generic(native, x, 1, 5) + } + pub fn mul_by_non_residue2_power1>( + &mut self, + native: &mut B, + x: &GE2, + ) -> GE2 { + let element = value_of::(native, Box::new("793479390729215512621379701633421447060886740281060493010456487427281649075476305620758731620351".to_string())); + let a = self.curve_f.mul(native, &x.a0, &element); + let b = self.curve_f.mul(native, &x.a1, &element); + GE2 { a0: a, a1: b } + } + pub fn mul_by_non_residue2_power2>( + &mut self, + native: &mut B, + x: &GE2, + ) -> GE2 { + let element = value_of::(native, Box::new("793479390729215512621379701633421447060886740281060493010456487427281649075476305620758731620350".to_string())); + let a = self.curve_f.mul(native, &x.a0, &element); + let b = self.curve_f.mul(native, &x.a1, &element); + GE2 { a0: a, a1: b } + } + pub fn mul_by_non_residue2_power3>( + &mut self, + native: &mut B, + x: &GE2, + ) -> GE2 { + let element = value_of::(native, Box::new("4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559786".to_string())); + let a = self.curve_f.mul(native, &x.a0, &element); + let b = self.curve_f.mul(native, &x.a1, &element); + GE2 { a0: a, a1: b } + } + pub fn mul_by_non_residue2_power4>( + &mut self, + native: &mut B, + x: &GE2, + ) -> GE2 { + let element = value_of::(native, Box::new("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939436".to_string())); + let a = self.curve_f.mul(native, &x.a0, &element); + let b = self.curve_f.mul(native, &x.a1, &element); + GE2 { a0: a, a1: b } + } + pub fn mul_by_non_residue2_power5>( + &mut self, + native: &mut B, + x: &GE2, + ) -> GE2 { + let element = value_of::(native, Box::new("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939437".to_string())); + let a = self.curve_f.mul(native, &x.a0, &element); + let b = self.curve_f.mul(native, &x.a1, &element); + GE2 { a0: a, a1: b } + } + pub fn non_residue>(&mut self, _native: &mut B) -> GE2 { + let one = self.curve_f.one_const.my_clone(); + GE2 { + a0: one.my_clone(), + a1: one.my_clone(), + } + } + pub fn copy>(&mut self, native: &mut B, x: &GE2) -> GE2 { + let inputs = vec![x.a0.my_clone(), x.a1.my_clone()]; + let output = self + .curve_f + .new_hint(native, "myhint.copye2hint", 2, inputs); + let res = GE2 { + a0: output[0].my_clone(), + a1: output[1].my_clone(), + }; + self.assert_isequal(native, x, &res); + res + } +} diff --git a/circuit-std-rs/src/gnark/emulated/field_bls12381/e6.rs b/circuit-std-rs/src/gnark/emulated/field_bls12381/e6.rs new file mode 100644 index 00000000..e2f3972f --- /dev/null +++ b/circuit-std-rs/src/gnark/emulated/field_bls12381/e6.rs @@ -0,0 +1,377 @@ +use crate::gnark::{element::Element, emparam::Bls12381Fp}; +use expander_compiler::frontend::{Config, RootAPI, Variable}; +use num_bigint::BigInt; + +use super::e2::*; +#[derive(Default, Clone)] +pub struct GE6 { + pub b0: GE2, + pub b1: GE2, + pub b2: GE2, +} +impl GE6 { + pub fn my_clone(&self) -> Self { + GE6 { + b0: self.b0.my_clone(), + b1: self.b1.my_clone(), + b2: self.b2.my_clone(), + } + } +} +pub struct Ext6 { + pub ext2: Ext2, +} + +impl Ext6 { + pub fn new>(api: &mut B) -> Self { + Self { + ext2: Ext2::new(api), + } + } + pub fn one(&mut self) -> GE6 { + let b0 = self.ext2.one(); + let b1 = self.ext2.zero(); + let b2 = self.ext2.zero(); + GE6 { b0, b1, b2 } + } + pub fn zero>(&mut self) -> GE6 { + let b0 = self.ext2.zero(); + let b1 = self.ext2.zero(); + let b2 = self.ext2.zero(); + GE6 { b0, b1, b2 } + } + pub fn is_zero>(&mut self, native: &mut B, z: &GE6) -> Variable { + let b0 = self.ext2.is_zero(native, &z.b0.my_clone()); + let b1 = self.ext2.is_zero(native, &z.b1.my_clone()); + let b2 = self.ext2.is_zero(native, &z.b2.my_clone()); + let tmp = native.and(b0, b1); + native.and(tmp, b2) + } + pub fn add>(&mut self, native: &mut B, x: &GE6, y: &GE6) -> GE6 { + let z0 = self.ext2.add(native, &x.b0.my_clone(), &y.b0.my_clone()); + let z1 = self.ext2.add(native, &x.b1.my_clone(), &y.b1.my_clone()); + let z2 = self.ext2.add(native, &x.b2.my_clone(), &y.b2.my_clone()); + GE6 { + b0: z0, + b1: z1, + b2: z2, + } + } + pub fn neg>(&mut self, native: &mut B, x: &GE6) -> GE6 { + let z0 = self.ext2.neg(native, &x.b0.my_clone()); + let z1 = self.ext2.neg(native, &x.b1.my_clone()); + let z2 = self.ext2.neg(native, &x.b2.my_clone()); + GE6 { + b0: z0, + b1: z1, + b2: z2, + } + } + pub fn sub>(&mut self, native: &mut B, x: &GE6, y: &GE6) -> GE6 { + let z0 = self.ext2.sub(native, &x.b0.my_clone(), &y.b0.my_clone()); + let z1 = self.ext2.sub(native, &x.b1.my_clone(), &y.b1.my_clone()); + let z2 = self.ext2.sub(native, &x.b2.my_clone(), &y.b2.my_clone()); + GE6 { + b0: z0, + b1: z1, + b2: z2, + } + } + pub fn double>(&mut self, native: &mut B, x: &GE6) -> GE6 { + let z0 = self.ext2.double(native, &x.b0.my_clone()); + let z1 = self.ext2.double(native, &x.b1.my_clone()); + let z2 = self.ext2.double(native, &x.b2.my_clone()); + GE6 { + b0: z0, + b1: z1, + b2: z2, + } + } + pub fn square>(&mut self, native: &mut B, x: &GE6) -> GE6 { + let c4 = self.ext2.mul(native, &x.b0.my_clone(), &x.b1.my_clone()); + let c4 = self.ext2.double(native, &c4); + let c5 = self.ext2.square(native, &x.b2.my_clone()); + let c1 = self.ext2.mul_by_non_residue(native, &c5); + let c1 = self.ext2.add(native, &c1, &c4); + let c2 = self.ext2.sub(native, &c4, &c5); + let c3 = self.ext2.square(native, &x.b0.my_clone()); + let c4 = self.ext2.sub(native, &x.b0.my_clone(), &x.b1.my_clone()); + let c4 = self.ext2.add(native, &c4, &x.b2.my_clone()); + let c5 = self.ext2.mul(native, &x.b1.my_clone(), &x.b2.my_clone()); + let c5 = self.ext2.double(native, &c5); + let c4 = self.ext2.square(native, &c4); + let c0 = self.ext2.mul_by_non_residue(native, &c5); + let c0 = self.ext2.add(native, &c0, &c3); + let z2 = self.ext2.add(native, &c2, &c4); + let z2 = self.ext2.add(native, &z2, &c5); + let z2 = self.ext2.sub(native, &z2, &c3); + let z0 = c0; + let z1 = c1; + GE6 { + b0: z0, + b1: z1, + b2: z2, + } + } + pub fn mul_by_e2>(&mut self, native: &mut B, x: &GE6, y: &GE2) -> GE6 { + let z0 = self.ext2.mul(native, &x.b0.my_clone(), y); + let z1 = self.ext2.mul(native, &x.b1.my_clone(), y); + let z2 = self.ext2.mul(native, &x.b2.my_clone(), y); + GE6 { + b0: z0, + b1: z1, + b2: z2, + } + } + pub fn mul_by_12>( + &mut self, + native: &mut B, + x: &GE6, + b1: &GE2, + b2: &GE2, + ) -> GE6 { + let t1 = self.ext2.mul(native, &x.b1.my_clone(), b1); + let t2 = self.ext2.mul(native, &x.b2.my_clone(), b2); + let mut c0 = self.ext2.add(native, &x.b1.my_clone(), &x.b2.my_clone()); + let mut tmp = self.ext2.add(native, b1, b2); + c0 = self.ext2.mul(native, &c0, &tmp); + tmp = self.ext2.add(native, &t1, &t2); + c0 = self.ext2.sub(native, &c0, &tmp); + c0 = self.ext2.mul_by_non_residue(native, &c0); + let mut c1 = self.ext2.add(native, &x.b0.my_clone(), &x.b1.my_clone()); + c1 = self.ext2.mul(native, &c1, b1); + c1 = self.ext2.sub(native, &c1, &t1); + tmp = self.ext2.mul_by_non_residue(native, &t2); + c1 = self.ext2.add(native, &c1, &tmp); + tmp = self.ext2.add(native, &x.b0.my_clone(), &x.b2.my_clone()); + let mut c2 = self.ext2.mul(native, b2, &tmp); + c2 = self.ext2.sub(native, &c2, &t2); + c2 = self.ext2.add(native, &c2, &t1); + GE6 { + b0: c0, + b1: c1, + b2: c2, + } + } + pub fn mul_by_0>(&mut self, native: &mut B, z: &GE6, c0: &GE2) -> GE6 { + let a = self.ext2.mul(native, &z.b0.my_clone(), c0); + let tmp = self.ext2.add(native, &z.b0.my_clone(), &z.b2.my_clone()); + let mut t2 = self.ext2.mul(native, c0, &tmp); + t2 = self.ext2.sub(native, &t2, &a); + let tmp = self.ext2.add(native, &z.b0.my_clone(), &z.b1.my_clone()); + let mut t1 = self.ext2.mul(native, c0, &tmp); + t1 = self.ext2.sub(native, &t1, &a); + GE6 { + b0: a, + b1: t1, + b2: t2, + } + } + pub fn mul_by_01>( + &mut self, + native: &mut B, + z: &GE6, + c0: &GE2, + c1: &GE2, + ) -> GE6 { + let a = self.ext2.mul(native, &z.b0, c0); + let b = self.ext2.mul(native, &z.b1, c1); + let tmp = self.ext2.add(native, &z.b1.my_clone(), &z.b2.my_clone()); + let mut t0 = self.ext2.mul(native, c1, &tmp); + + t0 = self.ext2.sub(native, &t0, &b); + t0 = self.ext2.mul_by_non_residue(native, &t0); + t0 = self.ext2.add(native, &t0, &a); + let mut t2 = self.ext2.mul(native, &z.b2.my_clone(), c0); + t2 = self.ext2.add(native, &t2, &b); + let mut t1 = self.ext2.add(native, c0, c1); + let tmp = self.ext2.add(native, &z.b0.my_clone(), &z.b1.my_clone()); + t1 = self.ext2.mul(native, &t1, &tmp); + let tmp = self.ext2.add(native, &a, &b); + t1 = self.ext2.sub(native, &t1, &tmp); + GE6 { + b0: t0, + b1: t1, + b2: t2, + } + } + pub fn mul_by_non_residue>(&mut self, native: &mut B, x: &GE6) -> GE6 { + let z0 = self.ext2.mul_by_non_residue(native, &x.b2.my_clone()); + GE6 { + b0: z0, + b1: x.b0.my_clone(), + b2: x.b1.my_clone(), + } + } + pub fn assert_isequal>(&mut self, native: &mut B, x: &GE6, y: &GE6) { + self.ext2.assert_isequal(native, &x.b0, &y.b0); + self.ext2.assert_isequal(native, &x.b1, &y.b1); + self.ext2.assert_isequal(native, &x.b2, &y.b2); + } + pub fn select>( + &mut self, + native: &mut B, + selector: Variable, + z1: &GE6, + z0: &GE6, + ) -> GE6 { + let b0 = self + .ext2 + .select(native, selector, &z1.b0.my_clone(), &z0.b0.my_clone()); + let b1 = self + .ext2 + .select(native, selector, &z1.b1.my_clone(), &z0.b1.my_clone()); + let b2 = self + .ext2 + .select(native, selector, &z1.b2.my_clone(), &z0.b2.my_clone()); + GE6 { b0, b1, b2 } + } + pub fn mul_karatsuba_over_karatsuba>( + &mut self, + native: &mut B, + x: &GE6, + y: &GE6, + ) -> GE6 { + let t0 = self.ext2.mul(native, &x.b0.my_clone(), &y.b0.my_clone()); + let t1 = self.ext2.mul(native, &x.b1.my_clone(), &y.b1.my_clone()); + let t2 = self.ext2.mul(native, &x.b2.my_clone(), &y.b2.my_clone()); + let mut c0 = self.ext2.add(native, &x.b1.my_clone(), &x.b2.my_clone()); + let mut tmp = self.ext2.add(native, &y.b1.my_clone(), &y.b2.my_clone()); + c0 = self.ext2.mul(native, &c0, &tmp); + tmp = self.ext2.add(native, &t2, &t1); + c0 = self.ext2.sub(native, &c0, &tmp); + c0 = self.ext2.mul_by_non_residue(native, &c0); + c0 = self.ext2.add(native, &c0, &t0); + let mut c1 = self.ext2.add(native, &x.b0.my_clone(), &x.b1.my_clone()); + tmp = self.ext2.add(native, &y.b0.my_clone(), &y.b1.my_clone()); + c1 = self.ext2.mul(native, &c1, &tmp); + tmp = self.ext2.add(native, &t0, &t1); + c1 = self.ext2.sub(native, &c1, &tmp); + tmp = self.ext2.mul_by_non_residue(native, &t2); + c1 = self.ext2.add(native, &c1, &tmp); + let mut tmp = self.ext2.add(native, &x.b0.my_clone(), &x.b2.my_clone()); + let mut c2 = self.ext2.add(native, &y.b0.my_clone(), &y.b2.my_clone()); + c2 = self.ext2.mul(native, &c2, &tmp); + tmp = self.ext2.add(native, &t0, &t2); + c2 = self.ext2.sub(native, &c2, &tmp); + c2 = self.ext2.add(native, &c2, &t1); + GE6 { + b0: c0, + b1: c1, + b2: c2, + } + } + pub fn mul>(&mut self, native: &mut B, x: &GE6, y: &GE6) -> GE6 { + self.mul_karatsuba_over_karatsuba(native, x, y) + } + pub fn div>(&mut self, native: &mut B, x: &GE6, y: &GE6) -> GE6 { + let inputs = vec![ + x.b0.a0.my_clone(), + x.b0.a1.my_clone(), + x.b1.a0.my_clone(), + x.b1.a1.my_clone(), + x.b2.a0.my_clone(), + x.b2.a1.my_clone(), + y.b0.a0.my_clone(), + y.b0.a1.my_clone(), + y.b1.a0.my_clone(), + y.b1.a1.my_clone(), + y.b2.a0.my_clone(), + y.b2.a1.my_clone(), + ]; + let output = self + .ext2 + .curve_f + .new_hint(native, "myhint.dive6hint", 6, inputs); + let div = GE6 { + b0: GE2 { + a0: output[0].my_clone(), + a1: output[1].my_clone(), + }, + b1: GE2 { + a0: output[2].my_clone(), + a1: output[3].my_clone(), + }, + b2: GE2 { + a0: output[4].my_clone(), + a1: output[5].my_clone(), + }, + }; + let _x = self.mul(native, &div, y); + self.assert_isequal(native, &x.my_clone(), &_x); + div + } + pub fn inverse_div>(&mut self, native: &mut B, x: &GE6) -> GE6 { + let one = self.one(); + self.div(native, &one, x) + } + pub fn inverse>(&mut self, native: &mut B, x: &GE6) -> GE6 { + let inputs = vec![ + x.b0.a0.my_clone(), + x.b0.a1.my_clone(), + x.b1.a0.my_clone(), + x.b1.a1.my_clone(), + x.b2.a0.my_clone(), + x.b2.a1.my_clone(), + ]; + let output = self + .ext2 + .curve_f + .new_hint(native, "myhint.inversee6hint", 6, inputs); + let inv = GE6 { + b0: GE2 { + a0: output[0].my_clone(), + a1: output[1].my_clone(), + }, + b1: GE2 { + a0: output[2].my_clone(), + a1: output[3].my_clone(), + }, + b2: GE2 { + a0: output[4].my_clone(), + a1: output[5].my_clone(), + }, + }; + let one = self.one(); + let _one = self.mul(native, &inv, x); + self.assert_isequal(native, &one, &_one); + inv + } + pub fn div_e6_by_6>( + &mut self, + native: &mut B, + x: &[Element; 6], + ) -> [Element; 6] { + let inputs = vec![ + x[0].my_clone(), + x[1].my_clone(), + x[2].my_clone(), + x[3].my_clone(), + x[4].my_clone(), + x[5].my_clone(), + ]; + let output = self + .ext2 + .curve_f + .new_hint(native, "myhint.dive6by6hint", 6, inputs); + let y0 = output[0].my_clone(); + let y1 = output[1].my_clone(); + let y2 = output[2].my_clone(); + let y3 = output[3].my_clone(); + let y4 = output[4].my_clone(); + let y5 = output[5].my_clone(); + let x0 = self.ext2.curve_f.mul_const(native, &y0, BigInt::from(6)); + let x1 = self.ext2.curve_f.mul_const(native, &y1, BigInt::from(6)); + let x2 = self.ext2.curve_f.mul_const(native, &y2, BigInt::from(6)); + let x3 = self.ext2.curve_f.mul_const(native, &y3, BigInt::from(6)); + let x4 = self.ext2.curve_f.mul_const(native, &y4, BigInt::from(6)); + let x5 = self.ext2.curve_f.mul_const(native, &y5, BigInt::from(6)); + self.ext2.curve_f.assert_isequal(native, &x[0], &x0); + self.ext2.curve_f.assert_isequal(native, &x[1], &x1); + self.ext2.curve_f.assert_isequal(native, &x[2], &x2); + self.ext2.curve_f.assert_isequal(native, &x[3], &x3); + self.ext2.curve_f.assert_isequal(native, &x[4], &x4); + self.ext2.curve_f.assert_isequal(native, &x[5], &x5); + [y0, y1, y2, y3, y4, y5] + } +} diff --git a/circuit-std-rs/src/gnark/emulated/field_bls12381/mod.rs b/circuit-std-rs/src/gnark/emulated/field_bls12381/mod.rs new file mode 100644 index 00000000..f2828701 --- /dev/null +++ b/circuit-std-rs/src/gnark/emulated/field_bls12381/mod.rs @@ -0,0 +1,3 @@ +pub mod e12; +pub mod e2; +pub mod e6; diff --git a/circuit-std-rs/src/gnark/emulated/mod.rs b/circuit-std-rs/src/gnark/emulated/mod.rs new file mode 100644 index 00000000..89f7a447 --- /dev/null +++ b/circuit-std-rs/src/gnark/emulated/mod.rs @@ -0,0 +1,2 @@ +pub mod field_bls12381; +pub mod sw_bls12381; diff --git a/circuit-std-rs/src/gnark/emulated/sw_bls12381/g1.rs b/circuit-std-rs/src/gnark/emulated/sw_bls12381/g1.rs new file mode 100644 index 00000000..daaadfe6 --- /dev/null +++ b/circuit-std-rs/src/gnark/emulated/sw_bls12381/g1.rs @@ -0,0 +1,62 @@ +use crate::gnark::element::*; +use crate::gnark::emparam::Bls12381Fp; +use crate::gnark::emulated::field_bls12381::e2::CurveF; +use expander_compiler::frontend::*; + +#[derive(Default, Clone)] +pub struct G1Affine { + pub x: Element, + pub y: Element, +} +impl G1Affine { + pub fn new(x: Element, y: Element) -> Self { + Self { x, y } + } + pub fn from_vars(x: Vec, y: Vec) -> Self { + Self { + x: Element::new(x, 0, false, false, false, Variable::default()), + y: Element::new(y, 0, false, false, false, Variable::default()), + } + } + pub fn one>(native: &mut B) -> Self { + //g1Gen.X.SetString("3685416753713387016781088315183077757961620795782546409894578378688607592378376318836054947676345821548104185464507") + //g1Gen.Y.SetString("1339506544944476473020471379941921221584933875938349620426543736416511423956333506472724655353366534992391756441569") + Self { + x: value_of::(native, Box::new("3685416753713387016781088315183077757961620795782546409894578378688607592378376318836054947676345821548104185464507".to_string())), + y: value_of::(native, Box::new("1339506544944476473020471379941921221584933875938349620426543736416511423956333506472724655353366534992391756441569".to_string())), + } + } +} +pub struct G1 { + pub curve_f: CurveF, + pub w: Element, +} + +impl G1 { + pub fn new>(native: &mut B) -> Self { + let curve_f = CurveF::new(native, Bls12381Fp {}); + let w = value_of::( native, Box::new("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939436".to_string())); + + Self { curve_f, w } + } + pub fn add>( + &mut self, + native: &mut B, + p: &G1Affine, + q: &G1Affine, + ) -> G1Affine { + let qypy = self.curve_f.sub(native, &q.y, &p.y); + let qxpx = self.curve_f.sub(native, &q.x, &p.x); + let λ = self.curve_f.div(native, &qypy, &qxpx); + + let λλ = self.curve_f.mul(native, &λ, &λ); + let qxpx = self.curve_f.add(native, &p.x, &q.x); + let xr = self.curve_f.sub(native, &λλ, &qxpx); + + let pxrx = self.curve_f.sub(native, &p.x, &xr); + let λpxrx = self.curve_f.mul(native, &λ, &pxrx); + let yr = self.curve_f.sub(native, &λpxrx, &p.y); + + G1Affine { x: xr, y: yr } + } +} diff --git a/circuit-std-rs/src/gnark/emulated/sw_bls12381/g2.rs b/circuit-std-rs/src/gnark/emulated/sw_bls12381/g2.rs new file mode 100644 index 00000000..96e2074c --- /dev/null +++ b/circuit-std-rs/src/gnark/emulated/sw_bls12381/g2.rs @@ -0,0 +1,67 @@ +use crate::gnark::emparam::Bls12381Fp; +use crate::gnark::emulated::field_bls12381::e2::Ext2; +use crate::gnark::emulated::field_bls12381::e2::GE2; +use expander_compiler::frontend::*; +#[derive(Default, Clone)] +pub struct G2AffP { + pub x: GE2, + pub y: GE2, +} + +impl G2AffP { + pub fn new(x: GE2, y: GE2) -> Self { + Self { x, y } + } + pub fn from_vars( + x0: Vec, + y0: Vec, + x1: Vec, + y1: Vec, + ) -> Self { + Self { + x: GE2::from_vars(x0, y0), + y: GE2::from_vars(x1, y1), + } + } +} + +pub struct G2 { + pub curve_f: Ext2, +} + +impl G2 { + pub fn new>(native: &mut B) -> Self { + let curve_f = Ext2::new(native); + Self { curve_f } + } + pub fn neg>(&mut self, native: &mut B, p: &G2AffP) -> G2AffP { + let yr = self.curve_f.neg(native, &p.y); + G2AffP::new(p.x.my_clone(), yr) + } +} +#[derive(Default)] +pub struct LineEvaluation { + pub r0: GE2, + pub r1: GE2, +} + +type LineEvaluationArray = [[Option>; 63]; 2]; + +pub struct LineEvaluations(pub LineEvaluationArray); + +impl Default for LineEvaluations { + fn default() -> Self { + LineEvaluations([[None; 63]; 2].map(|row: [Option; 63]| row.map(|_| None))) + } +} +impl LineEvaluations { + pub fn is_empty(&self) -> bool { + self.0 + .iter() + .all(|row| row.iter().all(|cell| cell.is_none())) + } +} +pub struct G2Affine { + pub p: G2AffP, + pub lines: LineEvaluations, +} diff --git a/circuit-std-rs/src/gnark/emulated/sw_bls12381/mod.rs b/circuit-std-rs/src/gnark/emulated/sw_bls12381/mod.rs new file mode 100644 index 00000000..245eaeb4 --- /dev/null +++ b/circuit-std-rs/src/gnark/emulated/sw_bls12381/mod.rs @@ -0,0 +1,3 @@ +pub mod g1; +pub mod g2; +pub mod pairing; diff --git a/circuit-std-rs/src/gnark/emulated/sw_bls12381/pairing.rs b/circuit-std-rs/src/gnark/emulated/sw_bls12381/pairing.rs new file mode 100644 index 00000000..41d6b330 --- /dev/null +++ b/circuit-std-rs/src/gnark/emulated/sw_bls12381/pairing.rs @@ -0,0 +1,439 @@ +use super::g1::G1Affine; +use super::g2::G2AffP; +use super::g2::G2Affine; +use super::g2::LineEvaluation; +use super::g2::LineEvaluations; +use crate::gnark::emparam::Bls12381Fp; +use crate::gnark::emulated::field_bls12381::e12::*; +use crate::gnark::emulated::field_bls12381::e2::*; +use crate::gnark::emulated::field_bls12381::e6::GE6; +use expander_compiler::frontend::{Config, Error, RootAPI}; +use num_bigint::BigInt; + +const LOOP_COUNTER: [i8; 64] = [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, +]; +pub struct Pairing { + pub ext12: Ext12, + pub curve_f: CurveF, +} + +impl Pairing { + pub fn new>(native: &mut B) -> Self { + let curve_f = CurveF::new(native, Bls12381Fp {}); + let ext12 = Ext12::new(native); + Self { curve_f, ext12 } + } + pub fn pairing_check>( + &mut self, + native: &mut B, + p: &[G1Affine], + q: &mut [G2Affine], + ) -> Result<(), Error> { + let f = self.miller_loop(native, p, q).unwrap(); + let buf = self.ext12.conjugate(native, &f); + + let buf = self.ext12.div(native, &buf, &f); + let f = self.ext12.frobenius_square(native, &buf); + let f = self.ext12.mul(native, &f, &buf); + + self.ext12.assert_final_exponentiation_is_one(native, &f); + + Ok(()) + } + pub fn miller_loop>( + &mut self, + native: &mut B, + p: &[G1Affine], + q: &mut [G2Affine], + ) -> Result { + let n = p.len(); + if n == 0 || n != q.len() { + return Err("nvalid inputs sizes".to_string()); + } + let mut lines = vec![]; + for cur_q in q { + if cur_q.lines.is_empty() { + let qlines = self.compute_lines_with_hint(native, &cur_q.p); + cur_q.lines = qlines; + } + let line_evaluations = std::mem::take(&mut cur_q.lines); + lines.push(line_evaluations); + } + self.miller_loop_lines_with_hint(native, p, lines) + } + pub fn miller_loop_lines_with_hint>( + &mut self, + native: &mut B, + p: &[G1Affine], + lines: Vec, + ) -> Result { + let n = p.len(); + if n == 0 || n != lines.len() { + return Err("invalid inputs sizes".to_string()); + } + let mut y_inv = vec![]; + let mut x_neg_over_y = vec![]; + for cur_p in p.iter().take(n) { + let y_inv_k = self.curve_f.inverse(native, &cur_p.y); + let x_neg_over_y_k = self.curve_f.mul(native, &cur_p.x, &y_inv_k); + let x_neg_over_y_k = self.curve_f.neg(native, &x_neg_over_y_k); + y_inv.push(y_inv_k); + x_neg_over_y.push(x_neg_over_y_k); + } + + let mut res = self.ext12.one(); + + if let Some(line_evaluation) = &lines[0].0[0][62] { + let line = line_evaluation; + res.c0.b0 = self + .ext12 + .ext6 + .ext2 + .mul_by_element(native, &line.r1, &y_inv[0]); + res.c0.b1 = self + .ext12 + .ext6 + .ext2 + .mul_by_element(native, &line.r0, &x_neg_over_y[0]); + } else { + return Err("line evaluation is None".to_string()); + } + res.c1.b1 = self.ext12.ext6.ext2.one(); + + if let Some(line_evaluation) = &lines[0].0[1][62] { + let line = line_evaluation; + let tmp0 = self + .ext12 + .ext6 + .ext2 + .mul_by_element(native, &line.r1, &y_inv[0]); + let tmp1 = self + .ext12 + .ext6 + .ext2 + .mul_by_element(native, &line.r0, &x_neg_over_y[0]); + let prod_lines = self + .ext12 + .mul_014_by_014(native, &tmp0, &tmp1, &res.c0.b0, &res.c0.b1); + res = GE12 { + c0: GE6 { + b0: prod_lines[0].my_clone(), + b1: prod_lines[1].my_clone(), + b2: prod_lines[2].my_clone(), + }, + c1: GE6 { + b0: res.c1.b0.my_clone(), + b1: prod_lines[3].my_clone(), + b2: prod_lines[4].my_clone(), + }, + }; + } else { + return Err("line evaluation is None".to_string()); + } + + for k in 1..n { + if let Some(line_evaluation) = &lines[k].0[0][62] { + let line = line_evaluation; + let tmp0 = self + .ext12 + .ext6 + .ext2 + .mul_by_element(native, &line.r1, &y_inv[k]); + let tmp1 = self + .ext12 + .ext6 + .ext2 + .mul_by_element(native, &line.r0, &x_neg_over_y[k]); + res = self.ext12.mul_by_014(native, &res, &tmp0, &tmp1); + } else { + return Err("line evaluation is None".to_string()); + } + if let Some(line_evaluation) = &lines[k].0[1][62] { + let line = line_evaluation; + let tmp0 = self + .ext12 + .ext6 + .ext2 + .mul_by_element(native, &line.r1, &y_inv[k]); + let tmp1 = self + .ext12 + .ext6 + .ext2 + .mul_by_element(native, &line.r0, &x_neg_over_y[k]); + res = self.ext12.mul_by_014(native, &res, &tmp0, &tmp1); + } else { + return Err("line evaluation is None".to_string()); + } + } + + let mut copy_res = self.ext12.copy(native, &res); + + for i in (0..=61).rev() { + res = self.ext12.square(native, ©_res); + copy_res = self.ext12.copy(native, &res); + for k in 0..n { + if LOOP_COUNTER[i as usize] == 0 { + if let Some(line_evaluation) = &lines[k].0[0][i as usize] { + let line = line_evaluation; + let tmp0 = self + .ext12 + .ext6 + .ext2 + .mul_by_element(native, &line.r1, &y_inv[k]); + let tmp1 = + self.ext12 + .ext6 + .ext2 + .mul_by_element(native, &line.r0, &x_neg_over_y[k]); + res = self.ext12.mul_by_014(native, ©_res, &tmp0, &tmp1); + copy_res = self.ext12.copy(native, &res); + } else { + return Err("line evaluation is None".to_string()); + } + } else { + if let Some(line_evaluation) = &lines[k].0[0][i as usize] { + let line = line_evaluation; + let tmp0 = self + .ext12 + .ext6 + .ext2 + .mul_by_element(native, &line.r1, &y_inv[k]); + let tmp1 = + self.ext12 + .ext6 + .ext2 + .mul_by_element(native, &line.r0, &x_neg_over_y[k]); + res = self.ext12.mul_by_014(native, ©_res, &tmp0, &tmp1); + copy_res = self.ext12.copy(native, &res); + } else { + return Err("line evaluation is None".to_string()); + } + if let Some(line_evaluation) = &lines[k].0[1][i as usize] { + let line = line_evaluation; + let tmp0 = self + .ext12 + .ext6 + .ext2 + .mul_by_element(native, &line.r1, &y_inv[k]); + let tmp1 = + self.ext12 + .ext6 + .ext2 + .mul_by_element(native, &line.r0, &x_neg_over_y[k]); + res = self.ext12.mul_by_014(native, ©_res, &tmp0, &tmp1); + copy_res = self.ext12.copy(native, &res); + } else { + return Err("line evaluation is None".to_string()); + } + } + } + } + res = self.ext12.conjugate(native, ©_res); + Ok(res) + } + pub fn compute_lines_with_hint>( + &mut self, + native: &mut B, + q: &G2AffP, + ) -> LineEvaluations { + // let mut c_lines = LineEvaluations::default(); + let mut c_lines: LineEvaluations = LineEvaluations::default(); + let q_acc = q; + let mut copy_q_acc = self.copy_g2_aff_p(native, q_acc); + let n = LOOP_COUNTER.len(); + let (q_acc, line1, line2) = self.triple_step(native, copy_q_acc); + c_lines.0[0][n - 2] = line1; + c_lines.0[1][n - 2] = line2; + copy_q_acc = self.copy_g2_aff_p(native, &q_acc); + for i in (1..=n - 3).rev() { + if LOOP_COUNTER[i] == 0 { + let (q_acc, c_lines_0_i) = self.double_step(native, copy_q_acc); + copy_q_acc = self.copy_g2_aff_p(native, &q_acc); + c_lines.0[0][i] = c_lines_0_i; + } else { + let (q_acc, c_lines_0_i, c_lines_1_i) = + self.double_and_add_step(native, copy_q_acc, q); + copy_q_acc = self.copy_g2_aff_p(native, &q_acc); + c_lines.0[0][i] = c_lines_0_i; + c_lines.0[1][i] = c_lines_1_i; + } + } + c_lines.0[0][0] = self.tangent_compute(native, copy_q_acc); + c_lines + } + pub fn double_and_add_step>( + &mut self, + native: &mut B, + p1: G2AffP, + p2: &G2AffP, + ) -> ( + G2AffP, + Option>, + Option>, + ) { + let n = self.ext12.ext6.ext2.sub(native, &p1.y, &p2.y); + let d = self.ext12.ext6.ext2.sub(native, &p1.x, &p2.x); + let λ1 = self.ext12.ext6.ext2.div(native, &n, &d); + + let xr = self.ext12.ext6.ext2.square(native, &λ1); + let tmp = self.ext12.ext6.ext2.add(native, &p1.x, &p2.x); + let xr = self.ext12.ext6.ext2.sub(native, &xr, &tmp); + + let r0 = λ1.my_clone(); + let mut r1 = self.ext12.ext6.ext2.mul(native, &λ1, &p1.x); + r1 = self.ext12.ext6.ext2.sub(native, &r1, &p1.y); + + let line1 = Some(Box::new(LineEvaluation { r0, r1 })); + + let d = self.ext12.ext6.ext2.sub(native, &xr, &p1.x); + let n = self.ext12.ext6.ext2.double(native, &p1.y); + let λ2 = self.ext12.ext6.ext2.div(native, &n, &d); + let λ2 = self.ext12.ext6.ext2.add(native, &λ2, &λ1); + let λ2 = self.ext12.ext6.ext2.neg(native, &λ2); + + let x4 = self.ext12.ext6.ext2.square(native, &λ2); + let tmp = self.ext12.ext6.ext2.add(native, &p1.x, &xr); + let x4 = self.ext12.ext6.ext2.sub(native, &x4, &tmp); + + let y4 = self.ext12.ext6.ext2.sub(native, &p1.x, &x4); + let y4 = self.ext12.ext6.ext2.mul(native, &λ2, &y4); + let y4 = self.ext12.ext6.ext2.sub(native, &y4, &p1.y); + + let p = G2AffP { x: x4, y: y4 }; + + let r0 = λ2.my_clone(); + let mut r1 = self.ext12.ext6.ext2.mul(native, &λ2, &p1.x); + r1 = self.ext12.ext6.ext2.sub(native, &r1, &p1.y); + + let line2 = Some(Box::new(LineEvaluation { r0, r1 })); + + (p, line1, line2) + } + pub fn double_step>( + &mut self, + native: &mut B, + p1: G2AffP, + ) -> (G2AffP, Option>) { + let n = self.ext12.ext6.ext2.square(native, &p1.x); + let three = BigInt::from(3); + let n = self + .ext12 + .ext6 + .ext2 + .mul_by_const_element(native, &n, &three); + let d = self.ext12.ext6.ext2.double(native, &p1.y); + let λ = self.ext12.ext6.ext2.div(native, &n, &d); + + let xr = self.ext12.ext6.ext2.square(native, &λ); + let tmp = self + .ext12 + .ext6 + .ext2 + .mul_by_const_element(native, &p1.x, &BigInt::from(2)); + let xr = self.ext12.ext6.ext2.sub(native, &xr, &tmp); + + let pxr = self.ext12.ext6.ext2.sub(native, &p1.x, &xr); + let λpxr = self.ext12.ext6.ext2.mul(native, &λ, &pxr); + let yr = self.ext12.ext6.ext2.sub(native, &λpxr, &p1.y); + + let res = G2AffP { x: xr, y: yr }; + + let r0 = λ.my_clone(); + let mut r1 = self.ext12.ext6.ext2.mul(native, &λ, &p1.x); + r1 = self.ext12.ext6.ext2.sub(native, &r1, &p1.y); + + let line = Some(Box::new(LineEvaluation { r0, r1 })); + + (res, line) + } + pub fn triple_step>( + &mut self, + native: &mut B, + p1: G2AffP, + ) -> ( + G2AffP, + Option>, + Option>, + ) { + let n = self.ext12.ext6.ext2.square(native, &p1.x); + let three = BigInt::from(3); + let n = self + .ext12 + .ext6 + .ext2 + .mul_by_const_element(native, &n, &three); + let d = self.ext12.ext6.ext2.double(native, &p1.y); + let λ1 = self.ext12.ext6.ext2.div(native, &n, &d); + + let r0 = λ1.my_clone(); + let mut r1 = self.ext12.ext6.ext2.mul(native, &λ1, &p1.x); + r1 = self.ext12.ext6.ext2.sub(native, &r1, &p1.y); + + let line1 = Some(Box::new(LineEvaluation { r0, r1 })); + + let x2 = self.ext12.ext6.ext2.square(native, &λ1); + let tmp = self + .ext12 + .ext6 + .ext2 + .mul_by_const_element(native, &p1.x, &BigInt::from(2)); + let x2 = self.ext12.ext6.ext2.sub(native, &x2, &tmp); + + let x1x2 = self.ext12.ext6.ext2.sub(native, &p1.x, &x2); + let λ2 = self.ext12.ext6.ext2.div(native, &d, &x1x2); + let λ2 = self.ext12.ext6.ext2.sub(native, &λ2, &λ1); + + let r0 = λ2.my_clone(); + let mut r1 = self.ext12.ext6.ext2.mul(native, &λ2, &p1.x); + r1 = self.ext12.ext6.ext2.sub(native, &r1, &p1.y); + + let line2 = Some(Box::new(LineEvaluation { r0, r1 })); + + let λ2λ2 = self.ext12.ext6.ext2.mul(native, &λ2, &λ2); + let qxrx = self.ext12.ext6.ext2.add(native, &x2, &p1.x); + let xr = self.ext12.ext6.ext2.sub(native, &λ2λ2, &qxrx); + + let pxrx = self.ext12.ext6.ext2.sub(native, &p1.x, &xr); + let λ2pxrx = self.ext12.ext6.ext2.mul(native, &λ2, &pxrx); + let yr = self.ext12.ext6.ext2.sub(native, &λ2pxrx, &p1.y); + + let res = G2AffP { x: xr, y: yr }; + + (res, line1, line2) + } + pub fn tangent_compute>( + &mut self, + native: &mut B, + p1: G2AffP, + ) -> Option> { + let n = self.ext12.ext6.ext2.square(native, &p1.x); + let three = BigInt::from(3); + let n = self + .ext12 + .ext6 + .ext2 + .mul_by_const_element(native, &n, &three); + let d = self.ext12.ext6.ext2.double(native, &p1.y); + let λ = self.ext12.ext6.ext2.div(native, &n, &d); + + let r0 = λ.my_clone(); + let mut r1 = self.ext12.ext6.ext2.mul(native, &λ, &p1.x); + r1 = self.ext12.ext6.ext2.sub(native, &r1, &p1.y); + + Some(Box::new(LineEvaluation { r0, r1 })) + } + pub fn copy_g2_aff_p>( + &mut self, + native: &mut B, + q: &G2AffP, + ) -> G2AffP { + let copy_q_acc_x = self.ext12.ext6.ext2.copy(native, &q.x); + let copy_q_acc_y = self.ext12.ext6.ext2.copy(native, &q.y); + G2AffP { + x: copy_q_acc_x, + y: copy_q_acc_y, + } + } +} diff --git a/circuit-std-rs/src/gnark/field.rs b/circuit-std-rs/src/gnark/field.rs new file mode 100644 index 00000000..7c50a9e2 --- /dev/null +++ b/circuit-std-rs/src/gnark/field.rs @@ -0,0 +1,663 @@ +use crate::gnark::element::*; +use crate::gnark::emparam::FieldParams; +use crate::gnark::utils::*; +use crate::logup::LogUpRangeProofTable; +use crate::utils::simple_select; +use expander_compiler::frontend::*; +use num_bigint::BigInt; +use num_traits::Signed; +use num_traits::ToPrimitive; +use num_traits::Zero; +use std::collections::HashMap; + +pub struct MulCheck { + a: Element, + b: Element, + r: Element, + k: Element, + c: Element, + p: Element, +} +impl MulCheck { + pub fn eval_round1>(&mut self, native: &mut B, at: Vec) { + self.c = eval_with_challenge(native, self.c.my_clone(), at.clone()); + self.r = eval_with_challenge(native, self.r.my_clone(), at.clone()); + self.k = eval_with_challenge(native, self.k.my_clone(), at.clone()); + if !self.p.is_empty() { + self.p = eval_with_challenge(native, self.p.my_clone(), at.clone()); + } + } + pub fn eval_round2>(&mut self, native: &mut B, at: Vec) { + self.a = eval_with_challenge(native, self.a.my_clone(), at.clone()); + self.b = eval_with_challenge(native, self.b.my_clone(), at.clone()); + } + pub fn check>(&self, native: &mut B, pval: Variable, ccoef: Variable) { + let mut new_peval = pval; + if !self.p.is_empty() { + new_peval = self.p.evaluation + }; + let ls = native.mul(self.a.evaluation, self.b.evaluation); + let rs_tmp1 = native.mul(new_peval, self.k.evaluation); + let rs_tmp2 = native.mul(self.c.evaluation, ccoef); + let rs_tmp3 = native.add(self.r.evaluation, rs_tmp1); + let rs = native.add(rs_tmp3, rs_tmp2); + native.assert_is_equal(ls, rs); + } + pub fn clean_evaluations(&mut self) { + self.a.evaluation = Variable::default(); + self.a.is_evaluated = false; + self.b.evaluation = Variable::default(); + self.b.is_evaluated = false; + self.r.evaluation = Variable::default(); + self.r.is_evaluated = false; + self.k.evaluation = Variable::default(); + self.k.is_evaluated = false; + self.c.evaluation = Variable::default(); + self.c.is_evaluated = false; + self.p.evaluation = Variable::default(); + self.p.is_evaluated = false; + } +} +pub struct GField { + _f_params: T, + max_of: u32, + n_const: Element, + nprev_const: Element, + pub zero_const: Element, + pub one_const: Element, + short_one_const: Element, + constrained_limbs: HashMap, + pub table: LogUpRangeProofTable, + //checker: Box, we use lookup rangeproof instead + mul_checks: Vec>, +} + +impl GField { + pub fn new>(native: &mut B, f_params: T) -> Self { + let mut field = GField { + _f_params: f_params, + max_of: 30 - 2 - T::bits_per_limb(), + n_const: Element::::my_default(), + nprev_const: Element::::my_default(), + zero_const: Element::::my_default(), + one_const: Element::::my_default(), + short_one_const: Element::::my_default(), + constrained_limbs: HashMap::new(), + table: LogUpRangeProofTable::new(8), + mul_checks: Vec::new(), + }; + field.n_const = value_of::(native, Box::new(T::modulus())); + field.nprev_const = value_of::(native, Box::new(T::modulus() - 1)); + field.zero_const = value_of::(native, Box::new(0)); + field.one_const = value_of::(native, Box::new(1)); + field.short_one_const = new_internal_element::(vec![native.constant(1); 1], 0); + field.table.initial(native); + field + } + pub fn max_overflow(&self) -> u64 { + 30 - 2 - 8 + } + pub fn is_zero>( + &mut self, + native: &mut B, + a: &Element, + ) -> Variable { + let ca = self.reduce(native, a, false); + let mut res0; + let total_overflow = ca.limbs.len() as i32 - 1; + if total_overflow > self.max_overflow() as i32 { + res0 = native.is_zero(ca.limbs[0]); + for i in 1..ca.limbs.len() { + let tmp = native.is_zero(ca.limbs[i]); + res0 = native.mul(res0, tmp); + } + } else { + let mut limb_sum = ca.limbs[0]; + for i in 1..ca.limbs.len() { + limb_sum = native.add(limb_sum, ca.limbs[i]); + } + res0 = native.is_zero(limb_sum); + } + res0 + } + pub fn select>( + &mut self, + native: &mut B, + selector: Variable, + a: &Element, + b: &Element, + ) -> Element { + self.enforce_width_conditional(native, &a.my_clone()); + self.enforce_width_conditional(native, &b.my_clone()); + let overflow = std::cmp::max(a.overflow, b.overflow); + let nb_limbs = std::cmp::max(a.limbs.len(), b.limbs.len()); + let mut limbs = vec![native.constant(0); nb_limbs]; + let mut normalize = |limbs: Vec| -> Vec { + if limbs.len() < nb_limbs { + let mut tail = vec![native.constant(0); nb_limbs - limbs.len()]; + for cur_tail in &mut tail { + *cur_tail = native.constant(0); + } + return limbs.iter().chain(tail.iter()).cloned().collect(); + } + limbs + }; + let a_norm_limbs = normalize(a.limbs.clone()); + let b_norm_limbs = normalize(b.limbs.clone()); + for i in 0..limbs.len() { + limbs[i] = simple_select(native, selector, a_norm_limbs[i], b_norm_limbs[i]); + } + new_internal_element::(limbs, overflow) + } + pub fn enforce_width_conditional>( + &mut self, + native: &mut B, + a: &Element, + ) -> bool { + let mut did_constrain = false; + if a.internal { + return false; + } + for i in 0..a.limbs.len() { + let value_id = a.limbs[i].id(); + if let std::collections::hash_map::Entry::Vacant(e) = + self.constrained_limbs.entry(value_id) + { + e.insert(()); + } else { + did_constrain = true; + } + } + self.enforce_width(native, a, true); + did_constrain + } + pub fn enforce_width>( + &mut self, + native: &mut B, + a: &Element, + mod_width: bool, + ) { + for i in 0..a.limbs.len() { + let mut limb_nb_bits = T::bits_per_limb() as u64; + if mod_width && i == a.limbs.len() - 1 { + limb_nb_bits = ((T::modulus().bits() - 1) % T::bits_per_limb() as u64) + 1; + } + //range check + if limb_nb_bits > 8 { + self.table + .rangeproof(native, a.limbs[i], limb_nb_bits as usize); + } else { + self.table + .rangeproof_onechunk(native, a.limbs[i], limb_nb_bits as usize); + } + } + } + pub fn wrap_hint>( + &self, + native: &mut B, + nonnative_inputs: Vec>, + ) -> Vec { + let mut res = vec![ + native.constant(T::bits_per_limb()), + native.constant(T::nb_limbs()), + ]; + res.extend(self.n_const.limbs.clone()); + res.push(native.constant(nonnative_inputs.len() as u32)); + for nonnative_input in &nonnative_inputs { + res.push(native.constant(nonnative_input.limbs.len() as u32)); + res.extend(nonnative_input.limbs.clone()); + } + res + } + pub fn new_hint>( + &mut self, + native: &mut B, + hf_name: &str, + nb_outputs: usize, + inputs: Vec>, + ) -> Vec> { + let native_inputs = self.wrap_hint(native, inputs); + let nb_native_outputs = T::nb_limbs() as usize * nb_outputs; + let native_outputs = native.new_hint(hf_name, &native_inputs, nb_native_outputs); + let mut outputs = vec![]; + for i in 0..nb_outputs { + let tmp_output = self.pack_limbs( + native, + native_outputs[i * T::nb_limbs() as usize..(i + 1) * T::nb_limbs() as usize] + .to_vec(), + true, + ); + outputs.push(tmp_output); + } + outputs + } + pub fn pack_limbs>( + &mut self, + native: &mut B, + limbs: Vec, + strict: bool, + ) -> Element { + let e = new_internal_element::(limbs, 0); + self.enforce_width(native, &e, strict); + e + } + pub fn reduce>( + &mut self, + native: &mut B, + a: &Element, + strict: bool, + ) -> Element { + self.enforce_width_conditional(native, a); + if a.mod_reduced { + return a.my_clone(); + } + if !strict && a.overflow == 0 { + return a.my_clone(); + } + let p = Element::::my_default(); + let one = self.one_const.my_clone(); + self.mul_mod(native, a, &one, 0, &p).my_clone() + } + pub fn mul_mod>( + &mut self, + native: &mut B, + a: &Element, + b: &Element, + _: usize, + p: &Element, + ) -> Element { + self.enforce_width_conditional(native, a); + self.enforce_width_conditional(native, b); + let (k, r, c) = self.call_mul_hint(native, a, b, true); + let mc = MulCheck { + a: a.my_clone(), + b: b.my_clone(), + c, + k, + r: r.my_clone(), + p: p.my_clone(), + }; + self.mul_checks.push(mc); + r + } + pub fn mul_pre_cond(&self, a: &Element, b: &Element) -> u32 { + let nb_res_limbs = nb_multiplication_res_limbs(a.limbs.len(), b.limbs.len()); + let nb_limbs_overflow = if nb_res_limbs > 0 { + (nb_res_limbs as f64).log2().ceil() as u32 + } else { + 1 + }; + T::bits_per_limb() + nb_limbs_overflow + a.overflow + b.overflow + } + pub fn call_mul_hint>( + &mut self, + native: &mut B, + a: &Element, + b: &Element, + is_mul_mod: bool, + ) -> (Element, Element, Element) { + let next_overflow = self.mul_pre_cond(a, b); + let next_overflow = if !is_mul_mod { + a.overflow + } else { + next_overflow + }; + let nb_limbs = T::nb_limbs() as usize; + let nb_bits = T::bits_per_limb() as usize; + let modbits = T::modulus().bits() as usize; + let a_limbs_len = a.limbs.len(); + let b_limbs_len = b.limbs.len(); + let nb_quo_limbs = (nb_multiplication_res_limbs(a_limbs_len, b_limbs_len) * nb_bits + + next_overflow as usize + + 1 + - modbits + + nb_bits + - 1) + / nb_bits; + let nb_rem_limbs = nb_limbs; + let nb_carry_limbs = std::cmp::max( + nb_multiplication_res_limbs(a_limbs_len, b_limbs_len), + nb_multiplication_res_limbs(nb_quo_limbs, nb_limbs), + ) - 1; + let mut hint_inputs = vec![ + native.constant(nb_bits as u32), + native.constant(nb_limbs as u32), + native.constant(a.limbs.len() as u32), + native.constant(nb_quo_limbs as u32), + ]; + let modulus_limbs = self.n_const.limbs.clone(); + hint_inputs.extend(modulus_limbs); + hint_inputs.extend(a.limbs.clone()); + hint_inputs.extend(b.limbs.clone()); + let ret = native.new_hint( + "myhint.mulhint", + &hint_inputs, + nb_quo_limbs + nb_rem_limbs + nb_carry_limbs, + ); + let quo = self.pack_limbs(native, ret[..nb_quo_limbs].to_vec(), false); + let rem = if is_mul_mod { + self.pack_limbs( + native, + ret[nb_quo_limbs..nb_quo_limbs + nb_rem_limbs].to_vec(), + true, + ) + } else { + Element::my_default() + }; + let carries = new_internal_element::(ret[nb_quo_limbs + nb_rem_limbs..].to_vec(), 0); + (quo, rem, carries) + } + pub fn check_zero>( + &mut self, + native: &mut B, + a: Element, + p: Option>, + ) { + self.enforce_width_conditional(native, &a.my_clone()); + let b = self.short_one_const.my_clone(); + let (k, r, c) = self.call_mul_hint(native, &a, &b, false); + let mc = MulCheck { + a, + b, + c, + k, + r: r.my_clone(), + p: p.unwrap_or(Element::::my_default()), + }; + self.mul_checks.push(mc); + } + pub fn assert_isequal>( + &mut self, + native: &mut B, + a: &Element, + b: &Element, + ) { + self.enforce_width_conditional(native, a); + self.enforce_width_conditional(native, b); + let diff = self.sub(native, b, a); + self.check_zero(native, diff, None); + } + pub fn add>( + &mut self, + native: &mut B, + a: &Element, + b: &Element, + ) -> Element { + self.enforce_width_conditional(native, &a.my_clone()); + self.enforce_width_conditional(native, &b.my_clone()); + let mut new_a = a.my_clone(); + let mut new_b = b.my_clone(); + if a.overflow + 1 > self.max_of { + new_a = self.reduce(native, a, false); + } + if b.overflow + 1 > self.max_of { + new_b = self.reduce(native, b, false); + } + let next_overflow = std::cmp::max(new_a.overflow, new_b.overflow) + 1; + let nb_limbs = std::cmp::max(new_a.limbs.len(), new_b.limbs.len()); + let mut limbs = vec![native.constant(0); nb_limbs]; + for (i, limb) in limbs.iter_mut().enumerate() { + if i < new_a.limbs.len() { + *limb = native.add(*limb, new_a.limbs[i]); + } + if i < new_b.limbs.len() { + *limb = native.add(*limb, new_b.limbs[i]); + } + } + new_internal_element::(limbs, next_overflow) + } + pub fn sub>( + &mut self, + native: &mut B, + a: &Element, + b: &Element, + ) -> Element { + self.enforce_width_conditional(native, &a.my_clone()); + self.enforce_width_conditional(native, &b.my_clone()); + let mut new_a = a.my_clone(); + let mut new_b = b.my_clone(); + if a.overflow + 1 > self.max_of { + new_a = self.reduce(native, a, false); + } + if b.overflow + 2 > self.max_of { + new_b = self.reduce(native, b, false); + } + let next_overflow = std::cmp::max(new_a.overflow, new_b.overflow + 1) + 1; + let nb_limbs = std::cmp::max(new_a.limbs.len(), new_b.limbs.len()); + let pad_limbs = sub_padding( + &T::modulus(), + T::bits_per_limb(), + new_b.overflow, + nb_limbs as u32, + ); + let mut limbs = vec![native.constant(0); nb_limbs]; + for i in 0..limbs.len() { + limbs[i] = native.constant(pad_limbs[i].to_u64().unwrap() as u32); + if i < new_a.limbs.len() { + limbs[i] = native.add(limbs[i], new_a.limbs[i]); + } + if i < new_b.limbs.len() { + limbs[i] = native.sub(limbs[i], new_b.limbs[i]); + } + } + new_internal_element::(limbs, next_overflow) + } + pub fn neg>(&mut self, native: &mut B, a: &Element) -> Element { + let zero = self.zero_const.my_clone(); + self.sub(native, &zero, a) + } + pub fn mul>( + &mut self, + native: &mut B, + a: &Element, + b: &Element, + ) -> Element { + self.enforce_width_conditional(native, a); + self.enforce_width_conditional(native, b); + + //calculate a*b's overflow and reduce if necessary + let mut next_overflow = self.mul_pre_cond(a, b); + let mut new_a = a.my_clone(); + let mut new_b = b.my_clone(); + if next_overflow > self.max_of { + if a.overflow < b.overflow { + new_b = self.reduce(native, b, false); + } else { + new_a = self.reduce(native, a, false); + } + } + next_overflow = self.mul_pre_cond(&new_a, &new_b); + if next_overflow > self.max_of { + if new_a.overflow < new_b.overflow { + new_b = self.reduce(native, &new_b, false); + } else { + new_a = self.reduce(native, &new_a, false); + } + } + + //calculate a*b + self.mul_mod(native, &new_a, &new_b, 0, &Element::::my_default()) + } + pub fn div>( + &mut self, + native: &mut B, + a: &Element, + b: &Element, + ) -> Element { + self.enforce_width_conditional(native, a); + self.enforce_width_conditional(native, b); + //calculate a/b's overflow and reduce if necessary + let zero_element = self.zero_const.my_clone(); + let mut mul_of = self.mul_pre_cond(&zero_element, b); + let mut new_a = a.my_clone(); + let mut new_b = b.my_clone(); + if mul_of > self.max_of { + new_b = self.reduce(native, &new_b, false); + mul_of = 0; + } + if new_a.overflow + 1 > self.max_of { + new_a = self.reduce(native, &new_a, false); + } + if mul_of + 2 > self.max_of { + new_b = self.reduce(native, &new_b, false); + } + + //calculate a/b + let div = self.compute_division_hint(native, a.limbs.clone(), b.limbs.clone()); + let e = self.pack_limbs(native, div, true); + let res = self.mul(native, &e, &new_b); + self.assert_isequal(native, &res, &new_a); + e + } + /* + mulOf, err := f.mulPreCond(a, &Element[T]{Limbs: make([]frontend.Variable, f.fParams.NbLimbs()), overflow: 0}) // order is important, we want that reduce left side + if err != nil { + return mulOf, err + } + return f.subPreCond(&Element[T]{overflow: 0}, &Element[T]{overflow: mulOf}) + */ + pub fn inverse>( + &mut self, + native: &mut B, + b: &Element, + ) -> Element { + self.enforce_width_conditional(native, b); + //calculate 1/b's overflow and reduce if necessary + let zero_element = self.zero_const.my_clone(); + let mut mul_of = self.mul_pre_cond(&zero_element, b); + let mut new_b = b.my_clone(); + if mul_of > self.max_of { + new_b = self.reduce(native, &new_b, false); + mul_of = 0; + } + if mul_of + 2 > self.max_of { + new_b = self.reduce(native, &new_b, false); + } + // let next_overflow = std::cmp::max(new_a.overflow, new_b.overflow+1) + 1; + + //calculate 1/b + let inv = self.compute_inverse_hint(native, b.limbs.clone()); + let e = self.pack_limbs(native, inv, true); + let res = self.mul(native, &e, &new_b); + let one = self.one_const.my_clone(); + self.assert_isequal(native, &res, &one); + e + } + pub fn compute_inverse_hint>( + &mut self, + native: &mut B, + in_limbs: Vec, + ) -> Vec { + let mut hint_inputs = vec![ + native.constant(T::bits_per_limb()), + native.constant(T::nb_limbs()), + ]; + let modulus_limbs = self.n_const.limbs.clone(); + hint_inputs.extend(modulus_limbs); + hint_inputs.extend(in_limbs); + native.new_hint("myhint.invhint", &hint_inputs, T::nb_limbs() as usize) + } + pub fn compute_division_hint>( + &mut self, + native: &mut B, + nom_limbs: Vec, + denom_limbs: Vec, + ) -> Vec { + let mut hint_inputs = vec![ + native.constant(T::bits_per_limb()), + native.constant(T::nb_limbs()), + native.constant(denom_limbs.len() as u32), + native.constant(nom_limbs.len() as u32), + ]; + let modulus_limbs = self.n_const.limbs.clone(); + hint_inputs.extend(modulus_limbs); + hint_inputs.extend(nom_limbs); + hint_inputs.extend(denom_limbs); + native.new_hint("myhint.divhint", &hint_inputs, T::nb_limbs() as usize) + } + pub fn mul_const>( + &mut self, + native: &mut B, + a: &Element, + c: BigInt, + ) -> Element { + if c.is_negative() { + let neg_a = self.neg(native, a); + return self.mul_const(native, &neg_a, -c); + } else if c.is_zero() { + return self.zero_const.my_clone(); + } + let cbl = c.bits(); + if cbl > self.max_overflow() { + panic!( + "constant bit length {} exceeds max {}", + cbl, + self.max_overflow() + ); + } + let next_overflow = a.overflow + cbl as u32; + let mut new_a = a.my_clone(); + if next_overflow > self.max_of { + new_a = self.reduce(native, a, false); + } + let mut limbs = vec![native.constant(0); new_a.limbs.len()]; + for i in 0..new_a.limbs.len() { + limbs[i] = native.mul(new_a.limbs[i], c.to_u64().unwrap() as u32); + } + new_internal_element::(limbs, new_a.overflow + cbl as u32) + } + pub fn check_mul>(&mut self, native: &mut B) { + let commitment = native.get_random_value(); + // let commitment = native.constant(1); //TBD + let mut coefs_len = T::nb_limbs() as usize; + for i in 0..self.mul_checks.len() { + coefs_len = std::cmp::max(coefs_len, self.mul_checks[i].a.limbs.len()); + coefs_len = std::cmp::max(coefs_len, self.mul_checks[i].b.limbs.len()); + coefs_len = std::cmp::max(coefs_len, self.mul_checks[i].c.limbs.len()); + coefs_len = std::cmp::max(coefs_len, self.mul_checks[i].k.limbs.len()); + } + let mut at = vec![commitment; coefs_len]; + for i in 1..at.len() { + at[i] = native.mul(at[i - 1], commitment); + } + for i in 0..self.mul_checks.len() { + self.mul_checks[i].eval_round1(native, at.clone()); + } + for i in 0..self.mul_checks.len() { + self.mul_checks[i].eval_round2(native, at.clone()); + } + let pval = eval_with_challenge(native, self.n_const.my_clone(), at.clone()); + let coef = BigInt::from(1) << T::bits_per_limb(); + let ccoef = native.sub(coef.to_u64().unwrap() as u32, commitment); + for i in 0..self.mul_checks.len() { + self.mul_checks[i].check(native, pval.evaluation, ccoef); + } + for i in 0..self.mul_checks.len() { + self.mul_checks[i].clean_evaluations(); + } + } +} +pub fn eval_with_challenge, T: FieldParams>( + native: &mut B, + a: Element, + at: Vec, +) -> Element { + if a.is_evaluated { + return a; + } + if (at.len() as i64) < (a.limbs.len() as i64) - 1 { + panic!("evaluation powers less than limbs"); + } + let mut sum = native.constant(0); + if !a.limbs.is_empty() { + sum = native.mul(a.limbs[0], 1); + } + for i in 1..a.limbs.len() { + let tmp = native.mul(a.limbs[i], at[i - 1]); + sum = native.add(sum, tmp); + } + let mut ret = a.my_clone(); + ret.is_evaluated = true; + ret.evaluation = sum; + ret +} diff --git a/circuit-std-rs/src/gnark/hints.rs b/circuit-std-rs/src/gnark/hints.rs new file mode 100644 index 00000000..1b4b25e3 --- /dev/null +++ b/circuit-std-rs/src/gnark/hints.rs @@ -0,0 +1,1188 @@ +use crate::gnark::limbs::*; +use crate::gnark::utils::*; +use crate::logup::{query_count_by_key_hint, query_count_hint, rangeproof_hint}; +use crate::sha256::m31_utils::to_binary_hint; +use ark_bls12_381::Fq; +use ark_bls12_381::Fq12; +use ark_bls12_381::Fq2; +use ark_bls12_381::Fq6; +use ark_ff::fields::Field; +use ark_ff::Zero; +use expander_compiler::frontend::extra::*; +use expander_compiler::frontend::*; +use num_bigint::BigInt; +use num_bigint::BigUint; +use num_traits::One; +use num_traits::Signed; +use num_traits::ToPrimitive; +use std::str::FromStr; + +pub fn register_hint(hint_registry: &mut HintRegistry) { + hint_registry.register("myhint.tobinary", to_binary_hint); + hint_registry.register("myhint.mulhint", mul_hint); + hint_registry.register("myhint.simple_rangecheck_hint", simple_rangecheck_hint); + hint_registry.register("myhint.querycounthint", query_count_hint); + hint_registry.register("myhint.querycountbykeyhint", query_count_by_key_hint); + hint_registry.register("myhint.copyvarshint", copy_vars_hint); + hint_registry.register("myhint.divhint", div_hint); + hint_registry.register("myhint.invhint", inv_hint); + hint_registry.register("myhint.dive2hint", div_e2_hint); + hint_registry.register("myhint.inversee2hint", inverse_e2_hint); + hint_registry.register("myhint.copye2hint", copy_e2_hint); + hint_registry.register("myhint.dive6hint", div_e6_hint); + hint_registry.register("myhint.inversee6hint", inverse_e6_hint); + hint_registry.register("myhint.dive6by6hint", div_e6_by_6_hint); + hint_registry.register("myhint.dive12hint", div_e12_hint); + hint_registry.register("myhint.inversee12hint", inverse_e12_hint); + hint_registry.register("myhint.copye12hint", copy_e12_hint); + hint_registry.register("myhint.finalexphint", final_exp_hint); + hint_registry.register("myhint.rangeproofhint", rangeproof_hint); +} +pub fn mul_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + let nb_bits = inputs[0].to_u256().as_usize(); + let nb_limbs = inputs[1].to_u256().as_usize(); + let nb_a_len = inputs[2].to_u256().as_usize(); + let nb_quo_len = inputs[3].to_u256().as_usize(); + let nb_b_len = inputs.len() - 4 - nb_limbs - nb_a_len; + let mut ptr = 4; + let plimbs_m31 = &inputs[ptr..ptr + nb_limbs]; + let plimbs_u32: Vec = (0..nb_limbs) + .map(|i| plimbs_m31[i].to_u256().as_u32()) + .collect(); + let plimbs: Vec = plimbs_u32.iter().map(|x| BigInt::from(*x)).collect(); + ptr += nb_limbs; + let alimbs_m31 = &inputs[ptr..ptr + nb_a_len]; + let alimbs_u32: Vec = (0..nb_a_len) + .map(|i| alimbs_m31[i].to_u256().as_u32()) + .collect(); + let alimbs: Vec = alimbs_u32.iter().map(|x| BigInt::from(*x)).collect(); + ptr += nb_a_len; + let blimbs_m31 = &inputs[ptr..ptr + nb_b_len]; + let blimbs_u32: Vec = (0..nb_b_len) + .map(|i| blimbs_m31[i].to_u256().as_u32()) + .collect(); + let blimbs: Vec = blimbs_u32.iter().map(|x| BigInt::from(*x)).collect(); + + let nb_carry_len = std::cmp::max( + nb_multiplication_res_limbs(nb_a_len, nb_b_len), + nb_multiplication_res_limbs(nb_quo_len, nb_limbs), + ) - 1; + + let p = recompose(plimbs.clone(), nb_bits as u32); + let a = recompose(alimbs.clone(), nb_bits as u32); + let b = recompose(blimbs.clone(), nb_bits as u32); + + let ab = a.clone() * b.clone(); + let quo = ab.clone() / p.clone(); + let rem = ab.clone() % p.clone(); + let mut quo_limbs = vec![BigInt::default(); nb_quo_len]; + if let Err(err) = decompose(&quo, nb_bits as u32, &mut quo_limbs) { + panic!("decompose value: {}", err); + } + let mut rem_limbs = vec![BigInt::default(); nb_limbs]; + if let Err(err) = decompose(&rem, nb_bits as u32, &mut rem_limbs) { + panic!("decompose value: {}", err); + } + let mut xp = vec![BigInt::default(); nb_multiplication_res_limbs(nb_a_len, nb_b_len)]; + let mut yp = vec![BigInt::default(); nb_multiplication_res_limbs(nb_quo_len, nb_limbs)]; + let mut tmp; + for cur_xp in &mut xp { + *cur_xp = BigInt::default(); + } + for cur_yp in &mut yp { + *cur_yp = BigInt::default(); + } + // we know compute the schoolbook multiprecision multiplication of a*b and + // r+k*p + for i in 0..nb_a_len { + for j in 0..nb_b_len { + tmp = alimbs[i].clone(); + tmp *= &blimbs[j]; + xp[i + j] += &tmp; + } + } + for i in 0..nb_limbs { + yp[i] += &rem_limbs[i]; + for j in 0..nb_quo_len { + tmp = quo_limbs[j].clone(); + tmp *= &plimbs[i]; + yp[i + j] += &tmp; + } + } + let mut carry = BigInt::default(); + let mut carry_limbs = vec![BigInt::default(); nb_carry_len]; + for i in 0..carry_limbs.len() { + if i < xp.len() { + carry += &xp[i]; + } + if i < yp.len() { + carry -= &yp[i]; + } + carry >>= nb_bits as u32; + //if carry is negative, we need to add 2^nb_bits to it + carry_limbs[i] = carry.clone(); + } + //convert limbs to m31 output + let mut outptr = 0; + for i in 0..nb_quo_len { + outputs[outptr + i] = M31::from(quo_limbs[i].to_u64().unwrap() as u32); + } + outptr += nb_quo_len; + for i in 0..nb_limbs { + outputs[outptr + i] = M31::from(rem_limbs[i].to_u64().unwrap() as u32); + } + outptr += nb_limbs; + for i in 0..nb_carry_len { + if carry_limbs[i] < BigInt::default() { + outputs[outptr + i] = -M31::from(carry_limbs[i].abs().to_u64().unwrap() as u32); + } else { + outputs[outptr + i] = M31::from(carry_limbs[i].to_u64().unwrap() as u32); + } + } + Ok(()) +} +pub fn div_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + let nb_bits = inputs[0].to_u256().as_usize(); + let nb_limbs = inputs[1].to_u256().as_usize(); + let nb_denom_limbs = inputs[2].to_u256().as_usize(); + let nb_nom_limbs = inputs[3].to_u256().as_usize(); + let mut ptr = 4; + let plimbs_m31 = &inputs[ptr..ptr + nb_limbs]; + let plimbs_u32: Vec = (0..nb_limbs) + .map(|i| plimbs_m31[i].to_u256().as_u32()) + .collect(); + let plimbs: Vec = plimbs_u32.iter().map(|x| BigInt::from(*x)).collect(); + ptr += nb_limbs; + let nomlimbs_m31 = &inputs[ptr..ptr + nb_nom_limbs]; + let nomlimbs_u32: Vec = (0..nb_nom_limbs) + .map(|i| nomlimbs_m31[i].to_u256().as_u32()) + .collect(); + let nomlimbs: Vec = nomlimbs_u32.iter().map(|x| BigInt::from(*x)).collect(); + ptr += nb_nom_limbs; + let denomlimbs_m31 = &inputs[ptr..ptr + nb_denom_limbs]; + let denomlimbs_u32: Vec = (0..nb_denom_limbs) + .map(|i| denomlimbs_m31[i].to_u256().as_u32()) + .collect(); + let denomlimbs: Vec = denomlimbs_u32.iter().map(|x| BigInt::from(*x)).collect(); + + let p = recompose(plimbs.clone(), nb_bits as u32); + let nom = recompose(nomlimbs.clone(), nb_bits as u32); + let denom = recompose(denomlimbs.clone(), nb_bits as u32); + let mut res = denom.clone().modinv(&p).unwrap(); + res *= &nom; + res %= &p; + let mut res_limbs = vec![BigInt::default(); nb_limbs]; + if let Err(err) = decompose(&res, nb_bits as u32, &mut res_limbs) { + panic!("decompose value: {}", err); + } + for i in 0..nb_limbs { + outputs[i] = M31::from(res_limbs[i].to_u64().unwrap() as u32); + } + Ok(()) +} + +pub fn inv_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + let nb_bits = inputs[0].to_u256().as_usize(); + let nb_limbs = inputs[1].to_u256().as_usize(); + let mut ptr = 2; + let plimbs_m31 = &inputs[ptr..ptr + nb_limbs]; + let plimbs_u32: Vec = (0..nb_limbs) + .map(|i| plimbs_m31[i].to_u256().as_u32()) + .collect(); + let plimbs: Vec = plimbs_u32.iter().map(|x| BigInt::from(*x)).collect(); + ptr += nb_limbs; + let xlimbs_m31 = &inputs[ptr..ptr + nb_limbs]; + let xlimbs_u32: Vec = (0..nb_limbs) + .map(|i| xlimbs_m31[i].to_u256().as_u32()) + .collect(); + let xlimbs: Vec = xlimbs_u32.iter().map(|x| BigInt::from(*x)).collect(); + + let p = recompose(plimbs.clone(), nb_bits as u32); + let x = recompose(xlimbs.clone(), nb_bits as u32); + let res = x.clone().modinv(&p).unwrap(); + let mut res_limbs = vec![BigInt::default(); nb_limbs]; + if let Err(err) = decompose(&res, nb_bits as u32, &mut res_limbs) { + panic!("decompose value: {}", err); + } + for i in 0..nb_limbs { + outputs[i] = M31::from(res_limbs[i].to_u64().unwrap() as u32); + } + Ok(()) +} +pub fn div_e2_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + if let Err(err) = unwrap_hint( + true, + true, + inputs, + outputs, + //divE2Hint + |inputs| { + let biguint_inputs = inputs + .iter() + .map(|x| x.to_biguint().unwrap()) + .collect::>(); + let a = Fq2::new( + Fq::from(biguint_inputs[0].clone()), + Fq::from(biguint_inputs[1].clone()), + ); + let b = Fq2::new( + Fq::from(biguint_inputs[2].clone()), + Fq::from(biguint_inputs[3].clone()), + ); + let c = a / b; + let c0_bigint = + c.c0.to_string() + .parse::() + .expect("Invalid decimal string"); + let c1_bigint = + c.c1.to_string() + .parse::() + .expect("Invalid decimal string"); + vec![c0_bigint, c1_bigint] + }, + ) { + panic!("divE2Hint: {}", err); + } + Ok(()) +} + +pub fn inverse_e2_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + if let Err(err) = unwrap_hint( + true, + true, + inputs, + outputs, + //inverseE2Hint + |inputs| { + let biguint_inputs = inputs + .iter() + .map(|x| x.to_biguint().unwrap()) + .collect::>(); + let a = Fq2::new( + Fq::from(biguint_inputs[0].clone()), + Fq::from(biguint_inputs[1].clone()), + ); + let c = a.inverse().unwrap(); + let c0_bigint = + c.c0.to_string() + .parse::() + .expect("Invalid decimal string"); + let c1_bigint = + c.c1.to_string() + .parse::() + .expect("Invalid decimal string"); + vec![c0_bigint, c1_bigint] + }, + ) { + panic!("inverseE2Hint: {}", err); + } + Ok(()) +} + +pub fn div_e6_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + if let Err(err) = unwrap_hint( + true, + true, + inputs, + outputs, + //divE6Hint + |inputs| { + let biguint_inputs = inputs + .iter() + .map(|x| x.to_biguint().unwrap()) + .collect::>(); + let a_b0 = Fq2::new( + Fq::from(biguint_inputs[0].clone()), + Fq::from(biguint_inputs[1].clone()), + ); + let a_b1 = Fq2::new( + Fq::from(biguint_inputs[2].clone()), + Fq::from(biguint_inputs[3].clone()), + ); + let a_b2 = Fq2::new( + Fq::from(biguint_inputs[4].clone()), + Fq::from(biguint_inputs[5].clone()), + ); + let a = Fq6::new(a_b0, a_b1, a_b2); + let b_b0 = Fq2::new( + Fq::from(biguint_inputs[6].clone()), + Fq::from(biguint_inputs[7].clone()), + ); + let b_b1 = Fq2::new( + Fq::from(biguint_inputs[8].clone()), + Fq::from(biguint_inputs[9].clone()), + ); + let b_b2 = Fq2::new( + Fq::from(biguint_inputs[10].clone()), + Fq::from(biguint_inputs[11].clone()), + ); + let b = Fq6::new(b_b0, b_b1, b_b2); + let c = a / b; + let c_c0_c0_bigint = + c.c0.c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c0_c1_bigint = + c.c0.c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_c0_bigint = + c.c1.c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_c1_bigint = + c.c1.c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c2_c0_bigint = + c.c2.c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c2_c1_bigint = + c.c2.c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + + vec![ + c_c0_c0_bigint, + c_c0_c1_bigint, + c_c1_c0_bigint, + c_c1_c1_bigint, + c_c2_c0_bigint, + c_c2_c1_bigint, + ] + }, + ) { + panic!("divE6Hint: {}", err); + } + Ok(()) +} + +pub fn inverse_e6_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + if let Err(err) = unwrap_hint( + true, + true, + inputs, + outputs, + //inverseE6Hint + |inputs| { + let biguint_inputs = inputs + .iter() + .map(|x| x.to_biguint().unwrap()) + .collect::>(); + let a_b0 = Fq2::new( + Fq::from(biguint_inputs[0].clone()), + Fq::from(biguint_inputs[1].clone()), + ); + let a_b1 = Fq2::new( + Fq::from(biguint_inputs[2].clone()), + Fq::from(biguint_inputs[3].clone()), + ); + let a_b2 = Fq2::new( + Fq::from(biguint_inputs[4].clone()), + Fq::from(biguint_inputs[5].clone()), + ); + let a = Fq6::new(a_b0, a_b1, a_b2); + let c = a.inverse().unwrap(); + let c_c0_c0_bigint = + c.c0.c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c0_c1_bigint = + c.c0.c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_c0_bigint = + c.c1.c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_c1_bigint = + c.c1.c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c2_c0_bigint = + c.c2.c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c2_c1_bigint = + c.c2.c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + vec![ + c_c0_c0_bigint, + c_c0_c1_bigint, + c_c1_c0_bigint, + c_c1_c1_bigint, + c_c2_c0_bigint, + c_c2_c1_bigint, + ] + }, + ) { + panic!("inverseE6Hint: {}", err); + } + Ok(()) +} + +pub fn div_e6_by_6_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + if let Err(err) = unwrap_hint( + true, + true, + inputs, + outputs, + //divE6By6Hint + |inputs| { + let biguint_inputs = inputs + .iter() + .map(|x| x.to_biguint().unwrap()) + .collect::>(); + let a_b0 = Fq2::new( + Fq::from(biguint_inputs[0].clone()), + Fq::from(biguint_inputs[1].clone()), + ); + let a_b1 = Fq2::new( + Fq::from(biguint_inputs[2].clone()), + Fq::from(biguint_inputs[3].clone()), + ); + let a_b2 = Fq2::new( + Fq::from(biguint_inputs[4].clone()), + Fq::from(biguint_inputs[5].clone()), + ); + let mut a = Fq6::new(a_b0, a_b1, a_b2); + let six_inv = Fq::from(6u32).inverse().unwrap(); + a.c0.mul_assign_by_fp(&six_inv); + a.c1.mul_assign_by_fp(&six_inv); + a.c2.mul_assign_by_fp(&six_inv); + let c_c0_c0_bigint = + a.c0.c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c0_c1_bigint = + a.c0.c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_c0_bigint = + a.c1.c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_c1_bigint = + a.c1.c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c2_c0_bigint = + a.c2.c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c2_c1_bigint = + a.c2.c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + vec![ + c_c0_c0_bigint, + c_c0_c1_bigint, + c_c1_c0_bigint, + c_c1_c1_bigint, + c_c2_c0_bigint, + c_c2_c1_bigint, + ] + }, + ) { + panic!("divE6By6Hint: {}", err); + } + Ok(()) +} + +pub fn div_e12_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + if let Err(err) = unwrap_hint( + true, + true, + inputs, + outputs, + //divE12Hint + |inputs| { + let biguint_inputs = inputs + .iter() + .map(|x| x.to_biguint().unwrap()) + .collect::>(); + + let a_c0_b0 = Fq2::new( + Fq::from(biguint_inputs[0].clone()), + Fq::from(biguint_inputs[1].clone()), + ); + let a_c0_b1 = Fq2::new( + Fq::from(biguint_inputs[2].clone()), + Fq::from(biguint_inputs[3].clone()), + ); + let a_c0_b2 = Fq2::new( + Fq::from(biguint_inputs[4].clone()), + Fq::from(biguint_inputs[5].clone()), + ); + let a_c0 = Fq6::new(a_c0_b0, a_c0_b1, a_c0_b2); + let a_c1_b0 = Fq2::new( + Fq::from(biguint_inputs[6].clone()), + Fq::from(biguint_inputs[7].clone()), + ); + let a_c1_b1 = Fq2::new( + Fq::from(biguint_inputs[8].clone()), + Fq::from(biguint_inputs[9].clone()), + ); + let a_c1_b2 = Fq2::new( + Fq::from(biguint_inputs[10].clone()), + Fq::from(biguint_inputs[11].clone()), + ); + let a_c1 = Fq6::new(a_c1_b0, a_c1_b1, a_c1_b2); + let a = Fq12::new(a_c0, a_c1); + + let b_c0_b0 = Fq2::new( + Fq::from(biguint_inputs[12].clone()), + Fq::from(biguint_inputs[13].clone()), + ); + let b_c0_b1 = Fq2::new( + Fq::from(biguint_inputs[14].clone()), + Fq::from(biguint_inputs[15].clone()), + ); + let b_c0_b2 = Fq2::new( + Fq::from(biguint_inputs[16].clone()), + Fq::from(biguint_inputs[17].clone()), + ); + let b_c0 = Fq6::new(b_c0_b0, b_c0_b1, b_c0_b2); + let b_c1_b0 = Fq2::new( + Fq::from(biguint_inputs[18].clone()), + Fq::from(biguint_inputs[19].clone()), + ); + let b_c1_b1 = Fq2::new( + Fq::from(biguint_inputs[20].clone()), + Fq::from(biguint_inputs[21].clone()), + ); + let b_c1_b2 = Fq2::new( + Fq::from(biguint_inputs[22].clone()), + Fq::from(biguint_inputs[23].clone()), + ); + let b_c1 = Fq6::new(b_c1_b0, b_c1_b1, b_c1_b2); + let b = Fq12::new(b_c0, b_c1); + + let c = a / b; + let c_c0_b0_a0_bigint = + c.c0.c0 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c0_b0_a1_bigint = + c.c0.c0 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c0_b1_a0_bigint = + c.c0.c1 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c0_b1_a1_bigint = + c.c0.c1 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c0_b2_a0_bigint = + c.c0.c2 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c0_b2_a1_bigint = + c.c0.c2 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_b0_a0_bigint = + c.c1.c0 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_b0_a1_bigint = + c.c1.c0 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_b1_a0_bigint = + c.c1.c1 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_b1_a1_bigint = + c.c1.c1 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_b2_a0_bigint = + c.c1.c2 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_b2_a1_bigint = + c.c1.c2 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + + vec![ + c_c0_b0_a0_bigint, + c_c0_b0_a1_bigint, + c_c0_b1_a0_bigint, + c_c0_b1_a1_bigint, + c_c0_b2_a0_bigint, + c_c0_b2_a1_bigint, + c_c1_b0_a0_bigint, + c_c1_b0_a1_bigint, + c_c1_b1_a0_bigint, + c_c1_b1_a1_bigint, + c_c1_b2_a0_bigint, + c_c1_b2_a1_bigint, + ] + }, + ) { + panic!("divE12Hint: {}", err); + } + Ok(()) +} + +pub fn inverse_e12_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + if let Err(err) = unwrap_hint( + true, + true, + inputs, + outputs, + //inverseE12Hint + |inputs| { + let biguint_inputs = inputs + .iter() + .map(|x| x.to_biguint().unwrap()) + .collect::>(); + + let a_c0_b0 = Fq2::new( + Fq::from(biguint_inputs[0].clone()), + Fq::from(biguint_inputs[1].clone()), + ); + let a_c0_b1 = Fq2::new( + Fq::from(biguint_inputs[2].clone()), + Fq::from(biguint_inputs[3].clone()), + ); + let a_c0_b2 = Fq2::new( + Fq::from(biguint_inputs[4].clone()), + Fq::from(biguint_inputs[5].clone()), + ); + let a_c0 = Fq6::new(a_c0_b0, a_c0_b1, a_c0_b2); + let a_c1_b0 = Fq2::new( + Fq::from(biguint_inputs[6].clone()), + Fq::from(biguint_inputs[7].clone()), + ); + let a_c1_b1 = Fq2::new( + Fq::from(biguint_inputs[8].clone()), + Fq::from(biguint_inputs[9].clone()), + ); + let a_c1_b2 = Fq2::new( + Fq::from(biguint_inputs[10].clone()), + Fq::from(biguint_inputs[11].clone()), + ); + let a_c1 = Fq6::new(a_c1_b0, a_c1_b1, a_c1_b2); + let a = Fq12::new(a_c0, a_c1); + + let c = a.inverse().unwrap(); + let c_c0_b0_a0_bigint = + c.c0.c0 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c0_b0_a1_bigint = + c.c0.c0 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c0_b1_a0_bigint = + c.c0.c1 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c0_b1_a1_bigint = + c.c0.c1 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c0_b2_a0_bigint = + c.c0.c2 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c0_b2_a1_bigint = + c.c0.c2 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_b0_a0_bigint = + c.c1.c0 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_b0_a1_bigint = + c.c1.c0 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_b1_a0_bigint = + c.c1.c1 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_b1_a1_bigint = + c.c1.c1 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_b2_a0_bigint = + c.c1.c2 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_b2_a1_bigint = + c.c1.c2 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + + vec![ + c_c0_b0_a0_bigint, + c_c0_b0_a1_bigint, + c_c0_b1_a0_bigint, + c_c0_b1_a1_bigint, + c_c0_b2_a0_bigint, + c_c0_b2_a1_bigint, + c_c1_b0_a0_bigint, + c_c1_b0_a1_bigint, + c_c1_b1_a0_bigint, + c_c1_b1_a1_bigint, + c_c1_b2_a0_bigint, + c_c1_b2_a1_bigint, + ] + }, + ) { + panic!("inverseE12Hint: {}", err); + } + Ok(()) +} +pub fn copy_vars_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + outputs.copy_from_slice(&inputs[..outputs.len()]); + Ok(()) +} +pub fn copy_e2_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + if let Err(err) = unwrap_hint( + true, + true, + inputs, + outputs, + //copyE2Hint + |inputs| inputs, + ) { + panic!("copyE2Hint: {}", err); + } + Ok(()) +} +pub fn copy_e12_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + if let Err(err) = unwrap_hint( + true, + true, + inputs, + outputs, + //copyE12Hint + |inputs| inputs, + ) { + panic!("copyE12Hint: {}", err); + } + Ok(()) +} +pub fn final_exp_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + if let Err(err) = unwrap_hint( + true, + true, + inputs, + outputs, + //finalExpHint + |inputs| { + let biguint_inputs = inputs + .iter() + .map(|x| x.to_biguint().unwrap()) + .collect::>(); + let mut miller_loop = Fq12::default(); + miller_loop.c0.c0.c0 = Fq::from(biguint_inputs[0].clone()); + miller_loop.c0.c0.c1 = Fq::from(biguint_inputs[1].clone()); + miller_loop.c0.c1.c0 = Fq::from(biguint_inputs[2].clone()); + miller_loop.c0.c1.c1 = Fq::from(biguint_inputs[3].clone()); + miller_loop.c0.c2.c0 = Fq::from(biguint_inputs[4].clone()); + miller_loop.c0.c2.c1 = Fq::from(biguint_inputs[5].clone()); + miller_loop.c1.c0.c0 = Fq::from(biguint_inputs[6].clone()); + miller_loop.c1.c0.c1 = Fq::from(biguint_inputs[7].clone()); + miller_loop.c1.c1.c0 = Fq::from(biguint_inputs[8].clone()); + miller_loop.c1.c1.c1 = Fq::from(biguint_inputs[9].clone()); + miller_loop.c1.c2.c0 = Fq::from(biguint_inputs[10].clone()); + miller_loop.c1.c2.c1 = Fq::from(biguint_inputs[11].clone()); + + let mut root_pth_inverse = Fq12::default(); + let mut root_27th_inverse = Fq12::default(); + let order3rd; + let mut order3rd_power = BigInt::default(); + let mut exponent: BigInt; + let mut exponent_inv; + let poly_factor = + BigInt::from_str("5044125407647214251").expect("Invalid string for BigInt"); + let final_exp_factor= BigInt::from_str("2366356426548243601069753987687709088104621721678962410379583120840019275952471579477684846670499039076873213559162845121989217658133790336552276567078487633052653005423051750848782286407340332979263075575489766963251914185767058009683318020965829271737924625612375201545022326908440428522712877494557944965298566001441468676802477524234094954960009227631543471415676620753242466901942121887152806837594306028649150255258504417829961387165043999299071444887652375514277477719817175923289019181393803729926249507024121957184340179467502106891835144220611408665090353102353194448552304429530104218473070114105759487413726485729058069746063140422361472585604626055492939586602274983146215294625774144156395553405525711143696689756441298365274341189385646499074862712688473936093315628166094221735056483459332831845007196600723053356837526749543765815988577005929923802636375670820616189737737304893769679803809426304143627363860243558537831172903494450556755190448279875942974830469855835666815454271389438587399739607656399812689280234103023464545891697941661992848552456326290792224091557256350095392859243101357349751064730561345062266850238821755009430903520645523345000326783803935359711318798844368754833295302563158150573540616830138810935344206231367357992991289265295323280").expect("Invalid string for BigInt"); + exponent = &final_exp_factor * 27; + let exp_uint = exponent.to_biguint().unwrap(); + let root = miller_loop.pow(exp_uint.to_u64_digits().iter()); + if root.is_one() { + root_pth_inverse.set_one(); + } else { + exponent_inv = exponent.clone().modinv(&poly_factor).unwrap(); + if exponent_inv.abs() > poly_factor { + exponent_inv %= &poly_factor; + } + exponent = &poly_factor - exponent_inv; + exponent %= &poly_factor; + let exp_uint = exponent.to_biguint().unwrap(); + root_pth_inverse = root.pow(exp_uint.to_u64_digits().iter()); + } + + let three = BigUint::from(3u32); + exponent = &poly_factor * &final_exp_factor; + let exp_uint = exponent.to_biguint().unwrap(); + let mut root = miller_loop.pow(exp_uint.to_u64_digits().iter()); + if root.is_one() { + order3rd_power = BigInt::from(0u32); + } + root = root.pow(three.to_u64_digits().iter()); + if root.is_one() { + order3rd_power = BigInt::from(1u32); + } + root = root.pow(three.to_u64_digits().iter()); + if root.is_one() { + order3rd_power = BigInt::from(2u32); + } + root = root.pow(three.to_u64_digits().iter()); + if root.is_one() { + order3rd_power = BigInt::from(3u32); + } + + if order3rd_power.is_zero() { + root_27th_inverse.set_one(); + } else { + let three_bigint = BigInt::from(3u32); + order3rd = three_bigint.pow(order3rd_power.to_u32().unwrap()); + exponent = &poly_factor * &final_exp_factor; + let exp_uint = exponent.to_biguint().unwrap(); + root = miller_loop.pow(exp_uint.to_u64_digits().iter()); + exponent_inv = exponent.modinv(&order3rd).unwrap(); + if exponent_inv.abs() > order3rd { + exponent_inv %= &order3rd; + } + exponent = &order3rd - exponent_inv; + exponent %= &order3rd; + let exp_uint = exponent.to_biguint().unwrap(); + root_27th_inverse = root.pow(exp_uint.to_u64_digits().iter()); + } + + let scaling_factor = root_pth_inverse * root_27th_inverse; + miller_loop *= scaling_factor; + + let lambda= BigInt::from_str("4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129030796414117214202539").expect("Invalid string for BigInt"); + exponent = lambda.modinv(&final_exp_factor).unwrap(); + let residue_witness = + miller_loop.pow(exponent.to_biguint().unwrap().to_u64_digits().iter()); + + let res_c0_b0_a0_bigint = residue_witness + .c0 + .c0 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let res_c0_b0_a1_bigint = residue_witness + .c0 + .c0 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let res_c0_b1_a0_bigint = residue_witness + .c0 + .c1 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let res_c0_b1_a1_bigint = residue_witness + .c0 + .c1 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let res_c0_b2_a0_bigint = residue_witness + .c0 + .c2 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let res_c0_b2_a1_bigint = residue_witness + .c0 + .c2 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let res_c1_b0_a0_bigint = residue_witness + .c1 + .c0 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let res_c1_b0_a1_bigint = residue_witness + .c1 + .c0 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let res_c1_b1_a0_bigint = residue_witness + .c1 + .c1 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let res_c1_b1_a1_bigint = residue_witness + .c1 + .c1 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let res_c1_b2_a0_bigint = residue_witness + .c1 + .c2 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let res_c1_b2_a1_bigint = residue_witness + .c1 + .c2 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + + let sca_c0_b0_a0_bigint = scaling_factor + .c0 + .c0 + .c0 + .to_string() + .parse::() + .unwrap_or_else(|_| BigInt::zero()); + let sca_c0_b0_a1_bigint = scaling_factor + .c0 + .c0 + .c1 + .to_string() + .parse::() + .unwrap_or_else(|_| BigInt::zero()); + let sca_c0_b1_a0_bigint = scaling_factor + .c0 + .c1 + .c0 + .to_string() + .parse::() + .unwrap_or_else(|_| BigInt::zero()); + let sca_c0_b1_a1_bigint = scaling_factor + .c0 + .c1 + .c1 + .to_string() + .parse::() + .unwrap_or_else(|_| BigInt::zero()); + let sca_c0_b2_a0_bigint = scaling_factor + .c0 + .c2 + .c0 + .to_string() + .parse::() + .unwrap_or_else(|_| BigInt::zero()); + let sca_c0_b2_a1_bigint = scaling_factor + .c0 + .c2 + .c1 + .to_string() + .parse::() + .unwrap_or_else(|_| BigInt::zero()); + + vec![ + res_c0_b0_a0_bigint, + res_c0_b0_a1_bigint, + res_c0_b1_a0_bigint, + res_c0_b1_a1_bigint, + res_c0_b2_a0_bigint, + res_c0_b2_a1_bigint, + res_c1_b0_a0_bigint, + res_c1_b0_a1_bigint, + res_c1_b1_a0_bigint, + res_c1_b1_a1_bigint, + res_c1_b2_a0_bigint, + res_c1_b2_a1_bigint, + sca_c0_b0_a0_bigint, + sca_c0_b0_a1_bigint, + sca_c0_b1_a0_bigint, + sca_c0_b1_a1_bigint, + sca_c0_b2_a0_bigint, + sca_c0_b2_a1_bigint, + ] + }, + ) { + panic!("inverseE12Hint: {}", err); + } + Ok(()) +} + +pub fn simple_rangecheck_hint(inputs: &[M31], _outputs: &mut [M31]) -> Result<(), Error> { + let nb_bits = inputs[0].to_u256().as_u32(); + let number = inputs[1].to_u256().as_f64(); + let number_bit = if number > 1.0 { + number.log2().ceil() as u32 + } else { + 1 + }; + if number_bit > nb_bits { + panic!("number is out of range"); + } + + Ok(()) +} + +pub fn unwrap_hint( + is_emulated_input: bool, + is_emulated_output: bool, + native_inputs: &[M31], + native_outputs: &mut [M31], + nonnative_hint: fn(Vec) -> Vec, +) -> Result<(), String> { + if native_inputs.len() < 2 { + return Err("hint wrapper header is 2 elements".to_string()); + } + let i64_max = 1 << 63; + if native_inputs[0].to_u256() >= i64_max || native_inputs[1].to_u256() >= i64_max { + return Err("header must be castable to int64".to_string()); + } + let nb_bits = native_inputs[0].to_u256().as_u32(); + let nb_limbs = native_inputs[1].to_u256().as_usize(); + if native_inputs.len() < 2 + nb_limbs { + return Err("hint wrapper header is 2+nbLimbs elements".to_string()); + } + let nonnative_mod_limbs = + m31_to_bigint_array(native_inputs[2..2 + nb_limbs].to_vec().as_slice()); + let nonnative_mod = recompose(nonnative_mod_limbs, nb_bits); + let mut nonnative_inputs; + if is_emulated_input { + if native_inputs[2 + nb_limbs].to_u256() >= i64_max { + return Err("number of nonnative elements must be castable to int64".to_string()); + } + let nb_inputs = native_inputs[2 + nb_limbs].to_u256().as_usize(); + let mut read_ptr = 3 + nb_limbs; + nonnative_inputs = vec![BigInt::default(); nb_inputs]; + for (i, nonnative_input) in nonnative_inputs.iter_mut().enumerate().take(nb_inputs) { + if native_inputs.len() < read_ptr + 1 { + return Err(format!("can not read {}-th native input", i)); + } + if native_inputs[read_ptr].to_u256() >= i64_max { + return Err(format!("corrupted {}-th native input", i)); + } + let current_input_len = native_inputs[read_ptr].to_u256().as_usize(); + if native_inputs.len() < read_ptr + 1 + current_input_len { + return Err(format!("cannot read {}-th nonnative element", i)); + } + let tmp_inputs = m31_to_bigint_array( + native_inputs[read_ptr + 1..read_ptr + 1 + current_input_len] + .to_vec() + .as_slice(), + ); + *nonnative_input = recompose(tmp_inputs, nb_bits); + read_ptr += 1 + current_input_len; + } + } else { + let nb_inputs = native_inputs[2 + nb_limbs..].len(); + let read_ptr = 2 + nb_limbs; + nonnative_inputs = vec![BigInt::default(); nb_inputs]; + for i in 0..nb_inputs { + nonnative_inputs[i] = m31_to_bigint(native_inputs[read_ptr + i]); + } + } + let nonnative_outputs = nonnative_hint(nonnative_inputs); + let mut tmp_outputs = vec![BigInt::default(); nb_limbs * nonnative_outputs.len()]; + if is_emulated_output { + if native_outputs.len() % nb_limbs != 0 { + return Err("output count doesn't divide limb count".to_string()); + } + for i in 0..nonnative_outputs.len() { + let mod_output = &nonnative_outputs[i] % &nonnative_mod; + if let Err(e) = decompose( + &mod_output, + nb_bits, + &mut tmp_outputs[i * nb_limbs..(i + 1) * nb_limbs], + ) { + return Err(format!("decompose {}-th element: {}", i, e)); + } + } + } else { + tmp_outputs[..nonnative_outputs.len()].clone_from_slice(&nonnative_outputs[..]); + } + for i in 0..tmp_outputs.len() { + native_outputs[i] = bigint_to_m31(&tmp_outputs[i]); + } + Ok(()) +} diff --git a/circuit-std-rs/src/gnark/limbs.rs b/circuit-std-rs/src/gnark/limbs.rs new file mode 100644 index 00000000..469535b6 --- /dev/null +++ b/circuit-std-rs/src/gnark/limbs.rs @@ -0,0 +1,39 @@ +use expander_compiler::frontend::*; +use num_bigint::BigInt; +use num_traits::ToPrimitive; +pub fn recompose(inputs: Vec, nb_bits: u32) -> BigInt { + if inputs.is_empty() { + panic!("zero length slice input"); + } + let mut res = BigInt::from(0u32); + for i in 0..inputs.len() { + res <<= nb_bits; + res += &inputs[inputs.len() - i - 1]; + } + res +} +pub fn decompose(input: &BigInt, nb_bits: u32, res: &mut [BigInt]) -> Result<(), String> { + // limb modulus + if input.bits() > res.len() as u64 * nb_bits as u64 { + return Err("decomposed integer does not fit into res".to_string()); + } + let base = BigInt::from(1u32) << nb_bits; + let mut tmp = input.clone(); + for cur_res in res { + *cur_res = &tmp % &base; + tmp >>= nb_bits; + } + Ok(()) +} + +pub fn m31_to_bigint(input: M31) -> BigInt { + BigInt::from(input.to_u256().as_u32()) +} + +pub fn bigint_to_m31(input: &BigInt) -> M31 { + M31::from(input.to_u32().unwrap()) +} + +pub fn m31_to_bigint_array(input: &[M31]) -> Vec { + input.iter().map(|x| m31_to_bigint(*x)).collect() +} diff --git a/circuit-std-rs/src/gnark/mod.rs b/circuit-std-rs/src/gnark/mod.rs new file mode 100644 index 00000000..93e7b082 --- /dev/null +++ b/circuit-std-rs/src/gnark/mod.rs @@ -0,0 +1,7 @@ +pub mod element; +pub mod emparam; +pub mod emulated; +pub mod field; +pub mod hints; +pub mod limbs; +pub mod utils; diff --git a/circuit-std-rs/src/gnark/utils.rs b/circuit-std-rs/src/gnark/utils.rs new file mode 100644 index 00000000..4b09ff93 --- /dev/null +++ b/circuit-std-rs/src/gnark/utils.rs @@ -0,0 +1,61 @@ +use num_bigint::BigInt; + +use crate::gnark::element::*; +use crate::gnark::emparam::FieldParams; +use crate::gnark::emulated::field_bls12381::e2::GE2; +use crate::gnark::limbs::decompose; +use crate::gnark::limbs::recompose; +use expander_compiler::frontend::*; + +pub fn nb_multiplication_res_limbs(len_left: usize, len_right: usize) -> usize { + let res = len_left + len_right - 1; + if len_left + len_right < 1 { + 0 + } else { + res + } +} + +pub fn sub_padding( + modulus: &BigInt, + bits_per_limbs: u32, + overflow: u32, + nb_limbs: u32, +) -> Vec { + if modulus == &BigInt::default() { + panic!("modulus is zero"); + } + let mut n_limbs = vec![BigInt::default(); nb_limbs as usize]; + for n_limb in &mut n_limbs { + *n_limb = BigInt::from(1) << (overflow + bits_per_limbs); + } + let mut n = recompose(n_limbs.clone(), bits_per_limbs); + n %= modulus; + n = modulus - n; + let mut pad = vec![BigInt::default(); nb_limbs as usize]; + if let Err(err) = decompose(&n, bits_per_limbs, &mut pad) { + panic!("decompose: {}", err); + } + let mut new_pad = vec![BigInt::default(); nb_limbs as usize]; + for i in 0..pad.len() { + new_pad[i] = pad[i].clone() + n_limbs[i].clone(); + } + new_pad +} + +pub fn print_e2>(native: &mut B, v: &GE2) { + for i in 0..48 { + println!( + "{}: {:?} {:?}", + i, + native.display("", v.a0.limbs[i]), + native.display("", v.a1.limbs[i]) + ); + } +} +pub fn print_element, T: FieldParams>(native: &mut B, v: &Element) { + for i in 0..v.limbs.len() { + print!("{:?} ", native.display("", v.limbs[i])); + } + println!(" "); +} diff --git a/circuit-std-rs/src/lib.rs b/circuit-std-rs/src/lib.rs index 248446f9..3baeade3 100644 --- a/circuit-std-rs/src/lib.rs +++ b/circuit-std-rs/src/lib.rs @@ -3,3 +3,8 @@ pub use traits::StdCircuit; pub mod logup; pub use logup::{LogUpCircuit, LogUpParams}; + +pub mod gnark; +pub mod poseidon_m31; +pub mod sha256; +pub mod utils; diff --git a/circuit-std-rs/src/logup.rs b/circuit-std-rs/src/logup.rs index 25911b78..4b399d50 100644 --- a/circuit-std-rs/src/logup.rs +++ b/circuit-std-rs/src/logup.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use arith::Field; use expander_compiler::frontend::*; use rand::Rng; @@ -31,7 +33,11 @@ struct Rational { denominator: Variable, } -fn add_rational(builder: &mut API, v1: &Rational, v2: &Rational) -> Rational { +fn add_rational>( + builder: &mut B, + v1: &Rational, + v2: &Rational, +) -> Rational { let p1 = builder.mul(v1.numerator, v2.denominator); let p2 = builder.mul(v1.denominator, v2.numerator); @@ -41,13 +47,13 @@ fn add_rational(builder: &mut API, v1: &Rational, v2: &Rational) - } } -fn assert_eq_rational(builder: &mut API, v1: &Rational, v2: &Rational) { +fn assert_eq_rational>(builder: &mut B, v1: &Rational, v2: &Rational) { let p1 = builder.mul(v1.numerator, v2.denominator); let p2 = builder.mul(v1.denominator, v2.numerator); builder.assert_is_equal(p1, p2); } -fn sum_rational_vec(builder: &mut API, vs: &[Rational]) -> Rational { +fn sum_rational_vec>(builder: &mut B, vs: &[Rational]) -> Rational { if vs.is_empty() { return Rational { numerator: builder.constant(0), @@ -83,8 +89,17 @@ fn sum_rational_vec(builder: &mut API, vs: &[Rational]) -> Rationa vvs[0] } -// TODO-Feature: poly randomness -fn get_column_randomness(builder: &mut API, n_columns: usize) -> Vec { +fn concat_d1(v1: &[Vec], v2: &[Vec]) -> Vec> { + v1.iter() + .zip(v2.iter()) + .map(|(row_key, row_value)| [row_key.to_vec(), row_value.to_vec()].concat()) + .collect() +} + +fn get_column_randomness>( + builder: &mut B, + n_columns: usize, +) -> Vec { let mut randomness = vec![]; randomness.push(builder.constant(1)); for _ in 1..n_columns { @@ -93,15 +108,8 @@ fn get_column_randomness(builder: &mut API, n_columns: usize) -> V randomness } -fn concat_d1(v1: &[Vec], v2: &[Vec]) -> Vec> { - v1.iter() - .zip(v2.iter()) - .map(|(row_key, row_value)| [row_key.to_vec(), row_value.to_vec()].concat()) - .collect() -} - -fn combine_columns( - builder: &mut API, +fn combine_columns>( + builder: &mut B, vec_2d: &[Vec], randomness: &[Variable], ) -> Vec { @@ -124,8 +132,8 @@ fn combine_columns( .collect() } -fn logup_poly_val( - builder: &mut API, +fn logup_poly_val>( + builder: &mut B, vals: &[Variable], counts: &[Variable], x: &Variable, @@ -230,3 +238,253 @@ impl StdCircuit for LogUpCircuit { assignment } } + +pub struct LogUpSingleKeyTable { + pub table: Vec>, + pub query_keys: Vec, + pub query_results: Vec>, +} + +impl LogUpSingleKeyTable { + pub fn new(_nb_bits: usize) -> Self { + Self { + table: vec![], + query_keys: vec![], + query_results: vec![], + } + } + + pub fn new_table(&mut self, key: Vec, value: Vec>) { + if key.len() != value.len() { + panic!("key and value should have the same length"); + } + if !self.table.is_empty() { + panic!("table already exists"); + } + for i in 0..key.len() { + let mut entry = vec![key[i]]; + entry.extend(value[i].clone()); + self.table.push(entry); + } + } + + pub fn add_table_row(&mut self, key: Variable, value: Vec) { + let mut entry = vec![key]; + entry.extend(value.clone()); + self.table.push(entry); + } + + fn add_query(&mut self, key: Variable, value: Vec) { + let mut entry = vec![key]; + entry.extend(value.clone()); + self.query_keys.push(key); + self.query_results.push(entry); + } + + pub fn query(&mut self, key: Variable, value: Vec) { + self.add_query(key, value); + } + + pub fn batch_query(&mut self, keys: Vec, values: Vec>) { + for i in 0..keys.len() { + self.add_query(keys[i], values[i].clone()); + } + } + + pub fn final_check>(&mut self, builder: &mut B) { + if self.table.is_empty() || self.query_keys.is_empty() { + panic!("empty table or empty query"); + } + + let value_len = self.table[0].len(); + + let alpha = builder.get_random_value(); + let randomness = get_column_randomness(builder, value_len); + + let table_combined = combine_columns(builder, &self.table, &randomness); + let mut inputs = vec![builder.constant(self.table.len() as u32)]; + //append table keys + for i in 0..self.table.len() { + inputs.push(self.table[i][0]); + } + //append query keys + inputs.extend(self.query_keys.clone()); + + let query_count = builder.new_hint("myhint.querycountbykeyhint", &inputs, self.table.len()); + + let v_table = logup_poly_val(builder, &table_combined, &query_count, &alpha); + + let query_combined = combine_columns(builder, &self.query_results, &randomness); + let one = builder.constant(1); + let v_query = logup_poly_val( + builder, + &query_combined, + &vec![one; query_combined.len()], + &alpha, + ); + + assert_eq_rational(builder, &v_table, &v_query); + } +} + +pub struct LogUpRangeProofTable { + pub table_keys: Vec, + pub query_keys: Vec, + pub rangeproof_bits: usize, +} + +impl LogUpRangeProofTable { + pub fn new(nb_bits: usize) -> Self { + Self { + table_keys: vec![], + query_keys: vec![], + rangeproof_bits: nb_bits, + } + } + + pub fn initial>(&mut self, builder: &mut B) { + for i in 0..1 << self.rangeproof_bits { + let key = builder.constant(i as u32); + self.add_table_row(key); + } + } + + pub fn add_table_row(&mut self, key: Variable) { + self.table_keys.push(key); + } + + pub fn add_query(&mut self, key: Variable) { + self.query_keys.push(key); + } + + pub fn rangeproof>(&mut self, builder: &mut B, a: Variable, n: usize) { + //add a shift value + let mut n = n; + let mut new_a = a; + if n % self.rangeproof_bits != 0 { + let rem = n % self.rangeproof_bits; + let shift = self.rangeproof_bits - rem; + let constant = (1 << shift) - 1; + let mut mul_factor = 1; + // println!("n:{}", n); + mul_factor <<= n; + let a_shift = builder.mul(constant, mul_factor); + new_a = builder.add(a, a_shift); + n += shift; + } + let hint_input = vec![ + builder.constant(n as u32), + builder.constant(self.rangeproof_bits as u32), + new_a, + ]; + let witnesses = builder.new_hint( + "myhint.rangeproofhint", + &hint_input, + n / self.rangeproof_bits, + ); + let mut sum = witnesses[0]; + for (i, witness) in witnesses.iter().enumerate().skip(1) { + let constant = 1 << (self.rangeproof_bits * i); + let constant = builder.constant(constant); + let mul = builder.mul(witness, constant); + sum = builder.add(sum, mul); + } + builder.assert_is_equal(sum, new_a); + for witness in witnesses.iter().take(n / self.rangeproof_bits) { + self.query_range(*witness); + } + } + + pub fn rangeproof_onechunk>( + &mut self, + builder: &mut B, + a: Variable, + n: usize, + ) { + //n must be less than self.rangeproof_bits, not need the hint + if n > self.rangeproof_bits { + panic!("n must be less than self.rangeproof_bits"); + } + //add a shift value + let mut new_a = a; + if n % self.rangeproof_bits != 0 { + let rem = n % self.rangeproof_bits; + let shift = self.rangeproof_bits - rem; + let constant = (1 << shift) - 1; + let mut mul_factor = 0; + mul_factor <<= n; + let a_shift = builder.mul(constant, mul_factor); + new_a = builder.add(a, a_shift); + } + self.query_range(new_a); + } + + pub fn query_range(&mut self, key: Variable) { + self.query_keys.push(key); + } + + pub fn final_check>(&mut self, builder: &mut B) { + let alpha = builder.get_random_value(); + let inputs = self.query_keys.clone(); + let query_count = builder.new_hint("myhint.querycounthint", &inputs, self.table_keys.len()); + let v_table = logup_poly_val(builder, &self.table_keys, &query_count, &alpha); + + let one = builder.constant(1); + let v_query = logup_poly_val( + builder, + &self.query_keys, + &vec![one; self.query_keys.len()], + &alpha, + ); + assert_eq_rational(builder, &v_table, &v_query); + } +} + +pub fn query_count_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + let mut count = vec![0; outputs.len()]; + for input in inputs { + let query_id = input.to_u256().as_usize(); + count[query_id] += 1; + } + for i in 0..outputs.len() { + outputs[i] = M31::from(count[i] as u32); + } + Ok(()) +} + +pub fn query_count_by_key_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + let mut outputs_u32 = vec![0; outputs.len()]; + + let table_size = inputs[0].to_u256().as_usize(); + let table = &inputs[1..=table_size]; + let query_keys = &inputs[(table_size + 1)..]; + + let mut table_map: HashMap = HashMap::new(); + for key in query_keys { + let key_value = key.to_u256().as_u32(); + *table_map.entry(key_value).or_insert(0) += 1; + } + + for (i, value) in table.iter().enumerate() { + let key_value = value.to_u256().as_u32(); + let count = table_map.get(&key_value).copied().unwrap_or(0); + outputs_u32[i] = count as u32; + } + for i in 0..outputs.len() { + outputs[i] = M31::from(outputs_u32[i]); + } + + Ok(()) +} + +pub fn rangeproof_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + let n = inputs[0].to_u256().as_i64(); + let m = inputs[1].to_u256().as_i64(); + let mut a = inputs[2].to_u256().as_i64(); + for i in 0..n / m { + let r = a % (1 << m); + a /= 1 << m; + outputs[i as usize] = M31::from(r as u32); + } + Ok(()) +} diff --git a/circuit-std-rs/src/poseidon_m31.rs b/circuit-std-rs/src/poseidon_m31.rs new file mode 100644 index 00000000..edee2e6d --- /dev/null +++ b/circuit-std-rs/src/poseidon_m31.rs @@ -0,0 +1,188 @@ +use expander_compiler::frontend::*; +use tiny_keccak::{Hasher, Keccak}; + +const POSEIDON_SEED_PREFIX: &str = "poseidon_seed"; + +const FIELD_NAME: &str = "Mersenne 31"; + +fn get_constants(width: usize, round_num: usize) -> Vec> { + let seed = format!("{POSEIDON_SEED_PREFIX}_{}_{}", FIELD_NAME, width); + + let mut keccak = Keccak::v256(); + let mut buffer = [0u8; 32]; + keccak.update(seed.as_bytes()); + keccak.finalize(&mut buffer); + + let mut res = vec![vec![0u32; width]; round_num]; + + (0..round_num).for_each(|i| { + (0..width).for_each(|j| { + let mut keccak = Keccak::v256(); + keccak.update(&buffer); + keccak.finalize(&mut buffer); + + let mut u32_le_bytes = [0u8; 4]; + u32_le_bytes.copy_from_slice(&buffer[..4]); + + res[i][j] = u32::from_le_bytes(u32_le_bytes); + }); + }); + + res +} + +const MATRIX_CIRC_MDS_8_SML_ROW: [u32; 8] = [7, 1, 3, 8, 8, 3, 4, 9]; + +const MATRIX_CIRC_MDS_12_SML_ROW: [u32; 12] = [1, 1, 2, 1, 8, 9, 10, 7, 5, 9, 4, 10]; + +const MATRIX_CIRC_MDS_16_SML_ROW: [u32; 16] = + [1, 1, 51, 1, 11, 17, 2, 1, 101, 63, 15, 2, 67, 22, 13, 3]; + +fn get_mds_matrix(width: usize) -> Vec> { + let mds_first_row: &[u32] = match width { + 8 => &MATRIX_CIRC_MDS_8_SML_ROW, + 12 => &MATRIX_CIRC_MDS_12_SML_ROW, + 16 => &MATRIX_CIRC_MDS_16_SML_ROW, + _ => panic!("unsupported state width for MDS matrix"), + }; + + let mut res = vec![vec![0u32; width]; width]; + + (0..width).for_each(|i| (0..width).for_each(|j| res[i][j] = mds_first_row[(i + j) % width])); + + res +} + +fn power_5>(api: &mut B, base: Variable) -> Variable { + let pow2 = api.mul(base, base); + let pow4 = api.mul(pow2, pow2); + api.mul(pow4, base) +} + +pub struct PoseidonM31Params { + pub mds_matrix: Vec>, + pub round_constants: Vec>, + + pub rate: usize, + pub width: usize, + pub full_rounds: usize, + pub partial_rounds: usize, +} + +impl PoseidonM31Params { + pub fn new>( + api: &mut B, + rate: usize, + width: usize, + full_rounds: usize, + partial_rounds: usize, + ) -> Self { + let round_constants = get_constants(width, partial_rounds + full_rounds); + let mds_matrix = get_mds_matrix(width); + + let round_constants_variables = (0..partial_rounds + full_rounds) + .map(|i| { + (0..width) + .map(|j| api.constant(round_constants[i][j])) + .collect::>() + }) + .collect::>(); + + let mds_matrix_variables = (0..width) + .map(|i| { + (0..width) + .map(|j| api.constant(mds_matrix[i][j])) + .collect::>() + }) + .collect::>(); + + Self { + mds_matrix: mds_matrix_variables, + round_constants: round_constants_variables, + rate, + width, + full_rounds, + partial_rounds, + } + } + + fn add_round_constants>( + &self, + api: &mut B, + state: &mut [Variable], + constants: &[Variable], + ) { + (0..self.width).for_each(|i| state[i] = api.add(state[i], constants[i])) + } + + fn apply_mds_matrix>(&self, api: &mut B, state: &mut [Variable]) { + let prev_state = state.to_vec(); + + (0..self.width).for_each(|i| { + let mut inner_product = api.constant(0); + (0..self.width).for_each(|j| { + let unit = api.mul(prev_state[j], self.mds_matrix[i][j]); + inner_product = api.add(inner_product, unit); + }); + state[i] = inner_product; + }) + } + + fn partial_full_sbox>(&self, api: &mut B, state: &mut [Variable]) { + state[0] = power_5(api, state[0]) + } + + fn apply_full_sbox>(&self, api: &mut B, state: &mut [Variable]) { + state.iter_mut().for_each(|s| *s = power_5(api, *s)) + } + + pub fn permute>(&self, api: &mut B, state: &mut [Variable]) { + let half_full_rounds = self.full_rounds / 2; + let partial_ends = half_full_rounds + self.partial_rounds; + + assert_eq!(self.width, state.len()); + + (0..half_full_rounds).for_each(|i| { + self.add_round_constants(api, state, &self.round_constants[i]); + self.apply_mds_matrix(api, state); + self.apply_full_sbox(api, state) + }); + (half_full_rounds..partial_ends).for_each(|i| { + self.add_round_constants(api, state, &self.round_constants[i]); + self.apply_mds_matrix(api, state); + self.partial_full_sbox(api, state) + }); + (partial_ends..half_full_rounds + partial_ends).for_each(|i| { + self.add_round_constants(api, state, &self.round_constants[i]); + self.apply_mds_matrix(api, state); + self.apply_full_sbox(api, state) + }); + } + + pub fn hash_to_state>( + &self, + api: &mut B, + inputs: &[Variable], + ) -> Vec { + let mut elts = inputs.to_vec(); + elts.resize(elts.len().next_multiple_of(self.rate), api.constant(0)); + + let mut res = vec![api.constant(0); self.width]; + + elts.chunks(self.rate).for_each(|chunk| { + let mut state_elts = vec![api.constant(0); self.width - self.rate]; + state_elts.extend_from_slice(chunk); + + (0..self.width).for_each(|i| res[i] = api.add(res[i], state_elts[i])); + self.permute(api, &mut res) + }); + + res + } +} + +pub const POSEIDON_M31X16_FULL_ROUNDS: usize = 8; + +pub const POSEIDON_M31X16_PARTIAL_ROUNDS: usize = 14; + +pub const POSEIDON_M31X16_RATE: usize = 8; diff --git a/circuit-std-rs/src/sha256.rs b/circuit-std-rs/src/sha256.rs new file mode 100644 index 00000000..8b142ecc --- /dev/null +++ b/circuit-std-rs/src/sha256.rs @@ -0,0 +1,8 @@ +// The implementation of sha256 for the M31 and GF2 field + +// The Std trait for M31 haven't been implemented yet, see test_m31.rs for the usage +pub mod m31; +pub mod m31_utils; + +pub mod gf2; +pub mod gf2_utils; diff --git a/circuit-std-rs/src/sha256/gf2.rs b/circuit-std-rs/src/sha256/gf2.rs new file mode 100644 index 00000000..9519fbcc --- /dev/null +++ b/circuit-std-rs/src/sha256/gf2.rs @@ -0,0 +1,122 @@ +use expander_compiler::frontend::*; + +use super::gf2_utils::*; + +#[derive(Clone, Debug, Default)] +pub struct SHA256GF2 { + data: Vec, +} + +const SHA256_INIT_STATE: [u32; 8] = [ + 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, +]; + +const SHA256_K: [u32; 64] = [ + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, + 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, + 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, + 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, + 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, + 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, + 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, + 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2, +]; + +impl SHA256GF2 { + pub fn new() -> Self { + Self { data: Vec::new() } + } + + // data can have arbitrary length, do not have to be aligned to 512 bits + pub fn update(&mut self, data: &[Variable]) { + self.data.extend(data); + } + + // finalize the hash, return the hash value + pub fn finalize(&mut self, api: &mut impl RootAPI) -> Vec { + let data_len = self.data.len(); + + // padding according to the sha256 padding rule: https://helix.stormhub.org/papers/SHA-256.pdf + // append a bit '1' first + self.data.push(api.constant(1)); + // append '0' bits to make the length of data congruent to 448 mod 512 + let zero_padding_len = 448 - ((data_len + 1) % 512); + self.data + .extend((0..zero_padding_len).map(|_| api.constant(0))); + // append the length of the data in 64 bits + self.data.extend(u64_to_bit(api, data_len as u64)); + + let mut state = SHA256_INIT_STATE + .iter() + .map(|x| u32_to_bit(api, *x)) + .collect::>() + .try_into() + .unwrap(); + self.data.chunks_exact(512).for_each(|chunk| { + self.sha256_compress(api, &mut state, chunk.try_into().unwrap()); + }); + + state.iter().flatten().cloned().collect() + } + + // The compress function, usually not used directly + pub fn sha256_compress( + &self, + api: &mut impl RootAPI, + state: &mut [Sha256Word; 8], + input: &[Variable; 512], + ) { + let [mut a, mut b, mut c, mut d, mut e, mut f, mut g, mut h] = state; + // self.display_state(api, state); + + let mut w = [[api.constant(0); 32]; 64]; + for i in 0..16 { + w[i] = input[(i * 32)..((i + 1) * 32)].try_into().unwrap(); + } + for i in 16..64 { + let lower_sigma1 = lower_case_sigma1(api, &w[i - 2]); + let s0 = add(api, &lower_sigma1, &w[i - 7]); + + let lower_sigma0 = lower_case_sigma0(api, &w[i - 15]); + let s1 = add(api, &lower_sigma0, &w[i - 16]); + + w[i] = add(api, &s0, &s1); + } + + for i in 0..64 { + let w_plus_k = add_const(api, &w[i], SHA256_K[i]); + let capital_sigma_1_e = capital_sigma1(api, &e); + let ch_e_f_g = ch(api, &e, &f, &g); + let t_1 = sum_all(api, &[h, capital_sigma_1_e, ch_e_f_g, w_plus_k]); + + let capital_sigma_0_a = capital_sigma0(api, &a); + let maj_a_b_c = maj(api, &a, &b, &c); + let t_2 = add(api, &capital_sigma_0_a, &maj_a_b_c); + + h = g; + g = f; + f = e; + e = add(api, &d, &t_1); + d = c; + c = b; + b = a; + a = add(api, &t_1, &t_2); + } + + state[0] = add(api, &state[0], &a); + state[1] = add(api, &state[1], &b); + state[2] = add(api, &state[2], &c); + state[3] = add(api, &state[3], &d); + state[4] = add(api, &state[4], &e); + state[5] = add(api, &state[5], &f); + state[6] = add(api, &state[6], &g); + state[7] = add(api, &state[7], &h); + } + + #[allow(dead_code)] + fn display_state(&self, api: &mut impl RootAPI, state: &[Sha256Word; 8]) { + for (i, s) in state.iter().enumerate() { + api.display(&format!("{}", i), s[30]); + } + } +} diff --git a/circuit-std-rs/src/sha256/gf2_utils.rs b/circuit-std-rs/src/sha256/gf2_utils.rs new file mode 100644 index 00000000..c5d68ef8 --- /dev/null +++ b/circuit-std-rs/src/sha256/gf2_utils.rs @@ -0,0 +1,324 @@ +use expander_compiler::frontend::*; + +pub type Sha256Word = [Variable; 32]; + +// parse the u32 into 32 bits, big-endian +pub fn u32_to_bit>(api: &mut Builder, value: u32) -> [Variable; 32] { + (0..32) + .map(|i| api.constant((value >> (31 - i)) & 1)) + .collect::>() + .try_into() + .expect("Iterator should have exactly 32 elements") +} + +pub fn u64_to_bit>(api: &mut Builder, value: u64) -> [Variable; 64] { + (0..64) + .map(|i| api.constant(((value >> (63 - i)) & 1) as u32)) + .collect::>() + .try_into() + .expect("Iterator should have exactly 64 elements") +} + +pub fn rotate_right(bits: &Sha256Word, k: usize) -> Sha256Word { + assert!(bits.len() & (bits.len() - 1) == 0); + let n = bits.len(); + let s = n - k; + let mut new_bits = bits[s..].to_vec(); + new_bits.append(&mut bits[0..s].to_vec()); + new_bits.try_into().unwrap() +} + +pub fn shift_right>( + api: &mut Builder, + bits: &Sha256Word, + k: usize, +) -> Sha256Word { + assert!(bits.len() & (bits.len() - 1) == 0); + let n = bits.len(); + let s = n - k; + let mut new_bits = vec![api.constant(0); k]; + new_bits.append(&mut bits[0..s].to_vec()); + new_bits.try_into().unwrap() +} + +// Ch function: (x AND y) XOR (NOT x AND z) +pub fn ch>( + api: &mut Builder, + x: &Sha256Word, + y: &Sha256Word, + z: &Sha256Word, +) -> Sha256Word { + let xy = and(api, x, y); + let not_x = not(api, x); + let not_xz = and(api, ¬_x, z); + + xor(api, &xy, ¬_xz) +} + +// Maj function: (x AND y) XOR (x AND z) XOR (y AND z) +pub fn maj>( + api: &mut Builder, + x: &Sha256Word, + y: &Sha256Word, + z: &Sha256Word, +) -> Sha256Word { + let xy = and(api, x, y); + let xz = and(api, x, z); + let yz = and(api, y, z); + let tmp = xor(api, &xy, &xz); + + xor(api, &tmp, &yz) +} + +// sigma0 function: ROTR(x, 7) XOR ROTR(x, 18) XOR SHR(x, 3) +pub fn lower_case_sigma0>( + api: &mut Builder, + word: &Sha256Word, +) -> Sha256Word { + let rot7 = rotate_right(word, 7); + let rot18 = rotate_right(word, 18); + let shft3 = shift_right(api, word, 3); + let tmp = xor(api, &rot7, &rot18); + + xor(api, &tmp, &shft3) +} + +pub fn lower_case_sigma1>( + api: &mut Builder, + word: &Sha256Word, +) -> Sha256Word { + let rot17 = rotate_right(word, 17); + let rot19 = rotate_right(word, 19); + let shft10 = shift_right(api, word, 10); + let tmp = xor(api, &rot17, &rot19); + + xor(api, &tmp, &shft10) +} + +// Sigma0 function: ROTR(x, 2) XOR ROTR(x, 13) XOR ROTR(x, 22) +pub fn capital_sigma0>( + api: &mut Builder, + x: &Sha256Word, +) -> Sha256Word { + let rot2 = rotate_right(x, 2); + let rot13 = rotate_right(x, 13); + let rot22 = rotate_right(x, 22); + let tmp = xor(api, &rot2, &rot13); + + xor(api, &tmp, &rot22) +} + +// Sigma1 function: ROTR(x, 6) XOR ROTR(x, 11) XOR ROTR(x, 25) +pub fn capital_sigma1>( + api: &mut Builder, + x: &Sha256Word, +) -> Sha256Word { + let rot6 = rotate_right(x, 6); + let rot11 = rotate_right(x, 11); + let rot25 = rotate_right(x, 25); + let tmp = xor(api, &rot6, &rot11); + + xor(api, &tmp, &rot25) +} + +pub fn add_const>( + api: &mut Builder, + a: &Sha256Word, + b: u32, +) -> Sha256Word { + let n = a.len(); + let mut c = *a; + let mut ci = api.constant(0); + for i in (0..n).rev() { + if (b >> (31 - i)) & 1 == 1 { + let p = api.add(a[i], 1); + c[i] = api.add(p, ci); + + ci = api.mul(ci, p); + ci = api.add(ci, a[i]); + } else { + c[i] = api.add(c[i], ci); + ci = api.mul(ci, a[i]); + } + } + c +} + +// The brentkung addition algorithm, recommended +pub fn add_brentkung>( + api: &mut Builder, + a: &Sha256Word, + b: &Sha256Word, +) -> Sha256Word { + // temporary solution to change endianness, big -> little + let mut a = *a; + let mut b = *b; + a.reverse(); + b.reverse(); + + let mut c = vec![api.constant(0); 32]; + let mut ci = api.constant(0); + + for i in 0..8 { + let start = i * 4; + let end = start + 4; + + let (sum, ci_next) = brent_kung_adder_4_bits(api, &a[start..end], &b[start..end], ci); + ci = ci_next; + + c[start..end].copy_from_slice(&sum); + } + + // temporary solution to change endianness, little -> big + c.reverse(); + c.try_into().unwrap() +} + +fn brent_kung_adder_4_bits>( + api: &mut Builder, + a: &[Variable], + b: &[Variable], + carry_in: Variable, +) -> ([Variable; 4], Variable) { + let mut g = [api.constant(0); 4]; + let mut p = [api.constant(0); 4]; + + // Step 1: Generate and propagate + for i in 0..4 { + g[i] = api.mul(a[i], b[i]); + p[i] = api.add(a[i], b[i]); + } + + // Step 2: Prefix computation + let p1g0 = api.mul(p[1], g[0]); + let p0p1 = api.mul(p[0], p[1]); + let p2p3 = api.mul(p[2], p[3]); + + let g10 = api.add(g[1], p1g0); + let g20 = api.mul(p[2], g10); + let g20 = api.add(g[2], g20); + let g30 = api.mul(p[3], g20); + let g30 = api.add(g[3], g30); + + // Step 3: Calculate carries + let mut c = [api.constant(0); 5]; + c[0] = carry_in; + let tmp = api.mul(p[0], c[0]); + c[1] = api.add(g[0], tmp); + let tmp = api.mul(p0p1, c[0]); + c[2] = api.add(g10, tmp); + let tmp = api.mul(p[2], c[0]); + let tmp = api.mul(p0p1, tmp); + c[3] = api.add(g20, tmp); + let tmp = api.mul(p0p1, p2p3); + let tmp = api.mul(tmp, c[0]); + c[4] = api.add(g30, tmp); + + // Step 4: Calculate sum + let mut sum = [api.constant(0); 4]; + for i in 0..4 { + sum[i] = api.add(p[i], c[i]); + } + + (sum, c[4]) +} + +pub fn add>( + api: &mut Builder, + a: &Sha256Word, + b: &Sha256Word, +) -> Sha256Word { + add_brentkung(api, a, b) +} + +pub fn sum_all>(api: &mut Builder, vs: &[Sha256Word]) -> Sha256Word { + let mut n_values_to_sum = vs.len(); + let mut vvs = vs.to_vec(); + + // Sum all values in a binary tree fashion to produce fewer layers in the circuit + while n_values_to_sum > 1 { + let half_size_floor = n_values_to_sum / 2; + for i in 0..half_size_floor { + vvs[i] = add(api, &vvs[i], &vvs[i + half_size_floor]) + } + + if n_values_to_sum & 1 != 0 { + vvs[half_size_floor] = vvs[n_values_to_sum - 1]; + } + + n_values_to_sum = (n_values_to_sum + 1) / 2; + } + + vvs[0] +} + +fn bit_add_with_carry>( + api: &mut Builder, + a: Variable, + b: Variable, + carry: Variable, +) -> (Variable, Variable) { + let sum = api.add(a, b); + let sum = api.add(sum, carry); + + // a * (b + (b + 1) * carry) + (a + 1) * b * carry + // = a * b + a * b * carry + a * b * carry + a * carry + b * carry + let ab = api.mul(a, b); + let ac = api.mul(a, carry); + let bc = api.mul(b, carry); + let abc = api.mul(ab, carry); + + let carry_next = api.add(ab, abc); + let carry_next = api.add(carry_next, abc); + let carry_next = api.add(carry_next, ac); + let carry_next = api.add(carry_next, bc); + + (sum, carry_next) +} + +// The vanilla addition algorithm, not recommended +pub fn add_vanilla>( + api: &mut Builder, + a: &Sha256Word, + b: &Sha256Word, +) -> Sha256Word { + let mut c = vec![api.constant(0); 32]; + + let mut carry = api.constant(0); + for i in (0..32).rev() { + (c[i], carry) = bit_add_with_carry(api, a[i], b[i], carry); + } + c.try_into().unwrap() +} + +pub fn xor>( + api: &mut Builder, + a: &Sha256Word, + b: &Sha256Word, +) -> Sha256Word { + let mut bits_res = [api.constant(0); 32]; + for i in 0..32 { + bits_res[i] = api.add(a[i], b[i]); + } + bits_res +} + +pub fn and>( + api: &mut Builder, + a: &Sha256Word, + b: &Sha256Word, +) -> Sha256Word { + let mut bits_res = [api.constant(0); 32]; + for i in 0..32 { + bits_res[i] = api.mul(a[i], b[i]); + } + bits_res +} + +pub fn not>(api: &mut Builder, a: &Sha256Word) -> Sha256Word { + let mut bits_res = [api.constant(0); 32]; + for i in 0..32 { + bits_res[i] = api.sub(1, a[i]); + } + bits_res +} diff --git a/circuit-std-rs/src/sha256/m31.rs b/circuit-std-rs/src/sha256/m31.rs new file mode 100644 index 00000000..d39d10b2 --- /dev/null +++ b/circuit-std-rs/src/sha256/m31.rs @@ -0,0 +1,289 @@ +use super::m31_utils::{ + big_array_add, big_endian_m31_array_put_uint32, bit_array_to_m31, bytes_to_bits, cap_sigma0, + cap_sigma1, ch, m31_to_bit_array, maj, sigma0, sigma1, +}; +use expander_compiler::frontend::*; + +const SHA256LEN: usize = 32; +const CHUNK: usize = 64; +const INIT0: u32 = 0x6A09E667; +const INIT1: u32 = 0xBB67AE85; +const INIT2: u32 = 0x3C6EF372; +const INIT3: u32 = 0xA54FF53A; +const INIT4: u32 = 0x510E527F; +const INIT5: u32 = 0x9B05688C; +const INIT6: u32 = 0x1F83D9AB; +const INIT7: u32 = 0x5BE0CD19; +//for m31 field (2^31-1), split each one to 2 30-bit element +const INIT00: u32 = INIT0 & 0x3FFFFFFF; +const INIT01: u32 = INIT0 >> 30; +const INIT10: u32 = INIT1 & 0x3FFFFFFF; +const INIT11: u32 = INIT1 >> 30; +const INIT20: u32 = INIT2 & 0x3FFFFFFF; +const INIT21: u32 = INIT2 >> 30; +const INIT30: u32 = INIT3 & 0x3FFFFFFF; +const INIT31: u32 = INIT3 >> 30; +const INIT40: u32 = INIT4 & 0x3FFFFFFF; +const INIT41: u32 = INIT4 >> 30; +const INIT50: u32 = INIT5 & 0x3FFFFFFF; +const INIT51: u32 = INIT5 >> 30; +const INIT60: u32 = INIT6 & 0x3FFFFFFF; +const INIT61: u32 = INIT6 >> 30; +const INIT70: u32 = INIT7 & 0x3FFFFFFF; +const INIT71: u32 = INIT7 >> 30; +const _K: [u32; 64] = [ + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, + 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, + 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, + 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, + 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, + 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, + 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, + 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2, +]; +struct MyDigest { + h: [[Variable; 2]; 8], + nx: usize, + len: u64, + kbits: [[Variable; 32]; 64], +} + +impl MyDigest { + fn new>(api: &mut B) -> Self { + let mut h = [[api.constant(0); 2]; 8]; + h[0][0] = api.constant(INIT00); + h[0][1] = api.constant(INIT01); + h[1][0] = api.constant(INIT10); + h[1][1] = api.constant(INIT11); + h[2][0] = api.constant(INIT20); + h[2][1] = api.constant(INIT21); + h[3][0] = api.constant(INIT30); + h[3][1] = api.constant(INIT31); + h[4][0] = api.constant(INIT40); + h[4][1] = api.constant(INIT41); + h[5][0] = api.constant(INIT50); + h[5][1] = api.constant(INIT51); + h[6][0] = api.constant(INIT60); + h[6][1] = api.constant(INIT61); + h[7][0] = api.constant(INIT70); + h[7][1] = api.constant(INIT71); + let mut kbits_u8 = [[0; 32]; 64]; + for i in 0..64 { + for j in 0..32 { + kbits_u8[i][j] = ((_K[i] >> j) & 1) as u8; + } + } + let mut kbits = [[api.constant(0); 32]; 64]; + for i in 0..64 { + for j in 0..32 { + kbits[i][j] = api.constant(kbits_u8[i][j] as u32); + } + } + MyDigest { + h, + nx: 0, + len: 0, + kbits, + } + } + fn reset>(&mut self, api: &mut B) { + for i in 0..8 { + self.h[i] = [api.constant(0); 2]; + } + self.h[0][0] = api.constant(INIT00); + self.h[0][1] = api.constant(INIT01); + self.h[1][0] = api.constant(INIT10); + self.h[1][1] = api.constant(INIT11); + self.h[2][0] = api.constant(INIT20); + self.h[2][1] = api.constant(INIT21); + self.h[3][0] = api.constant(INIT30); + self.h[3][1] = api.constant(INIT31); + self.h[4][0] = api.constant(INIT40); + self.h[4][1] = api.constant(INIT41); + self.h[5][0] = api.constant(INIT50); + self.h[5][1] = api.constant(INIT51); + self.h[6][0] = api.constant(INIT60); + self.h[6][1] = api.constant(INIT61); + self.h[7][0] = api.constant(INIT70); + self.h[7][1] = api.constant(INIT71); + self.nx = 0; + self.len = 0; + } + //always write a chunk + fn chunk_write>(&mut self, api: &mut B, p: &[Variable]) { + if p.len() != CHUNK || self.nx != 0 { + panic!("p.len() != CHUNK || self.nx != 0"); + } + self.len += CHUNK as u64; + let tmp_h = self.h; + self.h = self.block(api, tmp_h, p); + } + fn return_sum>(&mut self, api: &mut B) -> [Variable; SHA256LEN] { + let mut digest = [api.constant(0); SHA256LEN]; + + big_endian_m31_array_put_uint32(api, &mut digest[0..], self.h[0]); + big_endian_m31_array_put_uint32(api, &mut digest[4..], self.h[1]); + big_endian_m31_array_put_uint32(api, &mut digest[8..], self.h[2]); + big_endian_m31_array_put_uint32(api, &mut digest[12..], self.h[3]); + big_endian_m31_array_put_uint32(api, &mut digest[16..], self.h[4]); + big_endian_m31_array_put_uint32(api, &mut digest[20..], self.h[5]); + big_endian_m31_array_put_uint32(api, &mut digest[24..], self.h[6]); + big_endian_m31_array_put_uint32(api, &mut digest[28..], self.h[7]); + digest + } + + fn block>( + &mut self, + api: &mut B, + h: [[Variable; 2]; 8], + p: &[Variable], + ) -> [[Variable; 2]; 8] { + let mut p = p; + let mut hh = h; + while p.len() >= CHUNK { + let mut msg_schedule = vec![]; + for t in 0..64 { + if t <= 15 { + msg_schedule.push(bytes_to_bits(api, &p[t * 4..t * 4 + 4])); + } else { + let term1_tmp = sigma1(api, &msg_schedule[t - 2]); + let term1 = bit_array_to_m31(api, &term1_tmp); + let term2 = bit_array_to_m31(api, &msg_schedule[t - 7]); + let term3_tmp = sigma0(api, &msg_schedule[t - 15]); + let term3 = bit_array_to_m31(api, &term3_tmp); + let term4 = bit_array_to_m31(api, &msg_schedule[t - 16]); + let schedule_tmp1 = big_array_add(api, &term1, &term2, 30); + let schedule_tmp2 = big_array_add(api, &term3, &term4, 30); + let schedule = big_array_add(api, &schedule_tmp1, &schedule_tmp2, 30); + let schedule_bits = m31_to_bit_array(api, &schedule)[..32].to_vec(); + msg_schedule.push(schedule_bits); + } + } + let mut a = hh[0].to_vec(); + let mut b = hh[1].to_vec(); + let mut c = hh[2].to_vec(); + let mut d = hh[3].to_vec(); + let mut e = hh[4].to_vec(); + let mut f = hh[5].to_vec(); + let mut g = hh[6].to_vec(); + let mut h = hh[7].to_vec(); + + //rewrite + let mut a_bit = m31_to_bit_array(api, &a)[..32].to_vec(); + let mut b_bit = m31_to_bit_array(api, &b)[..32].to_vec(); + let mut c_bit = m31_to_bit_array(api, &c)[..32].to_vec(); + let mut e_bit = m31_to_bit_array(api, &e)[..32].to_vec(); + let mut f_bit = m31_to_bit_array(api, &f)[..32].to_vec(); + let mut g_bit = m31_to_bit_array(api, &g)[..32].to_vec(); + for (t, schedule) in msg_schedule.iter().enumerate().take(64) { + let mut t1_term1 = [api.constant(0); 2]; + t1_term1[0] = h[0]; + t1_term1[1] = h[1]; + let t1_term2_tmp = cap_sigma1(api, &e_bit); + let t1_term2 = bit_array_to_m31(api, &t1_term2_tmp); + let t1_term3_tmp = ch(api, &e_bit, &f_bit, &g_bit); + let t1_term3 = bit_array_to_m31(api, &t1_term3_tmp); + let t1_term4 = bit_array_to_m31(api, &self.kbits[t]); //rewrite to [2]frontend.Variable + let t1_term5 = bit_array_to_m31(api, schedule); + let tmp1 = big_array_add(api, &t1_term1, &t1_term2, 30); + let tmp2 = big_array_add(api, &t1_term3, &t1_term4, 30); + let tmp3 = big_array_add(api, &tmp1, &tmp2, 30); + let tmp4 = big_array_add(api, &tmp3, &t1_term5, 30); + let t1 = tmp4; + let t2_tmp1 = cap_sigma0(api, &a_bit); + let t2_tmp2 = bit_array_to_m31(api, &t2_tmp1); + let t2_tmp3 = maj(api, &a_bit, &b_bit, &c_bit); + let t2_tmp4 = bit_array_to_m31(api, &t2_tmp3); + let t2 = big_array_add(api, &t2_tmp2, &t2_tmp4, 30); + let new_a_bit_tmp = big_array_add(api, &t1, &t2, 30); + let new_a_bit = m31_to_bit_array(api, &new_a_bit_tmp)[..32].to_vec(); + let new_e_bit_tmp = big_array_add(api, &d[..2], &t1, 30); + let new_e_bit = m31_to_bit_array(api, &new_e_bit_tmp)[..32].to_vec(); + h = g.to_vec(); + g = f.to_vec(); + f = e.to_vec(); + d = c.to_vec(); + c = b.to_vec(); + b = a.to_vec(); + a = bit_array_to_m31(api, &new_a_bit).to_vec(); + e = bit_array_to_m31(api, &new_e_bit).to_vec(); + g_bit = f_bit.to_vec(); + f_bit = e_bit.to_vec(); + c_bit = b_bit.to_vec(); + b_bit = a_bit.to_vec(); + a_bit = new_a_bit.to_vec(); + e_bit = new_e_bit.to_vec(); + } + let hh0_tmp1 = big_array_add(api, &hh[0], &a, 30); + let hh0_tmp2 = m31_to_bit_array(api, &hh0_tmp1); + hh[0] = bit_array_to_m31(api, &hh0_tmp2[..32]) + .as_slice() + .try_into() + .unwrap(); + let hh1_tmp1 = big_array_add(api, &hh[1], &b, 30); + let hh1_tmp2 = m31_to_bit_array(api, &hh1_tmp1); + hh[1] = bit_array_to_m31(api, &hh1_tmp2[..32]) + .as_slice() + .try_into() + .unwrap(); + let hh2_tmp1 = big_array_add(api, &hh[2], &c, 30); + let hh2_tmp2 = m31_to_bit_array(api, &hh2_tmp1); + hh[2] = bit_array_to_m31(api, &hh2_tmp2[..32]) + .as_slice() + .try_into() + .unwrap(); + let hh3_tmp1 = big_array_add(api, &hh[3], &d, 30); + let hh3_tmp2 = m31_to_bit_array(api, &hh3_tmp1); + hh[3] = bit_array_to_m31(api, &hh3_tmp2[..32]) + .as_slice() + .try_into() + .unwrap(); + let hh4_tmp1 = big_array_add(api, &hh[4], &e, 30); + let hh4_tmp2 = m31_to_bit_array(api, &hh4_tmp1); + hh[4] = bit_array_to_m31(api, &hh4_tmp2[..32]) + .as_slice() + .try_into() + .unwrap(); + let hh5_tmp1 = big_array_add(api, &hh[5], &f, 30); + let hh5_tmp2 = m31_to_bit_array(api, &hh5_tmp1); + hh[5] = bit_array_to_m31(api, &hh5_tmp2[..32]) + .as_slice() + .try_into() + .unwrap(); + let hh6_tmp1 = big_array_add(api, &hh[6], &g, 30); + let hh6_tmp2 = m31_to_bit_array(api, &hh6_tmp1); + hh[6] = bit_array_to_m31(api, &hh6_tmp2[..32]) + .as_slice() + .try_into() + .unwrap(); + let hh7_tmp1 = big_array_add(api, &hh[7], &h, 30); + let hh7_tmp2 = m31_to_bit_array(api, &hh7_tmp1); + hh[7] = bit_array_to_m31(api, &hh7_tmp2[..32]) + .as_slice() + .try_into() + .unwrap(); + p = &p[CHUNK..]; + } + hh + } +} + +pub fn sha256_37bytes>( + builder: &mut B, + orign_data: &[Variable], +) -> Vec { + let mut data = orign_data.to_vec(); + let n = data.len(); + if n != 32 + 1 + 4 { + panic!("len(orignData) != 32+1+4") + } + let mut pre_pad = vec![builder.constant(0); 64 - 37]; + pre_pad[0] = builder.constant(128); //0x80 + pre_pad[64 - 37 - 2] = builder.constant((37) * 8 / 256); //length byte + pre_pad[64 - 37 - 1] = builder.constant((32 + 1 + 4) * 8 - 256); //length byte + data.append(&mut pre_pad); //append padding + let mut d = MyDigest::new(builder); + d.reset(builder); + d.chunk_write(builder, &data); + d.return_sum(builder).to_vec() +} diff --git a/circuit-std-rs/src/sha256/m31_utils.rs b/circuit-std-rs/src/sha256/m31_utils.rs new file mode 100644 index 00000000..32942a80 --- /dev/null +++ b/circuit-std-rs/src/sha256/m31_utils.rs @@ -0,0 +1,409 @@ +use expander_compiler::frontend::*; +use num_bigint::BigInt; +use num_traits::cast::ToPrimitive; + +pub fn bytes_to_bits>(api: &mut B, vals: &[Variable]) -> Vec { + let mut ret = to_binary(api, vals[0], 8); + for val in vals.iter().skip(1) { + ret = to_binary(api, *val, 8) + .into_iter() + .chain(ret.into_iter()) + .collect(); + } + ret +} +pub fn right_shift>( + api: &mut B, + bits: &[Variable], + shift: usize, +) -> Vec { + if bits.len() != 32 { + panic!("RightShift: len(bits) != 32"); + } + let mut shifted_bits = bits[shift..].to_vec(); + for _ in 0..shift { + shifted_bits.push(api.constant(0)); + } + shifted_bits +} +pub fn rotate_right(bits: &[Variable], shift: usize) -> Vec { + if bits.len() != 32 { + panic!("RotateRight: len(bits) != 32"); + } + let mut rotated_bits = bits[shift..].to_vec(); + rotated_bits.extend_from_slice(&bits[..shift]); + rotated_bits +} +pub fn sigma0>(api: &mut B, bits: &[Variable]) -> Vec { + if bits.len() != 32 { + panic!("Sigma0: len(bits) != 32"); + } + let bits1 = bits.to_vec(); + let bits2 = bits.to_vec(); + let bits3 = bits.to_vec(); + let v1 = rotate_right(&bits1, 7); + let v2 = rotate_right(&bits2, 18); + let v3 = right_shift(api, &bits3, 3); + let mut ret = vec![]; + for i in 0..32 { + let tmp = api.xor(v1[i], v2[i]); + ret.push(api.xor(tmp, v3[i])); + } + ret +} +pub fn sigma1>(api: &mut B, bits: &[Variable]) -> Vec { + if bits.len() != 32 { + panic!("Sigma1: len(bits) != 32"); + } + let bits1 = bits.to_vec(); + let bits2 = bits.to_vec(); + let bits3 = bits.to_vec(); + let v1 = rotate_right(&bits1, 17); + let v2 = rotate_right(&bits2, 19); + let v3 = right_shift(api, &bits3, 10); + let mut ret = vec![]; + for i in 0..32 { + let tmp = api.xor(v1[i], v2[i]); + ret.push(api.xor(tmp, v3[i])); + } + ret +} +pub fn cap_sigma0>(api: &mut B, bits: &[Variable]) -> Vec { + if bits.len() != 32 { + panic!("CapSigma0: len(bits) != 32"); + } + let bits1 = bits.to_vec(); + let bits2 = bits.to_vec(); + let bits3 = bits.to_vec(); + let v1 = rotate_right(&bits1, 2); + let v2 = rotate_right(&bits2, 13); + let v3 = rotate_right(&bits3, 22); + let mut ret = vec![]; + for i in 0..32 { + let tmp = api.xor(v1[i], v2[i]); + ret.push(api.xor(tmp, v3[i])); + } + ret +} +pub fn cap_sigma1>(api: &mut B, bits: &[Variable]) -> Vec { + if bits.len() != 32 { + panic!("CapSigma1: len(bits) != 32"); + } + let bits1 = bits.to_vec(); + let bits2 = bits.to_vec(); + let bits3 = bits.to_vec(); + let v1 = rotate_right(&bits1, 6); + let v2 = rotate_right(&bits2, 11); + let v3 = rotate_right(&bits3, 25); + let mut ret = vec![]; + for i in 0..32 { + let tmp = api.xor(v1[i], v2[i]); + ret.push(api.xor(tmp, v3[i])); + } + ret +} +pub fn ch>( + api: &mut B, + x: &[Variable], + y: &[Variable], + z: &[Variable], +) -> Vec { + if x.len() != 32 || y.len() != 32 || z.len() != 32 { + panic!("Ch: len(x) != 32 || len(y) != 32 || len(z) != 32"); + } + let mut ret = vec![]; + for i in 0..32 { + let tmp1 = api.and(x[i], y[i]); + let tmp2 = api.xor(x[i], 1); + let tmp3 = api.and(tmp2, z[i]); + ret.push(api.xor(tmp1, tmp3)); + } + ret +} +pub fn maj>( + api: &mut B, + x: &[Variable], + y: &[Variable], + z: &[Variable], +) -> Vec { + if x.len() != 32 || y.len() != 32 || z.len() != 32 { + panic!("Maj: len(x) != 32 || len(y) != 32 || len(z) != 32"); + } + let mut ret = vec![]; + for i in 0..32 { + let tmp1 = api.and(x[i], y[i]); + let tmp2 = api.and(x[i], z[i]); + let tmp3 = api.and(y[i], z[i]); + let tmp4 = api.xor(tmp1, tmp2); + ret.push(api.xor(tmp3, tmp4)); + } + ret +} +pub fn big_array_add>( + api: &mut B, + a: &[Variable], + b: &[Variable], + nb_bits: usize, +) -> Vec { + if a.len() != b.len() { + panic!("BigArrayAdd: length of a and b must be equal"); + } + let mut c = vec![api.constant(0); a.len()]; + let mut carry = api.constant(0); + for i in 0..a.len() { + c[i] = api.add(a[i], b[i]); + c[i] = api.add(c[i], carry); + carry = to_binary(api, c[i], nb_bits + 1)[nb_bits]; + let tmp = api.mul(carry, 1 << nb_bits); + c[i] = api.sub(c[i], tmp); + } + c +} +pub fn bit_array_to_m31>(api: &mut B, bits: &[Variable]) -> [Variable; 2] { + if bits.len() >= 60 { + panic!("BitArrayToM31: length of bits must be less than 60"); + } + [ + from_binary(api, bits[..30].to_vec()), + from_binary(api, bits[30..].to_vec()), + ] +} + +pub fn big_endian_m31_array_put_uint32>( + api: &mut B, + b: &mut [Variable], + x: [Variable; 2], +) { + let mut quo = x[0]; + for i in (1..=3).rev() { + let (q, r) = idiv_mod_bit(api, quo, 8); + b[i] = r; + quo = q; + } + let shift = api.mul(x[1], 1 << 6); + b[0] = api.add(quo, shift); +} + +pub fn big_endian_put_uint64>( + api: &mut B, + b: &mut [Variable], + x: Variable, +) { + let mut quo = x; + for i in (1..=7).rev() { + let (q, r) = idiv_mod_bit(api, quo, 8); + b[i] = r; + quo = q; + } + b[0] = quo; +} +pub fn m31_to_bit_array>(api: &mut B, m31: &[Variable]) -> Vec { + let mut bits = vec![]; + for val in m31 { + bits.extend_from_slice(&to_binary(api, *val, 30)); + } + bits +} +pub fn to_binary>( + api: &mut B, + x: Variable, + n_bits: usize, +) -> Vec { + api.new_hint("myhint.tobinary", &[x], n_bits) +} +pub fn from_binary>(api: &mut B, bits: Vec) -> Variable { + let mut res = api.constant(0); + for (i, bit) in bits.iter().enumerate() { + let coef = 1 << i; + let cur = api.mul(coef, *bit); + res = api.add(res, cur); + } + res +} + +pub fn to_binary_hint(x: &[M31], y: &mut [M31]) -> Result<(), Error> { + let t = x[0].to_u256(); + for (i, k) in y.iter_mut().enumerate() { + *k = M31::from_u256(t >> i as u32 & 1); + } + Ok(()) +} + +pub fn big_is_zero>(api: &mut B, k: usize, in_: &[Variable]) -> Variable { + let mut total = api.constant(k as u32); + for val in in_.iter().take(k) { + let tmp = api.is_zero(val); + total = api.sub(total, tmp); + } + api.is_zero(total) +} + +pub fn bigint_to_m31_array>( + api: &mut B, + x: BigInt, + n_bits: usize, + limb_len: usize, +) -> Vec { + let mut res = vec![]; + let mut a = x.clone(); + let mut mask = BigInt::from(1) << n_bits; + mask -= 1; + for _ in 0..limb_len { + let tmp = a.clone() & mask.clone(); + let tmp = api.constant(tmp.to_u32().unwrap()); + res.push(tmp); + a >>= n_bits; + } + res +} +pub fn big_less_than>( + api: &mut B, + n: usize, + k: usize, + a: &[Variable], + b: &[Variable], +) -> Variable { + let mut lt = vec![]; + let mut eq = vec![]; + for i in 0..k { + lt.push(my_is_less(api, n, a[i], b[i])); + let diff = api.sub(a[i], b[i]); + eq.push(api.is_zero(diff)); + } + let mut ors = vec![Variable::default(); k - 1]; + let mut ands = vec![Variable::default(); k - 1]; + let mut eq_ands = vec![Variable::default(); k - 1]; + for i in (0..k - 1).rev() { + if i == k - 2 { + ands[i] = api.and(eq[k - 1], lt[k - 2]); + eq_ands[i] = api.and(eq[k - 1], eq[k - 2]); + ors[i] = api.or(lt[k - 1], ands[k - 2]); + } else { + ands[i] = api.and(eq_ands[i + 1], lt[i]); + eq_ands[i] = api.and(eq_ands[i + 1], eq[i]); + ors[i] = api.or(ors[i + 1], ands[i]); + } + } + ors[0] +} +pub fn my_is_less>( + api: &mut B, + n: usize, + a: Variable, + b: Variable, +) -> Variable { + let neg_b = api.neg(b); + let tmp = api.add(a, 1 << n); + let tmp = api.add(tmp, neg_b); + let bi1 = to_binary(api, tmp, n + 1); + let one = api.constant(1); + api.sub(one, bi1[n]) +} + +pub fn idiv_mod_bit>( + builder: &mut B, + a: Variable, + b: u64, +) -> (Variable, Variable) { + let bits = to_binary(builder, a, 30); + let quotient = from_binary(builder, bits[b as usize..].to_vec()); + let remainder = from_binary(builder, bits[..b as usize].to_vec()); + (quotient, remainder) +} + +pub fn string_to_m31_array(s: &str, nb_bits: u32) -> [M31; 48] { + let mut big = + BigInt::parse_bytes(s.as_bytes(), 10).unwrap_or_else(|| panic!("Failed to parse BigInt")); + let mut res = [M31::from(0); 48]; + let base = BigInt::from(1) << nb_bits; + for cur_res in &mut res { + let tmp = &big % &base; + *cur_res = M31::from(tmp.to_u32().unwrap()); + big >>= nb_bits; + } + res +} + +declare_circuit!(IDIVMODBITCircuit { + value: PublicVariable, + quotient: Variable, + remainder: Variable, +}); + +impl Define for IDIVMODBITCircuit { + fn define(&self, builder: &mut API) { + let (quotient, remainder) = idiv_mod_bit(builder, self.value, 8); + builder.assert_is_equal(quotient, self.quotient); + builder.assert_is_equal(remainder, self.remainder); + } +} +#[test] +fn test_idiv_mod_bit() { + //register hints + let mut hint_registry = HintRegistry::::new(); + hint_registry.register("myhint.tobinary", to_binary_hint); + //compile and test + let compile_result = compile(&IDIVMODBITCircuit::default()).unwrap(); + let assignment = IDIVMODBITCircuit:: { + value: M31::from(3845), + quotient: M31::from(15), + remainder: M31::from(5), + }; + let witness = compile_result + .witness_solver + .solve_witness_with_hints(&assignment, &mut hint_registry) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![true]); +} + +declare_circuit!(BITCONVERTCircuit { + big_int: PublicVariable, + big_int_bytes: [Variable; 8], + big_int_m31: [Variable; 2], + big_int_m31_bytes: [Variable; 4], +}); + +impl Define for BITCONVERTCircuit { + fn define(&self, builder: &mut API) { + let mut big_int_bytes = [builder.constant(0); 8]; + big_endian_put_uint64(builder, &mut big_int_bytes, self.big_int); + for (i, big_int_byte) in big_int_bytes.iter().enumerate() { + builder.assert_is_equal(big_int_byte, self.big_int_bytes[i]); + } + let mut big_int_m31 = [builder.constant(0); 4]; + big_endian_m31_array_put_uint32(builder, &mut big_int_m31, self.big_int_m31); + for (i, val) in big_int_m31.iter().enumerate() { + builder.assert_is_equal(val, self.big_int_m31_bytes[i]); + } + } +} +#[test] +fn test_bit_convert() { + //register hints + let mut hint_registry = HintRegistry::::new(); + hint_registry.register("myhint.tobinary", to_binary_hint); + //compile and test + let compile_result = compile(&BITCONVERTCircuit::default()).unwrap(); + let assignment = BITCONVERTCircuit:: { + big_int: M31::from(3845), + big_int_bytes: [ + M31::from(0), + M31::from(0), + M31::from(0), + M31::from(0), + M31::from(0), + M31::from(0), + M31::from(15), + M31::from(5), + ], + big_int_m31: [M31::from(3845), M31::from(0)], + big_int_m31_bytes: [M31::from(0), M31::from(0), M31::from(15), M31::from(5)], + }; + let witness = compile_result + .witness_solver + .solve_witness_with_hints(&assignment, &mut hint_registry) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![true]); +} diff --git a/circuit-std-rs/src/traits.rs b/circuit-std-rs/src/traits.rs index f42ca176..8d4fb4b1 100644 --- a/circuit-std-rs/src/traits.rs +++ b/circuit-std-rs/src/traits.rs @@ -8,7 +8,9 @@ pub trait StdCircuit: Clone + Define + DumpLoadTwoVariables; + // Create a new circuit with the given parameters fn new_circuit(params: &Self::Params) -> Self; + // Create a new random assignment for the circuit fn new_assignment(params: &Self::Params, rng: impl RngCore) -> Self::Assignment; } diff --git a/circuit-std-rs/src/utils.rs b/circuit-std-rs/src/utils.rs new file mode 100644 index 00000000..898fe24e --- /dev/null +++ b/circuit-std-rs/src/utils.rs @@ -0,0 +1,30 @@ +use expander_compiler::frontend::*; + +pub fn simple_select>( + native: &mut B, + selector: Variable, + a: Variable, + b: Variable, +) -> Variable { + let tmp = native.sub(a, b); + let tmp2 = native.mul(tmp, selector); + native.add(b, tmp2) +} + +//return i0 if selector0 and selector 1 are 0 +//return i1 if selector0 is 1 and selector1 is 0 +//return i2 if selector0 is 0 and selector1 is 1 +//return d if selector0 and selector1 are 1 +pub fn simple_lookup2>( + native: &mut B, + selector0: Variable, + selector1: Variable, + i0: Variable, + i1: Variable, + i2: Variable, + i3: Variable, +) -> Variable { + let tmp0 = simple_select(native, selector0, i1, i0); + let tmp1 = simple_select(native, selector0, i3, i2); + simple_select(native, selector1, tmp1, tmp0) +} diff --git a/circuit-std-rs/tests/common.rs b/circuit-std-rs/tests/common.rs index 1adb95a8..bf777187 100644 --- a/circuit-std-rs/tests/common.rs +++ b/circuit-std-rs/tests/common.rs @@ -9,8 +9,8 @@ where Cir: StdCircuit, { let mut rng = thread_rng(); - let compile_result: CompileResult = compile(&Cir::new_circuit(¶ms)).unwrap(); - let assignment = Cir::new_assignment(¶ms, &mut rng); + let compile_result: CompileResult = compile(&Cir::new_circuit(params)).unwrap(); + let assignment = Cir::new_assignment(params, &mut rng); let witness = compile_result .witness_solver .solve_witness(&assignment) diff --git a/circuit-std-rs/tests/gnark.rs b/circuit-std-rs/tests/gnark.rs new file mode 100644 index 00000000..fe9c1e24 --- /dev/null +++ b/circuit-std-rs/tests/gnark.rs @@ -0,0 +1,14 @@ +mod gnark { + mod emulated { + mod field_bls12381 { + mod e12; + mod e2; + mod e6; + } + mod sw_bls12381 { + mod g1; + mod pairing; + } + } + mod element; +} diff --git a/circuit-std-rs/tests/gnark/element.rs b/circuit-std-rs/tests/gnark/element.rs new file mode 100644 index 00000000..f5fce973 --- /dev/null +++ b/circuit-std-rs/tests/gnark/element.rs @@ -0,0 +1,95 @@ +#[cfg(test)] +mod tests { + use circuit_std_rs::gnark::{ + element::{from_interface, value_of}, + emparam::Bls12381Fp, + }; + use expander_compiler::frontend::*; + use num_bigint::BigInt; + #[test] + fn test_from_interface() { + let v = 1111111u32; + let r = from_interface(Box::new(v)); + assert_eq!(r, BigInt::from(1111111u32)); + let v = 22222222222222u64; + let r = from_interface(Box::new(v)); + assert_eq!(r, BigInt::from(22222222222222u64)); + let v = 333333usize; + let r = from_interface(Box::new(v)); + assert_eq!(r, BigInt::from(333333usize as u64)); + let v = 444444i32; + let r = from_interface(Box::new(v)); + assert_eq!(r, BigInt::from(444444i32)); + let v = 555555555555555i64; + let r = from_interface(Box::new(v)); + assert_eq!(r, BigInt::from(555555555555555i64)); + let v = 666isize; + let r = from_interface(Box::new(v)); + assert_eq!(r, BigInt::from(666isize as i64)); + let v = "77777777777777777".to_string(); + let r = from_interface(Box::new(v)); + assert_eq!(r, BigInt::from(77777777777777777u64)); + let v = vec![7u8; 4]; + let r = from_interface(Box::new(v)); + assert_eq!(r, BigInt::from(0x07070707u32)); + } + + declare_circuit!(VALUECircuit { + target: [[Variable; 48]; 8], + }); + impl Define for VALUECircuit { + fn define(&self, builder: &mut API) { + let v1 = 1111111u32; + let v2 = 22222222222222u64; + let v3 = 333333usize; + let v4 = 444444i32; + let v5 = 555555555555555i64; + let v6 = 666isize; + let v7 = "77777777777777777".to_string(); + let v8 = vec![8u8; 4]; + + let r1 = value_of::(builder, Box::new(v1)); + let r2 = value_of::(builder, Box::new(v2)); + let r3 = value_of::(builder, Box::new(v3)); + let r4 = value_of::(builder, Box::new(v4)); + let r5 = value_of::(builder, Box::new(v5)); + let r6 = value_of::(builder, Box::new(v6)); + let r7 = value_of::(builder, Box::new(v7)); + let r8 = value_of::(builder, Box::new(v8)); + let rs = vec![r1, r2, r3, r4, r5, r6, r7, r8]; + for i in 0..rs.len() { + for j in 0..rs[i].limbs.len() { + builder.assert_is_equal(rs[i].limbs[j], self.target[i][j]); + } + } + } + } + + #[test] + fn test_value() { + let values: Vec = vec![ + 1111111, + 22222222222222, + 333333, + 444444, + 555555555555555, + 666, + 77777777777777777, + 0x08080808, + ]; + let values_u8: Vec> = values.iter().map(|v| v.to_le_bytes().to_vec()).collect(); + let compile_result = compile(&VALUECircuit::default()).unwrap(); + let mut assignment = VALUECircuit::::default(); + for i in 0..values_u8.len() { + for j in 0..values_u8[i].len() { + assignment.target[i][j] = M31::from(values_u8[i][j] as u32); + } + } + let witness = compile_result + .witness_solver + .solve_witness(&assignment) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![true]); + } +} diff --git a/circuit-std-rs/tests/gnark/emulated/field_bls12381/e12.rs b/circuit-std-rs/tests/gnark/emulated/field_bls12381/e12.rs new file mode 100644 index 00000000..fb9ca916 --- /dev/null +++ b/circuit-std-rs/tests/gnark/emulated/field_bls12381/e12.rs @@ -0,0 +1,2397 @@ +use circuit_std_rs::gnark::{ + element::new_internal_element, + emulated::field_bls12381::{ + e12::{Ext12, GE12}, + e2::GE2, + e6::GE6, + }, + hints::register_hint, +}; +use expander_compiler::{ + compile::CompileOptions, + declare_circuit, + frontend::{ + compile_generic, extra::debug_eval, GenericDefine, HintRegistry, M31Config, RootAPI, + Variable, M31, + }, +}; + +declare_circuit!(E12AddCircuit { + x: [[[[Variable; 48]; 2]; 3]; 2], + y: [[[[Variable; 48]; 2]; 3]; 2], + z: [[[[Variable; 48]; 2]; 3]; 2], +}); + +impl GenericDefine for E12AddCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext12 = Ext12::new(builder); + let x_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.x[0][0][0].to_vec(), 0), + a1: new_internal_element(self.x[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.x[0][1][0].to_vec(), 0), + a1: new_internal_element(self.x[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.x[0][2][0].to_vec(), 0), + a1: new_internal_element(self.x[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.x[1][0][0].to_vec(), 0), + a1: new_internal_element(self.x[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.x[1][1][0].to_vec(), 0), + a1: new_internal_element(self.x[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.x[1][2][0].to_vec(), 0), + a1: new_internal_element(self.x[1][2][1].to_vec(), 0), + }, + }, + }; + let y_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.y[0][0][0].to_vec(), 0), + a1: new_internal_element(self.y[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.y[0][1][0].to_vec(), 0), + a1: new_internal_element(self.y[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.y[0][2][0].to_vec(), 0), + a1: new_internal_element(self.y[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.y[1][0][0].to_vec(), 0), + a1: new_internal_element(self.y[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.y[1][1][0].to_vec(), 0), + a1: new_internal_element(self.y[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.y[1][2][0].to_vec(), 0), + a1: new_internal_element(self.y[1][2][1].to_vec(), 0), + }, + }, + }; + let z_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.z[0][0][0].to_vec(), 0), + a1: new_internal_element(self.z[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.z[0][1][0].to_vec(), 0), + a1: new_internal_element(self.z[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.z[0][2][0].to_vec(), 0), + a1: new_internal_element(self.z[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.z[1][0][0].to_vec(), 0), + a1: new_internal_element(self.z[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.z[1][1][0].to_vec(), 0), + a1: new_internal_element(self.z[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.z[1][2][0].to_vec(), 0), + a1: new_internal_element(self.z[1][2][1].to_vec(), 0), + }, + }, + }; + let z = ext12.add(builder, &x_e12, &y_e12); + ext12.assert_isequal(builder, &z, &z_e12); + ext12.ext6.ext2.curve_f.check_mul(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + } +} +#[test] +fn test_e12_add() { + compile_generic(&E12AddCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = E12AddCircuit:: { + x: [[[[M31::from(0); 48]; 2]; 3]; 2], + y: [[[[M31::from(0); 48]; 2]; 3]; 2], + z: [[[[M31::from(0); 48]; 2]; 3]; 2], + }; + + let x0_c0_b0_a0_bytes = [ + 230, 7, 244, 92, 237, 70, 117, 94, 82, 55, 74, 196, 172, 118, 86, 33, 195, 231, 218, 215, + 169, 200, 47, 95, 2, 162, 203, 215, 88, 27, 146, 255, 185, 205, 74, 164, 252, 251, 241, 36, + 112, 228, 157, 87, 122, 78, 189, 18, + ]; + let x0_c0_b0_a1_bytes = [ + 123, 74, 33, 121, 6, 155, 7, 109, 108, 65, 144, 138, 43, 39, 102, 201, 193, 139, 222, 60, + 96, 210, 211, 212, 214, 250, 64, 56, 217, 19, 222, 230, 161, 139, 175, 92, 207, 204, 60, + 236, 42, 23, 130, 36, 116, 94, 235, 22, + ]; + let x0_c0_b1_a0_bytes = [ + 49, 127, 28, 75, 52, 125, 232, 138, 94, 244, 108, 5, 97, 129, 205, 223, 92, 250, 249, 164, + 70, 188, 87, 59, 88, 120, 208, 94, 48, 41, 13, 251, 243, 5, 118, 105, 177, 148, 29, 54, + 156, 135, 64, 151, 157, 0, 119, 7, + ]; + let x0_c0_b1_a1_bytes = [ + 111, 133, 18, 247, 78, 21, 80, 154, 216, 230, 186, 223, 109, 228, 163, 119, 98, 30, 52, + 145, 174, 146, 135, 230, 44, 58, 58, 70, 56, 108, 96, 150, 67, 181, 53, 124, 38, 92, 190, + 174, 68, 18, 176, 112, 232, 23, 102, 7, + ]; + let x0_c0_b2_a0_bytes = [ + 194, 50, 236, 56, 30, 253, 216, 230, 252, 43, 62, 251, 37, 124, 173, 107, 236, 62, 190, + 121, 225, 13, 255, 152, 137, 221, 37, 23, 178, 16, 232, 244, 15, 29, 1, 229, 201, 43, 27, + 85, 173, 191, 250, 2, 43, 39, 206, 12, + ]; + let x0_c0_b2_a1_bytes = [ + 141, 208, 78, 212, 20, 209, 73, 151, 224, 146, 235, 177, 88, 38, 231, 36, 205, 8, 223, 66, + 35, 157, 28, 37, 123, 92, 239, 77, 190, 243, 142, 2, 228, 145, 241, 47, 251, 55, 59, 116, + 195, 196, 90, 86, 171, 39, 236, 12, + ]; + let x0_c1_b0_a0_bytes = [ + 169, 135, 2, 13, 240, 185, 47, 225, 235, 154, 118, 30, 95, 163, 223, 25, 184, 76, 152, 231, + 206, 120, 67, 227, 223, 228, 226, 172, 134, 24, 174, 108, 8, 21, 235, 122, 63, 78, 129, + 226, 8, 205, 153, 206, 152, 214, 164, 12, + ]; + let x0_c1_b0_a1_bytes = [ + 250, 192, 145, 229, 203, 199, 112, 129, 255, 241, 90, 53, 11, 91, 241, 117, 135, 247, 116, + 237, 193, 5, 104, 198, 55, 136, 215, 148, 136, 67, 185, 172, 209, 102, 122, 64, 180, 67, + 152, 220, 92, 166, 177, 36, 137, 82, 210, 4, + ]; + let x0_c1_b1_a0_bytes = [ + 86, 8, 54, 207, 80, 124, 211, 250, 195, 16, 41, 225, 151, 234, 74, 235, 6, 80, 128, 23, + 208, 150, 90, 168, 123, 66, 153, 230, 12, 192, 202, 28, 163, 221, 28, 76, 58, 73, 101, 1, + 243, 250, 133, 26, 228, 172, 88, 12, + ]; + let x0_c1_b1_a1_bytes = [ + 100, 82, 131, 139, 164, 216, 135, 48, 179, 232, 54, 9, 39, 131, 147, 137, 241, 60, 21, 218, + 161, 102, 144, 134, 81, 101, 64, 0, 5, 131, 214, 170, 224, 123, 11, 25, 160, 89, 220, 166, + 193, 45, 13, 100, 230, 116, 112, 24, + ]; + let x0_c1_b2_a0_bytes = [ + 247, 221, 42, 90, 51, 107, 26, 120, 49, 75, 158, 9, 75, 55, 71, 121, 59, 126, 96, 1, 14, + 248, 253, 151, 143, 29, 83, 249, 204, 94, 105, 120, 21, 8, 170, 27, 117, 166, 25, 117, 119, + 196, 147, 115, 60, 10, 53, 13, + ]; + let x0_c1_b2_a1_bytes = [ + 46, 173, 19, 115, 230, 103, 157, 253, 229, 42, 46, 181, 62, 74, 133, 99, 144, 63, 196, 246, + 4, 132, 203, 228, 77, 114, 70, 247, 63, 15, 138, 100, 9, 32, 145, 80, 245, 98, 110, 218, + 156, 33, 57, 62, 43, 98, 81, 18, + ]; + let x1_c0_b0_a0_bytes = [ + 148, 30, 71, 204, 89, 128, 39, 211, 200, 173, 12, 53, 49, 151, 93, 248, 122, 184, 53, 28, + 126, 17, 19, 194, 199, 192, 84, 54, 197, 99, 7, 123, 243, 77, 94, 235, 77, 57, 176, 95, + 211, 166, 170, 169, 219, 136, 143, 16, + ]; + let x1_c0_b0_a1_bytes = [ + 116, 165, 190, 228, 91, 60, 196, 159, 85, 252, 213, 69, 1, 2, 255, 229, 48, 82, 242, 236, + 138, 116, 18, 142, 211, 226, 1, 27, 172, 39, 110, 176, 116, 224, 29, 170, 150, 162, 188, + 133, 134, 187, 63, 39, 42, 233, 223, 21, + ]; + let x1_c0_b1_a0_bytes = [ + 52, 188, 3, 110, 86, 230, 166, 129, 55, 12, 222, 175, 157, 177, 232, 228, 128, 150, 69, 11, + 254, 146, 229, 48, 88, 212, 25, 142, 49, 186, 136, 155, 251, 188, 234, 79, 116, 72, 200, + 26, 16, 2, 44, 141, 51, 243, 107, 25, + ]; + let x1_c0_b1_a1_bytes = [ + 189, 11, 14, 178, 64, 171, 213, 99, 42, 92, 224, 19, 135, 91, 69, 10, 17, 74, 95, 100, 229, + 165, 14, 89, 76, 7, 26, 12, 141, 254, 74, 178, 222, 63, 209, 235, 231, 191, 198, 239, 111, + 184, 20, 119, 247, 206, 137, 21, + ]; + let x1_c0_b2_a0_bytes = [ + 212, 172, 221, 198, 21, 214, 123, 10, 204, 162, 176, 184, 103, 196, 108, 104, 238, 168, + 120, 68, 50, 179, 148, 56, 3, 150, 2, 153, 240, 153, 144, 156, 154, 0, 122, 112, 38, 167, + 188, 90, 58, 54, 253, 203, 30, 18, 116, 22, + ]; + let x1_c0_b2_a1_bytes = [ + 90, 124, 114, 30, 19, 47, 172, 69, 32, 76, 109, 59, 202, 137, 251, 14, 81, 116, 190, 33, + 48, 205, 103, 135, 26, 77, 174, 125, 197, 102, 92, 138, 15, 20, 230, 7, 205, 140, 129, 234, + 229, 245, 234, 158, 122, 90, 136, 20, + ]; + let x1_c1_b0_a0_bytes = [ + 200, 82, 45, 114, 38, 64, 114, 217, 14, 159, 26, 201, 98, 79, 228, 4, 175, 96, 242, 120, + 46, 134, 147, 59, 150, 169, 115, 61, 246, 17, 80, 231, 88, 50, 192, 43, 236, 13, 195, 51, + 88, 2, 150, 109, 127, 175, 212, 11, + ]; + let x1_c1_b0_a1_bytes = [ + 90, 205, 64, 128, 120, 157, 119, 255, 181, 86, 183, 85, 39, 214, 168, 122, 184, 70, 236, + 137, 17, 168, 133, 48, 19, 22, 156, 44, 154, 42, 65, 94, 10, 74, 77, 91, 168, 172, 235, + 220, 114, 60, 8, 25, 65, 146, 138, 10, + ]; + let x1_c1_b1_a0_bytes = [ + 79, 42, 100, 15, 28, 174, 145, 214, 133, 51, 126, 38, 14, 120, 235, 155, 26, 216, 119, 134, + 149, 230, 93, 241, 130, 50, 39, 124, 254, 144, 244, 88, 224, 222, 252, 49, 70, 167, 245, + 170, 157, 178, 32, 1, 188, 90, 249, 25, + ]; + let x1_c1_b1_a1_bytes = [ + 23, 37, 23, 6, 168, 183, 104, 99, 161, 213, 146, 108, 40, 203, 206, 138, 143, 9, 137, 68, + 6, 6, 215, 212, 160, 97, 220, 1, 20, 120, 149, 233, 158, 220, 164, 74, 228, 63, 10, 243, + 109, 171, 93, 139, 56, 187, 111, 9, + ]; + let x1_c1_b2_a0_bytes = [ + 88, 170, 4, 14, 40, 128, 9, 37, 112, 153, 51, 44, 207, 24, 160, 166, 202, 141, 45, 176, + 216, 247, 252, 83, 79, 125, 219, 52, 45, 47, 195, 0, 109, 64, 17, 233, 109, 171, 86, 64, + 101, 17, 110, 125, 8, 209, 220, 14, + ]; + let x1_c1_b2_a1_bytes = [ + 45, 80, 195, 74, 220, 212, 197, 127, 138, 75, 183, 100, 244, 133, 63, 126, 203, 191, 237, + 238, 226, 187, 191, 134, 30, 11, 201, 89, 71, 197, 47, 97, 183, 210, 75, 121, 252, 204, 21, + 52, 14, 136, 175, 8, 7, 47, 128, 23, + ]; + let x2_c0_b0_a0_bytes = [ + 207, 123, 59, 41, 71, 199, 157, 119, 27, 229, 2, 72, 223, 13, 8, 251, 25, 170, 95, 253, + 134, 7, 18, 186, 10, 80, 155, 26, 153, 51, 34, 22, 214, 110, 93, 76, 148, 141, 134, 57, + 169, 164, 200, 199, 107, 197, 75, 9, + ]; + let x2_c0_b0_a1_bytes = [ + 68, 69, 224, 93, 98, 215, 204, 82, 194, 61, 18, 31, 46, 41, 185, 144, 206, 231, 31, 51, 74, + 116, 181, 251, 234, 202, 189, 95, 0, 240, 212, 50, 63, 191, 129, 195, 175, 199, 221, 38, + 23, 236, 65, 18, 180, 53, 202, 18, + ]; + let x2_c0_b1_a0_bytes = [ + 186, 144, 32, 185, 138, 99, 144, 82, 150, 0, 247, 3, 0, 51, 10, 166, 185, 154, 142, 185, + 163, 124, 12, 5, 241, 57, 101, 249, 220, 151, 30, 50, 24, 22, 21, 118, 111, 53, 202, 5, 18, + 163, 236, 234, 230, 225, 225, 6, + ]; + let x2_c0_b1_a1_bytes = [ + 129, 230, 32, 169, 143, 192, 38, 68, 3, 67, 71, 66, 246, 63, 61, 99, 79, 114, 226, 254, + 242, 101, 101, 216, 185, 46, 207, 94, 64, 31, 52, 228, 74, 72, 187, 36, 88, 116, 105, 83, + 26, 228, 68, 174, 245, 212, 238, 2, + ]; + let x2_c0_b2_a0_bytes = [ + 235, 52, 202, 255, 51, 211, 85, 55, 201, 206, 154, 2, 143, 64, 110, 181, 182, 241, 133, + 199, 114, 238, 98, 106, 205, 96, 163, 188, 29, 95, 1, 45, 211, 112, 47, 18, 58, 43, 188, + 100, 77, 15, 120, 149, 95, 39, 65, 9, + ]; + let x2_c0_b2_a1_bytes = [ + 60, 162, 193, 242, 39, 0, 247, 34, 1, 223, 4, 60, 36, 176, 54, 21, 250, 134, 236, 109, 178, + 151, 83, 69, 214, 150, 24, 216, 254, 14, 116, 40, 28, 249, 139, 244, 17, 29, 161, 19, 15, + 212, 197, 187, 59, 112, 115, 7, + ]; + let x2_c1_b0_a0_bytes = [ + 113, 218, 47, 127, 22, 250, 161, 186, 250, 57, 145, 231, 193, 242, 195, 30, 103, 173, 138, + 96, 253, 254, 214, 30, 118, 142, 86, 234, 124, 42, 254, 83, 97, 71, 171, 166, 43, 92, 68, + 22, 97, 207, 47, 60, 24, 134, 121, 24, + ]; + let x2_c1_b0_a1_bytes = [ + 84, 142, 210, 101, 68, 101, 232, 128, 181, 72, 18, 139, 50, 49, 154, 240, 63, 62, 97, 119, + 211, 173, 237, 246, 74, 158, 115, 193, 34, 110, 250, 10, 220, 176, 199, 155, 92, 240, 131, + 185, 207, 226, 185, 61, 202, 228, 92, 15, + ]; + let x2_c1_b1_a0_bytes = [ + 250, 135, 154, 222, 108, 42, 102, 23, 74, 68, 83, 86, 167, 98, 138, 104, 253, 49, 71, 167, + 196, 170, 135, 50, 63, 98, 59, 111, 134, 5, 72, 17, 172, 15, 206, 58, 202, 72, 63, 97, 246, + 198, 38, 226, 181, 245, 80, 12, + ]; + let x2_c1_b1_a1_bytes = [ + 208, 204, 154, 145, 76, 144, 241, 217, 84, 190, 117, 196, 80, 78, 182, 245, 92, 80, 237, + 39, 7, 154, 54, 244, 50, 180, 151, 14, 148, 175, 244, 47, 168, 171, 100, 32, 206, 241, 202, + 78, 149, 242, 234, 181, 52, 30, 223, 7, + ]; + let x2_c1_b2_a0_bytes = [ + 164, 221, 47, 104, 91, 235, 36, 227, 161, 228, 125, 132, 27, 80, 59, 1, 226, 21, 221, 186, + 69, 29, 202, 132, 31, 136, 169, 58, 117, 66, 181, 20, 171, 155, 111, 193, 44, 170, 84, 106, + 66, 239, 129, 183, 90, 201, 16, 2, + ]; + let x2_c1_b2_a1_bytes = [ + 176, 82, 215, 189, 194, 60, 100, 195, 112, 118, 145, 104, 52, 208, 24, 195, 55, 9, 1, 239, + 70, 109, 90, 4, 173, 106, 138, 93, 2, 137, 66, 97, 233, 69, 145, 134, 59, 136, 104, 195, + 16, 195, 104, 13, 72, 127, 208, 15, + ]; + + for i in 0..48 { + assignment.x[0][0][0][i] = M31::from(x0_c0_b0_a0_bytes[i]); + assignment.x[0][0][1][i] = M31::from(x0_c0_b0_a1_bytes[i]); + assignment.x[0][1][0][i] = M31::from(x0_c0_b1_a0_bytes[i]); + assignment.x[0][1][1][i] = M31::from(x0_c0_b1_a1_bytes[i]); + assignment.x[0][2][0][i] = M31::from(x0_c0_b2_a0_bytes[i]); + assignment.x[0][2][1][i] = M31::from(x0_c0_b2_a1_bytes[i]); + assignment.x[1][0][0][i] = M31::from(x0_c1_b0_a0_bytes[i]); + assignment.x[1][0][1][i] = M31::from(x0_c1_b0_a1_bytes[i]); + assignment.x[1][1][0][i] = M31::from(x0_c1_b1_a0_bytes[i]); + assignment.x[1][1][1][i] = M31::from(x0_c1_b1_a1_bytes[i]); + assignment.x[1][2][0][i] = M31::from(x0_c1_b2_a0_bytes[i]); + assignment.x[1][2][1][i] = M31::from(x0_c1_b2_a1_bytes[i]); + assignment.y[0][0][0][i] = M31::from(x1_c0_b0_a0_bytes[i]); + assignment.y[0][0][1][i] = M31::from(x1_c0_b0_a1_bytes[i]); + assignment.y[0][1][0][i] = M31::from(x1_c0_b1_a0_bytes[i]); + assignment.y[0][1][1][i] = M31::from(x1_c0_b1_a1_bytes[i]); + assignment.y[0][2][0][i] = M31::from(x1_c0_b2_a0_bytes[i]); + assignment.y[0][2][1][i] = M31::from(x1_c0_b2_a1_bytes[i]); + assignment.y[1][0][0][i] = M31::from(x1_c1_b0_a0_bytes[i]); + assignment.y[1][0][1][i] = M31::from(x1_c1_b0_a1_bytes[i]); + assignment.y[1][1][0][i] = M31::from(x1_c1_b1_a0_bytes[i]); + assignment.y[1][1][1][i] = M31::from(x1_c1_b1_a1_bytes[i]); + assignment.y[1][2][0][i] = M31::from(x1_c1_b2_a0_bytes[i]); + assignment.y[1][2][1][i] = M31::from(x1_c1_b2_a1_bytes[i]); + assignment.z[0][0][0][i] = M31::from(x2_c0_b0_a0_bytes[i]); + assignment.z[0][0][1][i] = M31::from(x2_c0_b0_a1_bytes[i]); + assignment.z[0][1][0][i] = M31::from(x2_c0_b1_a0_bytes[i]); + assignment.z[0][1][1][i] = M31::from(x2_c0_b1_a1_bytes[i]); + assignment.z[0][2][0][i] = M31::from(x2_c0_b2_a0_bytes[i]); + assignment.z[0][2][1][i] = M31::from(x2_c0_b2_a1_bytes[i]); + assignment.z[1][0][0][i] = M31::from(x2_c1_b0_a0_bytes[i]); + assignment.z[1][0][1][i] = M31::from(x2_c1_b0_a1_bytes[i]); + assignment.z[1][1][0][i] = M31::from(x2_c1_b1_a0_bytes[i]); + assignment.z[1][1][1][i] = M31::from(x2_c1_b1_a1_bytes[i]); + assignment.z[1][2][0][i] = M31::from(x2_c1_b2_a0_bytes[i]); + assignment.z[1][2][1][i] = M31::from(x2_c1_b2_a1_bytes[i]); + } + + debug_eval(&E12AddCircuit::default(), &assignment, hint_registry); +} +declare_circuit!(E12SubCircuit { + a: [[[[Variable; 48]; 2]; 3]; 2], + b: [[[[Variable; 48]; 2]; 3]; 2], + c: [[[[Variable; 48]; 2]; 3]; 2], +}); + +impl GenericDefine for E12SubCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext12 = Ext12::new(builder); + + let a_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.a[0][0][0].to_vec(), 0), + a1: new_internal_element(self.a[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[0][1][0].to_vec(), 0), + a1: new_internal_element(self.a[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[0][2][0].to_vec(), 0), + a1: new_internal_element(self.a[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.a[1][0][0].to_vec(), 0), + a1: new_internal_element(self.a[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[1][1][0].to_vec(), 0), + a1: new_internal_element(self.a[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[1][2][0].to_vec(), 0), + a1: new_internal_element(self.a[1][2][1].to_vec(), 0), + }, + }, + }; + + let b_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.b[0][0][0].to_vec(), 0), + a1: new_internal_element(self.b[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.b[0][1][0].to_vec(), 0), + a1: new_internal_element(self.b[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.b[0][2][0].to_vec(), 0), + a1: new_internal_element(self.b[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.b[1][0][0].to_vec(), 0), + a1: new_internal_element(self.b[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.b[1][1][0].to_vec(), 0), + a1: new_internal_element(self.b[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.b[1][2][0].to_vec(), 0), + a1: new_internal_element(self.b[1][2][1].to_vec(), 0), + }, + }, + }; + + let c_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.c[0][0][0].to_vec(), 0), + a1: new_internal_element(self.c[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[0][1][0].to_vec(), 0), + a1: new_internal_element(self.c[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[0][2][0].to_vec(), 0), + a1: new_internal_element(self.c[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.c[1][0][0].to_vec(), 0), + a1: new_internal_element(self.c[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[1][1][0].to_vec(), 0), + a1: new_internal_element(self.c[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[1][2][0].to_vec(), 0), + a1: new_internal_element(self.c[1][2][1].to_vec(), 0), + }, + }, + }; + + let z = ext12.sub(builder, &a_e12, &b_e12); + ext12.assert_isequal(builder, &z, &c_e12); + ext12.ext6.ext2.curve_f.check_mul(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e12_sub() { + compile_generic(&E12SubCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + + let mut assignment = E12SubCircuit:: { + a: [[[[M31::from(0); 48]; 2]; 3]; 2], + b: [[[[M31::from(0); 48]; 2]; 3]; 2], + c: [[[[M31::from(0); 48]; 2]; 3]; 2], + }; + + let x0_c0_b0_a0_bytes = [ + 197, 236, 193, 85, 161, 111, 30, 106, 84, 151, 195, 17, 249, 224, 84, 244, 234, 151, 155, + 63, 74, 153, 175, 165, 235, 125, 153, 130, 243, 107, 14, 105, 245, 28, 233, 106, 75, 57, + 94, 106, 84, 180, 23, 57, 67, 110, 184, 10, + ]; + let x0_c0_b0_a1_bytes = [ + 59, 96, 28, 50, 133, 228, 182, 73, 66, 218, 225, 164, 193, 187, 245, 231, 228, 192, 66, 73, + 171, 154, 154, 62, 133, 130, 233, 245, 172, 151, 229, 221, 180, 146, 34, 210, 144, 85, 244, + 82, 184, 183, 27, 180, 223, 136, 102, 24, + ]; + let x0_c0_b1_a0_bytes = [ + 84, 55, 219, 118, 151, 133, 30, 81, 23, 129, 216, 253, 231, 146, 81, 239, 82, 143, 143, + 240, 153, 190, 91, 53, 196, 35, 118, 126, 126, 117, 228, 158, 50, 171, 35, 147, 148, 104, + 198, 50, 111, 65, 153, 100, 245, 126, 124, 7, + ]; + let x0_c0_b1_a1_bytes = [ + 158, 71, 191, 118, 128, 142, 50, 104, 161, 113, 119, 153, 140, 128, 153, 6, 169, 32, 115, + 6, 250, 209, 208, 97, 194, 1, 162, 91, 12, 42, 22, 245, 136, 71, 91, 95, 227, 52, 40, 208, + 108, 112, 216, 18, 58, 137, 192, 1, + ]; + let x0_c0_b2_a0_bytes = [ + 228, 37, 132, 99, 194, 152, 42, 52, 22, 111, 105, 49, 77, 137, 143, 217, 244, 72, 169, 243, + 233, 48, 144, 134, 104, 208, 140, 34, 253, 229, 139, 181, 9, 39, 20, 5, 49, 42, 213, 22, + 78, 66, 164, 172, 111, 223, 186, 22, + ]; + let x0_c0_b2_a1_bytes = [ + 91, 255, 84, 235, 130, 162, 183, 217, 231, 118, 130, 247, 180, 1, 189, 144, 216, 166, 141, + 55, 72, 168, 144, 255, 240, 224, 253, 181, 195, 202, 154, 136, 143, 131, 24, 12, 18, 54, + 102, 200, 132, 179, 33, 73, 73, 129, 120, 24, + ]; + let x0_c1_b0_a0_bytes = [ + 75, 145, 107, 24, 225, 40, 95, 38, 248, 143, 36, 81, 242, 205, 106, 97, 93, 79, 202, 24, + 215, 215, 203, 153, 98, 58, 232, 124, 142, 40, 126, 86, 171, 9, 120, 56, 12, 102, 208, 245, + 103, 47, 55, 136, 96, 157, 196, 1, + ]; + let x0_c1_b0_a1_bytes = [ + 195, 95, 9, 22, 123, 87, 85, 52, 125, 17, 135, 205, 148, 125, 41, 154, 196, 207, 18, 95, + 210, 76, 5, 80, 165, 167, 180, 14, 149, 98, 136, 29, 247, 65, 214, 62, 90, 127, 47, 44, 19, + 47, 16, 84, 210, 45, 33, 3, + ]; + let x0_c1_b1_a0_bytes = [ + 67, 64, 200, 83, 56, 98, 37, 156, 128, 197, 145, 165, 24, 7, 119, 161, 36, 53, 81, 104, + 132, 26, 28, 154, 249, 99, 147, 13, 200, 123, 226, 105, 94, 31, 96, 107, 114, 36, 246, 164, + 198, 23, 239, 186, 38, 4, 150, 11, + ]; + let x0_c1_b1_a1_bytes = [ + 120, 243, 96, 185, 212, 141, 147, 104, 52, 239, 147, 173, 134, 47, 255, 170, 192, 225, 233, + 197, 7, 190, 254, 207, 196, 69, 228, 67, 11, 209, 193, 162, 29, 33, 62, 134, 198, 93, 171, + 104, 36, 55, 224, 195, 116, 124, 37, 5, + ]; + let x0_c1_b2_a0_bytes = [ + 149, 80, 66, 197, 78, 71, 174, 41, 148, 153, 187, 24, 17, 86, 155, 33, 16, 86, 221, 137, + 135, 115, 244, 2, 255, 150, 239, 226, 231, 115, 224, 37, 155, 126, 196, 79, 207, 144, 253, + 16, 159, 113, 37, 120, 77, 255, 73, 22, + ]; + let x0_c1_b2_a1_bytes = [ + 190, 175, 107, 235, 207, 189, 162, 102, 173, 62, 208, 181, 179, 166, 36, 90, 114, 111, 210, + 198, 113, 141, 199, 109, 94, 157, 183, 9, 128, 240, 121, 117, 148, 236, 238, 69, 107, 66, + 217, 41, 236, 99, 80, 244, 190, 82, 151, 1, + ]; + let x1_c0_b0_a0_bytes = [ + 232, 141, 62, 55, 243, 245, 168, 210, 31, 237, 239, 153, 14, 209, 115, 1, 206, 147, 183, + 64, 152, 81, 49, 18, 190, 179, 192, 37, 84, 115, 137, 165, 244, 132, 222, 69, 0, 30, 137, + 145, 103, 129, 61, 52, 250, 155, 219, 4, + ]; + let x1_c0_b0_a1_bytes = [ + 113, 139, 120, 115, 225, 148, 22, 187, 109, 115, 126, 91, 111, 145, 171, 208, 110, 106, + 149, 194, 93, 202, 135, 38, 207, 224, 84, 228, 29, 20, 108, 242, 236, 97, 233, 108, 121, + 144, 23, 153, 40, 223, 98, 234, 188, 44, 242, 6, + ]; + let x1_c0_b1_a0_bytes = [ + 152, 83, 177, 81, 25, 169, 168, 112, 215, 237, 121, 175, 120, 129, 75, 46, 55, 200, 16, + 106, 154, 231, 73, 168, 62, 216, 151, 228, 249, 41, 11, 107, 158, 140, 67, 215, 117, 16, + 84, 45, 234, 74, 151, 254, 184, 219, 116, 0, + ]; + let x1_c0_b1_a1_bytes = [ + 68, 35, 46, 47, 154, 117, 41, 42, 243, 148, 223, 144, 111, 107, 140, 207, 164, 68, 84, 243, + 64, 128, 254, 216, 177, 233, 131, 227, 40, 19, 194, 153, 248, 80, 201, 0, 127, 63, 59, 155, + 222, 127, 81, 60, 26, 190, 33, 15, + ]; + let x1_c0_b2_a0_bytes = [ + 248, 133, 135, 6, 150, 86, 28, 203, 165, 53, 190, 226, 99, 10, 36, 47, 226, 178, 239, 209, + 159, 91, 220, 5, 67, 62, 117, 35, 108, 130, 199, 12, 45, 245, 84, 40, 110, 201, 159, 184, + 237, 175, 154, 239, 164, 187, 131, 1, + ]; + let x1_c0_b2_a1_bytes = [ + 68, 107, 158, 70, 92, 137, 135, 220, 212, 245, 24, 214, 217, 210, 137, 220, 42, 191, 194, + 42, 243, 143, 219, 231, 52, 64, 89, 157, 205, 97, 52, 209, 9, 61, 136, 37, 202, 247, 64, + 166, 163, 249, 26, 95, 59, 255, 237, 7, + ]; + let x1_c1_b0_a0_bytes = [ + 169, 12, 166, 142, 127, 221, 90, 52, 130, 240, 103, 229, 157, 212, 117, 57, 95, 237, 195, + 145, 196, 87, 41, 204, 201, 55, 101, 137, 193, 53, 23, 73, 177, 252, 212, 131, 1, 89, 170, + 171, 222, 181, 216, 219, 162, 41, 228, 8, + ]; + let x1_c1_b0_a1_bytes = [ + 237, 98, 101, 211, 49, 237, 157, 16, 6, 61, 83, 201, 3, 96, 185, 153, 250, 216, 184, 117, + 159, 246, 233, 96, 23, 119, 118, 103, 88, 80, 126, 68, 66, 214, 147, 46, 209, 159, 243, 75, + 204, 240, 192, 84, 231, 18, 57, 17, + ]; + let x1_c1_b1_a0_bytes = [ + 104, 144, 181, 81, 179, 227, 108, 37, 237, 241, 87, 182, 122, 63, 188, 228, 195, 34, 131, + 244, 136, 121, 187, 97, 57, 55, 255, 12, 229, 30, 113, 5, 129, 97, 18, 46, 21, 43, 137, 24, + 204, 21, 47, 114, 88, 123, 199, 9, + ]; + let x1_c1_b1_a1_bytes = [ + 219, 73, 222, 238, 62, 66, 133, 212, 134, 204, 165, 110, 75, 169, 34, 254, 78, 131, 51, 67, + 27, 193, 8, 56, 180, 137, 126, 251, 241, 176, 69, 38, 15, 118, 107, 98, 68, 68, 96, 1, 144, + 214, 29, 31, 83, 179, 138, 6, + ]; + let x1_c1_b2_a0_bytes = [ + 200, 135, 142, 179, 186, 161, 77, 83, 223, 201, 62, 131, 26, 198, 122, 50, 188, 167, 41, + 219, 122, 80, 74, 9, 1, 233, 94, 222, 127, 179, 185, 37, 73, 200, 87, 78, 147, 149, 225, + 52, 187, 134, 144, 110, 101, 198, 248, 11, + ]; + let x1_c1_b2_a1_bytes = [ + 161, 101, 7, 76, 21, 58, 5, 167, 239, 173, 64, 201, 247, 135, 227, 46, 142, 173, 1, 178, + 43, 222, 120, 104, 27, 246, 152, 18, 240, 122, 233, 85, 242, 136, 136, 113, 15, 145, 142, + 200, 124, 118, 22, 138, 12, 152, 9, 22, + ]; + let x2_c0_b0_a0_bytes = [ + 221, 94, 131, 30, 174, 121, 117, 151, 52, 170, 211, 119, 234, 15, 225, 242, 28, 4, 228, + 254, 177, 71, 126, 147, 45, 202, 216, 92, 159, 248, 132, 195, 0, 152, 10, 37, 75, 27, 213, + 216, 236, 50, 218, 4, 73, 210, 220, 5, + ]; + let x2_c0_b0_a1_bytes = [ + 202, 212, 163, 190, 163, 79, 160, 142, 212, 102, 99, 73, 82, 42, 74, 23, 118, 86, 173, 134, + 77, 208, 18, 24, 182, 161, 148, 17, 143, 131, 121, 235, 199, 48, 57, 101, 23, 197, 220, + 185, 143, 216, 184, 201, 34, 92, 116, 17, + ]; + let x2_c0_b1_a0_bytes = [ + 188, 227, 41, 37, 126, 220, 117, 224, 63, 147, 94, 78, 111, 17, 6, 193, 27, 199, 126, 134, + 255, 214, 17, 141, 133, 75, 222, 153, 132, 75, 217, 51, 148, 30, 224, 187, 30, 88, 114, 5, + 133, 246, 1, 102, 60, 163, 7, 7, + ]; + let x2_c0_b1_a1_bytes = [ + 5, 207, 144, 71, 230, 24, 8, 248, 173, 220, 235, 185, 27, 21, 185, 85, 40, 210, 207, 9, 90, + 36, 3, 240, 207, 42, 163, 107, 104, 98, 203, 191, 103, 163, 221, 161, 26, 157, 8, 128, 40, + 215, 6, 16, 10, 221, 159, 12, + ]; + let x2_c0_b2_a0_bytes = [ + 236, 159, 252, 92, 44, 66, 14, 105, 112, 57, 171, 78, 233, 126, 107, 170, 18, 150, 185, 33, + 74, 213, 179, 128, 37, 146, 23, 255, 144, 99, 196, 168, 220, 49, 191, 220, 194, 96, 53, 94, + 96, 146, 9, 189, 202, 35, 55, 21, + ]; + let x2_c0_b2_a1_bytes = [ + 23, 148, 182, 164, 38, 25, 48, 253, 18, 129, 105, 33, 219, 46, 51, 180, 173, 231, 202, 12, + 85, 24, 181, 23, 188, 160, 164, 24, 246, 104, 102, 183, 133, 70, 144, 230, 71, 62, 37, 34, + 225, 185, 6, 234, 13, 130, 138, 16, + ]; + let x2_c1_b0_a0_bytes = [ + 77, 47, 197, 137, 97, 75, 3, 172, 117, 159, 16, 29, 83, 249, 160, 70, 34, 88, 183, 125, + 179, 82, 211, 52, 88, 21, 8, 231, 81, 62, 222, 113, 209, 185, 238, 247, 192, 180, 65, 149, + 35, 96, 222, 229, 167, 133, 225, 18, + ]; + let x2_c1_b0_a1_bytes = [ + 129, 167, 163, 66, 73, 106, 182, 221, 118, 212, 135, 181, 143, 29, 28, 31, 238, 236, 10, + 224, 211, 40, 76, 86, 77, 67, 195, 154, 193, 93, 129, 61, 140, 24, 142, 83, 63, 135, 87, + 43, 225, 36, 207, 56, 213, 44, 233, 11, + ]; + let x2_c1_b1_a0_bytes = [ + 219, 175, 18, 2, 133, 126, 184, 118, 147, 211, 57, 239, 157, 199, 186, 188, 96, 18, 206, + 115, 251, 160, 96, 56, 192, 44, 148, 0, 227, 92, 113, 100, 221, 189, 77, 61, 93, 249, 108, + 140, 250, 1, 192, 72, 206, 136, 206, 1, + ]; + let x2_c1_b1_a1_bytes = [ + 72, 84, 130, 202, 149, 75, 13, 78, 173, 34, 66, 240, 57, 134, 136, 203, 149, 84, 103, 121, + 141, 207, 38, 255, 207, 206, 234, 59, 158, 107, 243, 224, 229, 87, 30, 103, 56, 193, 102, + 178, 46, 71, 66, 222, 11, 219, 155, 24, + ]; + let x2_c1_b2_a0_bytes = [ + 205, 200, 179, 17, 148, 165, 96, 214, 180, 207, 124, 149, 246, 143, 32, 239, 83, 174, 179, + 174, 12, 35, 170, 249, 253, 173, 144, 4, 104, 192, 38, 0, 82, 182, 108, 1, 60, 251, 27, + 220, 227, 234, 148, 9, 232, 56, 81, 10, + ]; + let x2_c1_b2_a1_bytes = [ + 200, 244, 99, 159, 186, 131, 156, 121, 189, 144, 227, 157, 186, 30, 237, 73, 8, 184, 129, + 11, 231, 129, 127, 108, 2, 186, 163, 234, 20, 193, 7, 132, 121, 16, 178, 23, 18, 89, 102, + 172, 9, 212, 185, 163, 156, 204, 142, 5, + ]; + + for i in 0..48 { + assignment.a[0][0][0][i] = M31::from(x0_c0_b0_a0_bytes[i]); + assignment.a[0][0][1][i] = M31::from(x0_c0_b0_a1_bytes[i]); + assignment.a[0][1][0][i] = M31::from(x0_c0_b1_a0_bytes[i]); + assignment.a[0][1][1][i] = M31::from(x0_c0_b1_a1_bytes[i]); + assignment.a[0][2][0][i] = M31::from(x0_c0_b2_a0_bytes[i]); + assignment.a[0][2][1][i] = M31::from(x0_c0_b2_a1_bytes[i]); + assignment.a[1][0][0][i] = M31::from(x0_c1_b0_a0_bytes[i]); + assignment.a[1][0][1][i] = M31::from(x0_c1_b0_a1_bytes[i]); + assignment.a[1][1][0][i] = M31::from(x0_c1_b1_a0_bytes[i]); + assignment.a[1][1][1][i] = M31::from(x0_c1_b1_a1_bytes[i]); + assignment.a[1][2][0][i] = M31::from(x0_c1_b2_a0_bytes[i]); + assignment.a[1][2][1][i] = M31::from(x0_c1_b2_a1_bytes[i]); + assignment.b[0][0][0][i] = M31::from(x1_c0_b0_a0_bytes[i]); + assignment.b[0][0][1][i] = M31::from(x1_c0_b0_a1_bytes[i]); + assignment.b[0][1][0][i] = M31::from(x1_c0_b1_a0_bytes[i]); + assignment.b[0][1][1][i] = M31::from(x1_c0_b1_a1_bytes[i]); + assignment.b[0][2][0][i] = M31::from(x1_c0_b2_a0_bytes[i]); + assignment.b[0][2][1][i] = M31::from(x1_c0_b2_a1_bytes[i]); + assignment.b[1][0][0][i] = M31::from(x1_c1_b0_a0_bytes[i]); + assignment.b[1][0][1][i] = M31::from(x1_c1_b0_a1_bytes[i]); + assignment.b[1][1][0][i] = M31::from(x1_c1_b1_a0_bytes[i]); + assignment.b[1][1][1][i] = M31::from(x1_c1_b1_a1_bytes[i]); + assignment.b[1][2][0][i] = M31::from(x1_c1_b2_a0_bytes[i]); + assignment.b[1][2][1][i] = M31::from(x1_c1_b2_a1_bytes[i]); + assignment.c[0][0][0][i] = M31::from(x2_c0_b0_a0_bytes[i]); + assignment.c[0][0][1][i] = M31::from(x2_c0_b0_a1_bytes[i]); + assignment.c[0][1][0][i] = M31::from(x2_c0_b1_a0_bytes[i]); + assignment.c[0][1][1][i] = M31::from(x2_c0_b1_a1_bytes[i]); + assignment.c[0][2][0][i] = M31::from(x2_c0_b2_a0_bytes[i]); + assignment.c[0][2][1][i] = M31::from(x2_c0_b2_a1_bytes[i]); + assignment.c[1][0][0][i] = M31::from(x2_c1_b0_a0_bytes[i]); + assignment.c[1][0][1][i] = M31::from(x2_c1_b0_a1_bytes[i]); + assignment.c[1][1][0][i] = M31::from(x2_c1_b1_a0_bytes[i]); + assignment.c[1][1][1][i] = M31::from(x2_c1_b1_a1_bytes[i]); + assignment.c[1][2][0][i] = M31::from(x2_c1_b2_a0_bytes[i]); + assignment.c[1][2][1][i] = M31::from(x2_c1_b2_a1_bytes[i]); + } + + debug_eval(&E12SubCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E12MulCircuit { + a: [[[[Variable; 48]; 2]; 3]; 2], + b: [[[[Variable; 48]; 2]; 3]; 2], + c: [[[[Variable; 48]; 2]; 3]; 2], +}); + +impl GenericDefine for E12MulCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext12 = Ext12::new(builder); + + let a_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.a[0][0][0].to_vec(), 0), + a1: new_internal_element(self.a[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[0][1][0].to_vec(), 0), + a1: new_internal_element(self.a[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[0][2][0].to_vec(), 0), + a1: new_internal_element(self.a[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.a[1][0][0].to_vec(), 0), + a1: new_internal_element(self.a[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[1][1][0].to_vec(), 0), + a1: new_internal_element(self.a[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[1][2][0].to_vec(), 0), + a1: new_internal_element(self.a[1][2][1].to_vec(), 0), + }, + }, + }; + + let b_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.b[0][0][0].to_vec(), 0), + a1: new_internal_element(self.b[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.b[0][1][0].to_vec(), 0), + a1: new_internal_element(self.b[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.b[0][2][0].to_vec(), 0), + a1: new_internal_element(self.b[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.b[1][0][0].to_vec(), 0), + a1: new_internal_element(self.b[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.b[1][1][0].to_vec(), 0), + a1: new_internal_element(self.b[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.b[1][2][0].to_vec(), 0), + a1: new_internal_element(self.b[1][2][1].to_vec(), 0), + }, + }, + }; + + let c_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.c[0][0][0].to_vec(), 0), + a1: new_internal_element(self.c[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[0][1][0].to_vec(), 0), + a1: new_internal_element(self.c[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[0][2][0].to_vec(), 0), + a1: new_internal_element(self.c[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.c[1][0][0].to_vec(), 0), + a1: new_internal_element(self.c[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[1][1][0].to_vec(), 0), + a1: new_internal_element(self.c[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[1][2][0].to_vec(), 0), + a1: new_internal_element(self.c[1][2][1].to_vec(), 0), + }, + }, + }; + + let z = ext12.mul(builder, &a_e12, &b_e12); + ext12.assert_isequal(builder, &z, &c_e12); + ext12.ext6.ext2.curve_f.check_mul(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e12_mul() { + compile_generic(&E12MulCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + + let mut assignment = E12MulCircuit:: { + a: [[[[M31::from(0); 48]; 2]; 3]; 2], + b: [[[[M31::from(0); 48]; 2]; 3]; 2], + c: [[[[M31::from(0); 48]; 2]; 3]; 2], + }; + + let x0_c0_b0_a0_bytes = [ + 18, 16, 175, 85, 34, 237, 118, 71, 162, 164, 89, 178, 78, 181, 29, 51, 79, 100, 35, 97, + 196, 220, 121, 215, 157, 189, 144, 26, 67, 25, 143, 143, 42, 101, 231, 240, 230, 220, 139, + 229, 187, 86, 239, 244, 109, 91, 143, 20, + ]; + let x0_c0_b0_a1_bytes = [ + 104, 153, 197, 146, 135, 101, 130, 39, 74, 182, 160, 38, 197, 224, 5, 133, 142, 105, 202, + 217, 215, 240, 244, 171, 157, 55, 89, 59, 188, 205, 135, 43, 127, 31, 166, 190, 9, 193, 93, + 205, 58, 226, 101, 14, 153, 21, 234, 22, + ]; + let x0_c0_b1_a0_bytes = [ + 126, 212, 100, 36, 202, 52, 184, 67, 214, 199, 123, 245, 2, 167, 137, 57, 81, 54, 78, 8, + 204, 178, 55, 15, 220, 40, 57, 37, 167, 232, 27, 33, 243, 213, 212, 233, 46, 43, 145, 49, + 208, 94, 159, 54, 61, 86, 74, 22, + ]; + let x0_c0_b1_a1_bytes = [ + 174, 111, 11, 165, 30, 60, 48, 155, 87, 253, 31, 26, 63, 238, 208, 50, 127, 61, 238, 214, + 152, 200, 10, 111, 92, 23, 141, 127, 190, 250, 186, 237, 78, 143, 238, 113, 111, 124, 32, + 10, 61, 131, 95, 58, 154, 188, 144, 25, + ]; + let x0_c0_b2_a0_bytes = [ + 59, 200, 148, 183, 6, 226, 234, 205, 189, 41, 155, 50, 205, 1, 73, 159, 234, 93, 20, 65, 7, + 210, 176, 195, 242, 149, 31, 36, 66, 79, 103, 232, 182, 29, 129, 100, 127, 55, 143, 74, 76, + 224, 7, 87, 128, 229, 156, 13, + ]; + let x0_c0_b2_a1_bytes = [ + 110, 72, 137, 164, 201, 4, 40, 254, 210, 231, 146, 39, 192, 152, 171, 24, 237, 83, 153, + 179, 26, 97, 200, 122, 36, 82, 239, 217, 181, 231, 62, 128, 66, 227, 0, 198, 91, 252, 165, + 196, 81, 198, 154, 73, 96, 55, 209, 19, + ]; + let x0_c1_b0_a0_bytes = [ + 169, 129, 186, 227, 169, 163, 212, 206, 238, 76, 175, 179, 26, 251, 188, 55, 225, 254, 135, + 143, 106, 185, 34, 137, 192, 89, 157, 244, 186, 116, 163, 155, 250, 100, 254, 217, 201, 88, + 143, 57, 13, 253, 249, 223, 180, 181, 154, 1, + ]; + let x0_c1_b0_a1_bytes = [ + 241, 145, 54, 93, 184, 84, 47, 57, 100, 101, 64, 216, 140, 119, 185, 24, 79, 78, 187, 112, + 137, 186, 170, 29, 142, 240, 58, 182, 135, 206, 87, 185, 164, 140, 72, 144, 75, 219, 55, + 197, 124, 20, 45, 213, 71, 6, 195, 7, + ]; + let x0_c1_b1_a0_bytes = [ + 205, 64, 90, 100, 21, 169, 136, 39, 56, 72, 95, 160, 189, 175, 183, 219, 70, 48, 253, 114, + 208, 195, 195, 42, 203, 148, 99, 109, 232, 156, 175, 222, 224, 133, 192, 52, 178, 135, 98, + 208, 120, 253, 167, 40, 242, 93, 35, 25, + ]; + let x0_c1_b1_a1_bytes = [ + 3, 148, 43, 205, 241, 107, 73, 27, 92, 128, 127, 56, 26, 71, 93, 197, 106, 244, 30, 151, + 227, 100, 3, 100, 35, 57, 155, 142, 253, 223, 146, 199, 123, 9, 30, 111, 201, 199, 61, 77, + 22, 183, 200, 140, 225, 254, 194, 20, + ]; + let x0_c1_b2_a0_bytes = [ + 50, 105, 205, 33, 216, 5, 48, 84, 66, 141, 202, 6, 27, 142, 141, 74, 204, 171, 60, 145, + 125, 247, 88, 64, 93, 126, 118, 112, 109, 230, 100, 16, 42, 239, 204, 160, 230, 2, 7, 85, + 120, 155, 87, 196, 244, 159, 199, 20, + ]; + let x0_c1_b2_a1_bytes = [ + 11, 173, 240, 71, 15, 10, 199, 212, 101, 196, 123, 200, 143, 223, 216, 254, 40, 78, 66, + 163, 117, 205, 134, 253, 18, 21, 17, 37, 196, 124, 210, 118, 177, 48, 105, 105, 114, 222, + 224, 205, 37, 180, 65, 198, 34, 48, 34, 19, + ]; + let x1_c0_b0_a0_bytes = [ + 240, 137, 36, 51, 174, 210, 159, 102, 67, 7, 163, 220, 57, 196, 207, 116, 18, 202, 148, + 248, 6, 45, 135, 188, 79, 72, 55, 149, 74, 111, 220, 241, 23, 21, 151, 196, 186, 87, 250, + 144, 144, 213, 24, 190, 214, 125, 110, 0, + ]; + let x1_c0_b0_a1_bytes = [ + 60, 27, 22, 130, 117, 251, 130, 122, 140, 235, 142, 212, 10, 48, 246, 0, 141, 46, 146, 86, + 0, 78, 161, 219, 203, 39, 120, 253, 162, 34, 241, 239, 135, 28, 181, 205, 147, 187, 157, + 15, 119, 201, 81, 87, 222, 90, 58, 15, + ]; + let x1_c0_b1_a0_bytes = [ + 73, 70, 72, 123, 87, 235, 173, 13, 165, 233, 46, 210, 182, 119, 13, 209, 194, 46, 94, 218, + 156, 61, 214, 26, 55, 96, 204, 141, 85, 154, 101, 53, 136, 157, 105, 5, 166, 92, 37, 60, + 137, 148, 88, 87, 165, 203, 87, 7, + ]; + let x1_c0_b1_a1_bytes = [ + 251, 149, 3, 244, 35, 194, 49, 215, 250, 29, 193, 89, 177, 75, 111, 95, 111, 154, 179, 253, + 102, 196, 56, 147, 204, 115, 142, 158, 81, 35, 6, 136, 144, 196, 124, 75, 34, 79, 141, 40, + 83, 27, 86, 225, 184, 50, 232, 8, + ]; + let x1_c0_b2_a0_bytes = [ + 234, 29, 186, 114, 252, 192, 80, 101, 188, 72, 170, 15, 249, 50, 15, 0, 160, 97, 98, 53, + 174, 3, 132, 228, 15, 4, 19, 169, 15, 44, 22, 142, 62, 56, 151, 39, 209, 206, 103, 243, + 213, 24, 22, 195, 30, 64, 99, 17, + ]; + let x1_c0_b2_a1_bytes = [ + 41, 14, 48, 194, 233, 49, 189, 213, 184, 242, 130, 15, 112, 59, 59, 234, 226, 157, 204, + 127, 56, 179, 33, 102, 35, 151, 38, 172, 186, 116, 139, 125, 145, 252, 155, 113, 15, 235, + 96, 231, 238, 29, 176, 208, 83, 108, 34, 2, + ]; + let x1_c1_b0_a0_bytes = [ + 217, 237, 38, 213, 242, 122, 12, 249, 193, 156, 147, 167, 44, 167, 3, 183, 85, 155, 233, + 78, 216, 78, 93, 112, 51, 27, 189, 239, 13, 26, 99, 243, 161, 105, 227, 210, 70, 112, 48, + 163, 95, 44, 166, 114, 32, 48, 105, 5, + ]; + let x1_c1_b0_a1_bytes = [ + 191, 202, 154, 207, 61, 76, 176, 195, 236, 143, 41, 42, 233, 188, 57, 152, 85, 0, 209, 84, + 229, 123, 83, 90, 140, 34, 165, 96, 229, 100, 135, 105, 223, 248, 110, 29, 49, 133, 47, + 184, 223, 49, 107, 242, 204, 125, 92, 3, + ]; + let x1_c1_b1_a0_bytes = [ + 222, 196, 209, 22, 166, 64, 174, 112, 126, 200, 126, 250, 49, 210, 117, 146, 45, 137, 127, + 17, 219, 141, 59, 149, 231, 145, 239, 87, 50, 126, 73, 225, 42, 34, 121, 105, 159, 119, + 218, 242, 58, 177, 63, 23, 17, 41, 141, 8, + ]; + let x1_c1_b1_a1_bytes = [ + 51, 253, 245, 231, 88, 162, 251, 225, 148, 169, 24, 17, 157, 53, 128, 177, 87, 114, 85, + 154, 248, 125, 173, 180, 139, 181, 126, 221, 114, 103, 18, 252, 227, 219, 115, 161, 71, 38, + 91, 200, 247, 35, 62, 25, 118, 250, 65, 0, + ]; + let x1_c1_b2_a0_bytes = [ + 60, 154, 232, 54, 209, 216, 161, 46, 119, 93, 48, 165, 158, 118, 33, 17, 110, 132, 136, 27, + 135, 15, 232, 41, 84, 241, 133, 44, 214, 113, 211, 204, 78, 161, 220, 224, 59, 249, 51, + 242, 55, 121, 161, 124, 16, 252, 218, 12, + ]; + let x1_c1_b2_a1_bytes = [ + 137, 242, 221, 198, 166, 207, 120, 212, 128, 29, 46, 23, 109, 110, 227, 228, 253, 14, 75, + 143, 148, 245, 84, 86, 227, 73, 113, 139, 53, 141, 58, 222, 227, 204, 186, 104, 124, 18, + 92, 243, 14, 223, 234, 223, 53, 146, 68, 22, + ]; + let x2_c0_b0_a0_bytes = [ + 1, 149, 245, 118, 70, 112, 151, 116, 114, 158, 58, 126, 125, 134, 169, 173, 222, 62, 254, + 247, 138, 110, 222, 181, 49, 16, 20, 74, 190, 252, 59, 26, 36, 244, 53, 89, 3, 29, 193, 41, + 53, 209, 151, 162, 227, 23, 35, 0, + ]; + let x2_c0_b0_a1_bytes = [ + 198, 137, 108, 161, 94, 178, 221, 160, 92, 142, 20, 161, 203, 198, 212, 161, 200, 102, 184, + 1, 149, 19, 54, 172, 181, 0, 3, 60, 164, 25, 179, 27, 126, 101, 101, 152, 48, 39, 140, 137, + 227, 188, 234, 142, 37, 82, 42, 4, + ]; + let x2_c0_b1_a0_bytes = [ + 214, 32, 230, 177, 23, 76, 224, 158, 211, 4, 191, 255, 210, 124, 182, 226, 204, 174, 70, + 49, 245, 52, 187, 68, 199, 33, 75, 141, 112, 46, 163, 151, 1, 33, 37, 156, 0, 98, 15, 207, + 86, 18, 181, 185, 56, 135, 13, 21, + ]; + let x2_c0_b1_a1_bytes = [ + 237, 204, 148, 175, 56, 19, 91, 99, 62, 247, 203, 193, 89, 176, 166, 172, 184, 135, 23, + 202, 116, 113, 247, 209, 30, 200, 205, 54, 205, 157, 22, 248, 203, 154, 207, 92, 217, 65, + 253, 33, 229, 230, 110, 97, 247, 33, 227, 2, + ]; + let x2_c0_b2_a0_bytes = [ + 152, 32, 127, 72, 230, 253, 163, 95, 208, 104, 71, 35, 71, 74, 212, 182, 56, 212, 49, 178, + 60, 242, 97, 255, 142, 26, 231, 104, 20, 239, 71, 46, 18, 172, 158, 162, 119, 39, 155, 4, + 115, 149, 45, 17, 160, 11, 183, 23, + ]; + let x2_c0_b2_a1_bytes = [ + 214, 55, 28, 255, 211, 238, 206, 210, 80, 24, 120, 165, 76, 1, 7, 137, 190, 11, 229, 167, + 236, 55, 145, 134, 15, 8, 208, 168, 180, 16, 172, 229, 206, 73, 58, 192, 98, 16, 104, 193, + 130, 66, 39, 57, 178, 252, 154, 5, + ]; + let x2_c1_b0_a0_bytes = [ + 19, 208, 0, 191, 6, 160, 11, 114, 241, 154, 85, 194, 234, 149, 134, 185, 117, 13, 200, 110, + 62, 249, 86, 202, 195, 194, 53, 143, 244, 54, 68, 254, 65, 245, 221, 102, 189, 221, 246, + 48, 202, 113, 195, 17, 47, 172, 205, 16, + ]; + let x2_c1_b0_a1_bytes = [ + 24, 133, 121, 38, 233, 140, 70, 206, 19, 114, 131, 40, 250, 61, 165, 157, 3, 218, 12, 156, + 3, 36, 100, 173, 78, 73, 161, 18, 88, 169, 101, 4, 224, 138, 37, 192, 33, 69, 119, 196, + 203, 122, 166, 212, 20, 40, 199, 18, + ]; + let x2_c1_b1_a0_bytes = [ + 58, 180, 157, 138, 178, 143, 59, 160, 99, 147, 56, 53, 155, 35, 65, 227, 23, 162, 191, 243, + 139, 206, 20, 109, 42, 13, 184, 41, 77, 101, 92, 30, 49, 177, 61, 60, 171, 10, 114, 10, + 185, 131, 252, 40, 88, 232, 201, 10, + ]; + let x2_c1_b1_a1_bytes = [ + 117, 238, 170, 146, 84, 80, 82, 70, 144, 134, 148, 70, 182, 153, 18, 73, 252, 151, 171, + 118, 161, 113, 93, 115, 101, 127, 97, 90, 146, 232, 114, 159, 164, 237, 232, 31, 140, 217, + 160, 112, 142, 153, 50, 230, 151, 207, 201, 7, + ]; + let x2_c1_b2_a0_bytes = [ + 218, 19, 179, 196, 132, 93, 249, 221, 47, 165, 80, 237, 178, 80, 214, 236, 26, 67, 226, + 252, 234, 204, 11, 109, 4, 246, 171, 23, 82, 14, 26, 104, 36, 222, 236, 91, 194, 103, 215, + 93, 97, 69, 49, 212, 61, 2, 222, 11, + ]; + let x2_c1_b2_a1_bytes = [ + 8, 132, 51, 137, 1, 206, 121, 67, 104, 212, 9, 238, 140, 14, 73, 74, 65, 177, 167, 226, + 127, 90, 220, 71, 34, 121, 96, 219, 11, 245, 16, 53, 63, 140, 54, 254, 35, 201, 17, 108, + 96, 16, 132, 144, 60, 143, 127, 3, + ]; + + for i in 0..48 { + assignment.a[0][0][0][i] = M31::from(x0_c0_b0_a0_bytes[i]); + assignment.a[0][0][1][i] = M31::from(x0_c0_b0_a1_bytes[i]); + assignment.a[0][1][0][i] = M31::from(x0_c0_b1_a0_bytes[i]); + assignment.a[0][1][1][i] = M31::from(x0_c0_b1_a1_bytes[i]); + assignment.a[0][2][0][i] = M31::from(x0_c0_b2_a0_bytes[i]); + assignment.a[0][2][1][i] = M31::from(x0_c0_b2_a1_bytes[i]); + assignment.a[1][0][0][i] = M31::from(x0_c1_b0_a0_bytes[i]); + assignment.a[1][0][1][i] = M31::from(x0_c1_b0_a1_bytes[i]); + assignment.a[1][1][0][i] = M31::from(x0_c1_b1_a0_bytes[i]); + assignment.a[1][1][1][i] = M31::from(x0_c1_b1_a1_bytes[i]); + assignment.a[1][2][0][i] = M31::from(x0_c1_b2_a0_bytes[i]); + assignment.a[1][2][1][i] = M31::from(x0_c1_b2_a1_bytes[i]); + assignment.b[0][0][0][i] = M31::from(x1_c0_b0_a0_bytes[i]); + assignment.b[0][0][1][i] = M31::from(x1_c0_b0_a1_bytes[i]); + assignment.b[0][1][0][i] = M31::from(x1_c0_b1_a0_bytes[i]); + assignment.b[0][1][1][i] = M31::from(x1_c0_b1_a1_bytes[i]); + assignment.b[0][2][0][i] = M31::from(x1_c0_b2_a0_bytes[i]); + assignment.b[0][2][1][i] = M31::from(x1_c0_b2_a1_bytes[i]); + assignment.b[1][0][0][i] = M31::from(x1_c1_b0_a0_bytes[i]); + assignment.b[1][0][1][i] = M31::from(x1_c1_b0_a1_bytes[i]); + assignment.b[1][1][0][i] = M31::from(x1_c1_b1_a0_bytes[i]); + assignment.b[1][1][1][i] = M31::from(x1_c1_b1_a1_bytes[i]); + assignment.b[1][2][0][i] = M31::from(x1_c1_b2_a0_bytes[i]); + assignment.b[1][2][1][i] = M31::from(x1_c1_b2_a1_bytes[i]); + assignment.c[0][0][0][i] = M31::from(x2_c0_b0_a0_bytes[i]); + assignment.c[0][0][1][i] = M31::from(x2_c0_b0_a1_bytes[i]); + assignment.c[0][1][0][i] = M31::from(x2_c0_b1_a0_bytes[i]); + assignment.c[0][1][1][i] = M31::from(x2_c0_b1_a1_bytes[i]); + assignment.c[0][2][0][i] = M31::from(x2_c0_b2_a0_bytes[i]); + assignment.c[0][2][1][i] = M31::from(x2_c0_b2_a1_bytes[i]); + assignment.c[1][0][0][i] = M31::from(x2_c1_b0_a0_bytes[i]); + assignment.c[1][0][1][i] = M31::from(x2_c1_b0_a1_bytes[i]); + assignment.c[1][1][0][i] = M31::from(x2_c1_b1_a0_bytes[i]); + assignment.c[1][1][1][i] = M31::from(x2_c1_b1_a1_bytes[i]); + assignment.c[1][2][0][i] = M31::from(x2_c1_b2_a0_bytes[i]); + assignment.c[1][2][1][i] = M31::from(x2_c1_b2_a1_bytes[i]); + } + debug_eval(&E12MulCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E12DivCircuit { + a: [[[[Variable; 48]; 2]; 3]; 2], + b: [[[[Variable; 48]; 2]; 3]; 2], + c: [[[[Variable; 48]; 2]; 3]; 2], +}); + +impl GenericDefine for E12DivCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext12 = Ext12::new(builder); + + let a_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.a[0][0][0].to_vec(), 0), + a1: new_internal_element(self.a[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[0][1][0].to_vec(), 0), + a1: new_internal_element(self.a[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[0][2][0].to_vec(), 0), + a1: new_internal_element(self.a[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.a[1][0][0].to_vec(), 0), + a1: new_internal_element(self.a[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[1][1][0].to_vec(), 0), + a1: new_internal_element(self.a[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[1][2][0].to_vec(), 0), + a1: new_internal_element(self.a[1][2][1].to_vec(), 0), + }, + }, + }; + + let b_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.b[0][0][0].to_vec(), 0), + a1: new_internal_element(self.b[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.b[0][1][0].to_vec(), 0), + a1: new_internal_element(self.b[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.b[0][2][0].to_vec(), 0), + a1: new_internal_element(self.b[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.b[1][0][0].to_vec(), 0), + a1: new_internal_element(self.b[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.b[1][1][0].to_vec(), 0), + a1: new_internal_element(self.b[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.b[1][2][0].to_vec(), 0), + a1: new_internal_element(self.b[1][2][1].to_vec(), 0), + }, + }, + }; + + let c_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.c[0][0][0].to_vec(), 0), + a1: new_internal_element(self.c[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[0][1][0].to_vec(), 0), + a1: new_internal_element(self.c[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[0][2][0].to_vec(), 0), + a1: new_internal_element(self.c[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.c[1][0][0].to_vec(), 0), + a1: new_internal_element(self.c[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[1][1][0].to_vec(), 0), + a1: new_internal_element(self.c[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[1][2][0].to_vec(), 0), + a1: new_internal_element(self.c[1][2][1].to_vec(), 0), + }, + }, + }; + + let z = ext12.div(builder, &a_e12, &b_e12); + ext12.assert_isequal(builder, &z, &c_e12); + ext12.ext6.ext2.curve_f.check_mul(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e12_div() { + compile_generic(&E12DivCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + + let mut assignment = E12DivCircuit:: { + a: [[[[M31::from(0); 48]; 2]; 3]; 2], + b: [[[[M31::from(0); 48]; 2]; 3]; 2], + c: [[[[M31::from(0); 48]; 2]; 3]; 2], + }; + + let x0_c0_b0_a0_bytes = [ + 254, 180, 220, 147, 183, 118, 153, 36, 195, 182, 38, 75, 52, 106, 65, 31, 129, 247, 165, + 36, 249, 44, 176, 1, 42, 106, 237, 185, 148, 192, 231, 0, 123, 186, 60, 239, 65, 203, 166, + 161, 15, 211, 65, 114, 65, 36, 80, 3, + ]; + let x0_c0_b0_a1_bytes = [ + 58, 127, 245, 147, 170, 27, 20, 107, 100, 56, 192, 22, 167, 172, 88, 219, 98, 126, 91, 86, + 29, 142, 117, 156, 166, 36, 223, 50, 161, 179, 178, 252, 125, 164, 147, 159, 249, 111, 70, + 48, 106, 58, 142, 112, 204, 211, 72, 18, + ]; + let x0_c0_b1_a0_bytes = [ + 151, 18, 147, 147, 3, 131, 131, 230, 185, 24, 54, 136, 249, 234, 141, 241, 80, 44, 100, + 169, 203, 250, 245, 208, 130, 171, 36, 70, 145, 68, 7, 223, 110, 161, 240, 4, 188, 221, + 252, 143, 243, 16, 70, 147, 121, 203, 207, 23, + ]; + let x0_c0_b1_a1_bytes = [ + 121, 192, 157, 27, 84, 232, 248, 218, 216, 193, 26, 58, 161, 185, 51, 106, 144, 142, 48, + 62, 254, 62, 201, 224, 38, 98, 44, 105, 90, 96, 51, 6, 219, 241, 23, 198, 109, 39, 66, 76, + 236, 6, 84, 98, 197, 72, 92, 7, + ]; + let x0_c0_b2_a0_bytes = [ + 183, 181, 165, 60, 147, 229, 250, 166, 11, 193, 79, 192, 12, 161, 71, 94, 96, 212, 33, 91, + 80, 90, 141, 52, 246, 64, 44, 85, 182, 252, 39, 164, 76, 235, 131, 247, 38, 57, 62, 96, + 252, 55, 9, 170, 175, 36, 14, 11, + ]; + let x0_c0_b2_a1_bytes = [ + 189, 156, 0, 235, 163, 90, 36, 226, 124, 135, 231, 181, 119, 172, 9, 171, 212, 53, 232, 31, + 193, 188, 40, 186, 228, 71, 128, 43, 21, 97, 254, 245, 137, 234, 155, 125, 218, 241, 206, + 42, 136, 184, 220, 122, 164, 26, 18, 23, + ]; + let x0_c1_b0_a0_bytes = [ + 200, 146, 209, 175, 82, 195, 145, 241, 54, 31, 18, 193, 200, 8, 41, 161, 43, 94, 59, 219, + 81, 128, 85, 13, 162, 9, 141, 39, 157, 70, 246, 131, 164, 104, 76, 227, 219, 42, 112, 136, + 166, 45, 200, 246, 225, 51, 28, 16, + ]; + let x0_c1_b0_a1_bytes = [ + 54, 115, 148, 26, 219, 101, 46, 245, 26, 216, 90, 142, 45, 183, 28, 250, 222, 213, 38, 96, + 62, 92, 225, 241, 52, 207, 25, 59, 75, 34, 131, 253, 200, 155, 159, 146, 254, 106, 174, + 192, 21, 208, 115, 104, 89, 82, 201, 12, + ]; + let x0_c1_b1_a0_bytes = [ + 46, 14, 236, 125, 150, 59, 135, 79, 129, 202, 43, 29, 226, 36, 157, 208, 201, 235, 145, 77, + 132, 64, 130, 98, 74, 100, 107, 125, 50, 147, 171, 37, 61, 119, 183, 122, 28, 64, 223, 191, + 159, 52, 64, 220, 183, 77, 68, 24, + ]; + let x0_c1_b1_a1_bytes = [ + 120, 70, 77, 94, 71, 235, 65, 233, 161, 74, 206, 155, 203, 39, 168, 202, 136, 61, 64, 186, + 114, 75, 137, 76, 47, 131, 84, 47, 137, 223, 249, 64, 195, 103, 21, 145, 78, 20, 37, 241, + 150, 118, 48, 64, 106, 50, 197, 1, + ]; + let x0_c1_b2_a0_bytes = [ + 17, 70, 175, 245, 238, 38, 4, 224, 115, 31, 107, 233, 28, 224, 149, 204, 77, 150, 169, 55, + 196, 94, 107, 75, 35, 11, 131, 95, 212, 212, 103, 64, 210, 147, 241, 48, 58, 129, 205, 213, + 250, 8, 69, 13, 93, 27, 215, 13, + ]; + let x0_c1_b2_a1_bytes = [ + 42, 34, 192, 185, 113, 199, 199, 165, 168, 0, 80, 76, 229, 232, 229, 191, 97, 111, 8, 96, + 226, 177, 83, 192, 195, 209, 33, 216, 64, 40, 10, 244, 85, 12, 215, 16, 249, 93, 55, 53, + 217, 94, 24, 147, 149, 76, 113, 6, + ]; + let x1_c0_b0_a0_bytes = [ + 60, 92, 218, 84, 110, 123, 199, 41, 87, 94, 192, 231, 66, 152, 5, 186, 92, 211, 103, 33, + 232, 228, 151, 5, 206, 231, 89, 46, 57, 39, 158, 50, 208, 83, 252, 217, 228, 52, 254, 107, + 229, 46, 105, 152, 31, 93, 35, 17, + ]; + let x1_c0_b0_a1_bytes = [ + 106, 251, 2, 54, 89, 25, 70, 97, 241, 184, 44, 143, 138, 187, 197, 209, 110, 166, 22, 156, + 71, 37, 31, 87, 29, 181, 17, 61, 83, 135, 73, 230, 255, 106, 77, 58, 230, 157, 180, 41, 5, + 26, 227, 40, 196, 78, 186, 17, + ]; + let x1_c0_b1_a0_bytes = [ + 92, 84, 110, 29, 202, 71, 43, 200, 70, 116, 31, 50, 19, 195, 144, 50, 12, 139, 209, 28, 36, + 225, 89, 241, 99, 233, 171, 30, 24, 3, 155, 50, 66, 251, 10, 200, 186, 86, 96, 105, 213, + 248, 85, 248, 110, 35, 26, 15, + ]; + let x1_c0_b1_a1_bytes = [ + 173, 116, 187, 196, 213, 153, 240, 42, 151, 106, 69, 11, 251, 231, 152, 77, 136, 117, 57, + 154, 178, 108, 49, 165, 171, 24, 80, 207, 93, 16, 90, 195, 135, 66, 214, 92, 73, 4, 104, + 238, 29, 167, 252, 105, 52, 81, 23, 22, + ]; + let x1_c0_b2_a0_bytes = [ + 253, 140, 214, 65, 230, 229, 249, 148, 5, 249, 97, 222, 240, 204, 100, 136, 64, 100, 75, + 68, 242, 70, 163, 21, 135, 141, 119, 166, 131, 42, 135, 3, 194, 210, 22, 59, 225, 133, 172, + 6, 16, 40, 181, 52, 69, 227, 26, 21, + ]; + let x1_c0_b2_a1_bytes = [ + 137, 181, 69, 64, 102, 26, 114, 215, 0, 254, 8, 156, 53, 38, 158, 33, 146, 155, 37, 52, + 246, 157, 120, 135, 96, 158, 208, 90, 4, 175, 163, 68, 23, 3, 241, 72, 20, 104, 92, 28, 13, + 67, 243, 77, 23, 215, 179, 19, + ]; + let x1_c1_b0_a0_bytes = [ + 191, 220, 69, 111, 219, 69, 192, 59, 150, 42, 118, 235, 174, 95, 241, 145, 147, 190, 224, + 65, 24, 164, 80, 235, 5, 139, 74, 198, 133, 37, 191, 215, 254, 131, 233, 11, 159, 122, 64, + 226, 236, 56, 135, 186, 246, 167, 252, 21, + ]; + let x1_c1_b0_a1_bytes = [ + 108, 243, 84, 77, 223, 98, 25, 156, 113, 210, 47, 53, 192, 254, 227, 74, 12, 183, 85, 153, + 146, 247, 161, 172, 86, 65, 68, 123, 204, 144, 221, 107, 98, 46, 176, 204, 146, 72, 63, + 145, 71, 177, 139, 186, 180, 139, 12, 6, + ]; + let x1_c1_b1_a0_bytes = [ + 95, 108, 116, 45, 180, 244, 62, 115, 53, 224, 132, 50, 185, 217, 204, 60, 186, 144, 222, + 208, 83, 181, 49, 156, 28, 44, 121, 85, 31, 90, 218, 15, 179, 99, 131, 15, 76, 228, 231, + 151, 54, 50, 127, 19, 13, 29, 231, 21, + ]; + let x1_c1_b1_a1_bytes = [ + 208, 84, 155, 33, 71, 227, 55, 60, 166, 69, 70, 175, 217, 19, 65, 151, 96, 229, 196, 237, + 185, 71, 127, 24, 116, 26, 180, 160, 101, 9, 181, 128, 127, 140, 20, 237, 51, 116, 229, 87, + 4, 70, 219, 177, 136, 38, 190, 10, + ]; + let x1_c1_b2_a0_bytes = [ + 110, 182, 233, 157, 108, 35, 70, 151, 135, 60, 100, 224, 22, 31, 244, 228, 93, 8, 123, 41, + 197, 189, 48, 115, 15, 13, 226, 43, 179, 173, 65, 228, 169, 140, 61, 83, 207, 232, 250, + 179, 24, 134, 51, 212, 101, 172, 196, 0, + ]; + let x1_c1_b2_a1_bytes = [ + 23, 226, 188, 161, 124, 0, 174, 246, 12, 60, 212, 16, 30, 23, 148, 45, 120, 66, 11, 61, + 225, 76, 178, 199, 73, 143, 156, 121, 137, 33, 85, 79, 171, 168, 197, 87, 245, 121, 93, + 254, 29, 223, 214, 163, 159, 182, 77, 25, + ]; + let x2_c0_b0_a0_bytes = [ + 193, 85, 60, 41, 60, 152, 106, 114, 148, 237, 154, 211, 214, 196, 213, 101, 115, 247, 217, + 223, 117, 55, 13, 175, 77, 123, 244, 52, 227, 28, 169, 27, 217, 47, 69, 149, 188, 93, 70, + 195, 43, 183, 207, 133, 86, 80, 194, 10, + ]; + let x2_c0_b0_a1_bytes = [ + 36, 127, 151, 163, 201, 85, 223, 30, 16, 103, 144, 95, 65, 225, 213, 110, 31, 137, 215, + 101, 254, 117, 77, 161, 242, 65, 131, 175, 78, 158, 70, 195, 181, 212, 1, 41, 189, 131, + 187, 191, 33, 51, 232, 34, 165, 99, 97, 4, + ]; + let x2_c0_b1_a0_bytes = [ + 44, 106, 74, 150, 120, 208, 238, 66, 3, 250, 179, 67, 229, 57, 59, 90, 42, 240, 255, 7, 57, + 35, 228, 233, 92, 6, 27, 158, 84, 101, 228, 120, 131, 163, 134, 252, 160, 195, 147, 169, + 94, 217, 133, 110, 3, 36, 169, 14, + ]; + let x2_c0_b1_a1_bytes = [ + 207, 75, 223, 255, 56, 145, 37, 87, 131, 151, 214, 99, 155, 236, 192, 39, 57, 184, 80, 4, + 204, 139, 105, 209, 89, 221, 48, 231, 216, 143, 50, 106, 51, 240, 179, 216, 42, 92, 12, + 208, 162, 59, 252, 106, 187, 52, 78, 14, + ]; + let x2_c0_b2_a0_bytes = [ + 44, 163, 90, 136, 20, 187, 82, 175, 60, 123, 68, 24, 184, 102, 100, 24, 63, 8, 135, 105, 0, + 199, 31, 20, 76, 35, 214, 148, 84, 105, 12, 191, 159, 196, 105, 93, 143, 74, 141, 66, 144, + 145, 35, 193, 91, 237, 131, 17, + ]; + let x2_c0_b2_a1_bytes = [ + 0, 57, 213, 117, 115, 227, 33, 33, 242, 96, 162, 92, 199, 126, 170, 210, 90, 42, 239, 201, + 182, 137, 254, 147, 209, 115, 88, 138, 184, 7, 209, 171, 204, 145, 116, 8, 81, 149, 240, + 199, 215, 224, 91, 183, 175, 14, 114, 24, + ]; + let x2_c1_b0_a0_bytes = [ + 107, 154, 39, 211, 222, 105, 63, 163, 49, 5, 83, 98, 183, 5, 225, 130, 171, 221, 182, 166, + 175, 207, 123, 42, 34, 243, 78, 52, 125, 132, 149, 71, 217, 140, 159, 127, 245, 185, 119, + 173, 169, 45, 59, 3, 168, 213, 214, 3, + ]; + let x2_c1_b0_a1_bytes = [ + 41, 123, 78, 190, 56, 110, 2, 65, 52, 247, 49, 179, 167, 29, 231, 228, 230, 200, 225, 201, + 125, 207, 251, 92, 191, 56, 173, 61, 137, 11, 175, 65, 228, 18, 121, 196, 134, 228, 2, 210, + 12, 3, 33, 212, 17, 25, 4, 20, + ]; + let x2_c1_b1_a0_bytes = [ + 73, 240, 43, 201, 245, 221, 180, 227, 71, 110, 86, 238, 235, 55, 11, 107, 92, 120, 130, 19, + 228, 202, 128, 10, 18, 152, 0, 147, 39, 137, 150, 101, 173, 186, 0, 4, 168, 152, 25, 126, + 111, 212, 205, 16, 197, 159, 87, 8, + ]; + let x2_c1_b1_a1_bytes = [ + 241, 9, 83, 199, 86, 120, 96, 84, 72, 214, 186, 152, 30, 128, 230, 207, 67, 248, 15, 247, + 245, 117, 250, 32, 214, 193, 219, 69, 24, 112, 89, 102, 226, 19, 43, 231, 198, 14, 141, 1, + 110, 7, 177, 148, 133, 72, 114, 13, + ]; + let x2_c1_b2_a0_bytes = [ + 141, 154, 251, 95, 73, 155, 76, 96, 218, 10, 96, 92, 236, 217, 69, 22, 189, 223, 80, 166, + 99, 163, 248, 207, 18, 31, 22, 51, 34, 37, 225, 6, 148, 150, 160, 141, 243, 6, 220, 106, + 158, 239, 73, 179, 78, 81, 96, 9, + ]; + let x2_c1_b2_a1_bytes = [ + 251, 124, 170, 135, 94, 22, 235, 110, 117, 182, 48, 254, 114, 133, 34, 113, 83, 69, 102, + 241, 200, 233, 124, 188, 239, 165, 178, 171, 57, 37, 214, 60, 30, 131, 116, 44, 118, 206, + 190, 85, 20, 118, 212, 69, 194, 20, 81, 16, + ]; + + for i in 0..48 { + assignment.a[0][0][0][i] = M31::from(x0_c0_b0_a0_bytes[i]); + assignment.a[0][0][1][i] = M31::from(x0_c0_b0_a1_bytes[i]); + assignment.a[0][1][0][i] = M31::from(x0_c0_b1_a0_bytes[i]); + assignment.a[0][1][1][i] = M31::from(x0_c0_b1_a1_bytes[i]); + assignment.a[0][2][0][i] = M31::from(x0_c0_b2_a0_bytes[i]); + assignment.a[0][2][1][i] = M31::from(x0_c0_b2_a1_bytes[i]); + assignment.a[1][0][0][i] = M31::from(x0_c1_b0_a0_bytes[i]); + assignment.a[1][0][1][i] = M31::from(x0_c1_b0_a1_bytes[i]); + assignment.a[1][1][0][i] = M31::from(x0_c1_b1_a0_bytes[i]); + assignment.a[1][1][1][i] = M31::from(x0_c1_b1_a1_bytes[i]); + assignment.a[1][2][0][i] = M31::from(x0_c1_b2_a0_bytes[i]); + assignment.a[1][2][1][i] = M31::from(x0_c1_b2_a1_bytes[i]); + assignment.b[0][0][0][i] = M31::from(x1_c0_b0_a0_bytes[i]); + assignment.b[0][0][1][i] = M31::from(x1_c0_b0_a1_bytes[i]); + assignment.b[0][1][0][i] = M31::from(x1_c0_b1_a0_bytes[i]); + assignment.b[0][1][1][i] = M31::from(x1_c0_b1_a1_bytes[i]); + assignment.b[0][2][0][i] = M31::from(x1_c0_b2_a0_bytes[i]); + assignment.b[0][2][1][i] = M31::from(x1_c0_b2_a1_bytes[i]); + assignment.b[1][0][0][i] = M31::from(x1_c1_b0_a0_bytes[i]); + assignment.b[1][0][1][i] = M31::from(x1_c1_b0_a1_bytes[i]); + assignment.b[1][1][0][i] = M31::from(x1_c1_b1_a0_bytes[i]); + assignment.b[1][1][1][i] = M31::from(x1_c1_b1_a1_bytes[i]); + assignment.b[1][2][0][i] = M31::from(x1_c1_b2_a0_bytes[i]); + assignment.b[1][2][1][i] = M31::from(x1_c1_b2_a1_bytes[i]); + assignment.c[0][0][0][i] = M31::from(x2_c0_b0_a0_bytes[i]); + assignment.c[0][0][1][i] = M31::from(x2_c0_b0_a1_bytes[i]); + assignment.c[0][1][0][i] = M31::from(x2_c0_b1_a0_bytes[i]); + assignment.c[0][1][1][i] = M31::from(x2_c0_b1_a1_bytes[i]); + assignment.c[0][2][0][i] = M31::from(x2_c0_b2_a0_bytes[i]); + assignment.c[0][2][1][i] = M31::from(x2_c0_b2_a1_bytes[i]); + assignment.c[1][0][0][i] = M31::from(x2_c1_b0_a0_bytes[i]); + assignment.c[1][0][1][i] = M31::from(x2_c1_b0_a1_bytes[i]); + assignment.c[1][1][0][i] = M31::from(x2_c1_b1_a0_bytes[i]); + assignment.c[1][1][1][i] = M31::from(x2_c1_b1_a1_bytes[i]); + assignment.c[1][2][0][i] = M31::from(x2_c1_b2_a0_bytes[i]); + assignment.c[1][2][1][i] = M31::from(x2_c1_b2_a1_bytes[i]); + } + debug_eval(&E12DivCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E12SquareCircuit { + a: [[[[Variable; 48]; 2]; 3]; 2], + c: [[[[Variable; 48]; 2]; 3]; 2], +}); + +impl GenericDefine for E12SquareCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext12 = Ext12::new(builder); + + let a_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.a[0][0][0].to_vec(), 0), + a1: new_internal_element(self.a[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[0][1][0].to_vec(), 0), + a1: new_internal_element(self.a[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[0][2][0].to_vec(), 0), + a1: new_internal_element(self.a[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.a[1][0][0].to_vec(), 0), + a1: new_internal_element(self.a[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[1][1][0].to_vec(), 0), + a1: new_internal_element(self.a[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[1][2][0].to_vec(), 0), + a1: new_internal_element(self.a[1][2][1].to_vec(), 0), + }, + }, + }; + + let c_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.c[0][0][0].to_vec(), 0), + a1: new_internal_element(self.c[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[0][1][0].to_vec(), 0), + a1: new_internal_element(self.c[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[0][2][0].to_vec(), 0), + a1: new_internal_element(self.c[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.c[1][0][0].to_vec(), 0), + a1: new_internal_element(self.c[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[1][1][0].to_vec(), 0), + a1: new_internal_element(self.c[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[1][2][0].to_vec(), 0), + a1: new_internal_element(self.c[1][2][1].to_vec(), 0), + }, + }, + }; + + let z = ext12.square(builder, &a_e12); + ext12.assert_isequal(builder, &z, &c_e12); + ext12.ext6.ext2.curve_f.check_mul(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e12_square() { + compile_generic(&E12SquareCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + + let mut assignment = E12SquareCircuit:: { + a: [[[[M31::from(0); 48]; 2]; 3]; 2], + c: [[[[M31::from(0); 48]; 2]; 3]; 2], + }; + let x0_c0_b0_a0_bytes = [ + 88, 133, 252, 130, 248, 35, 113, 86, 1, 233, 243, 26, 171, 123, 147, 247, 95, 0, 7, 89, + 214, 56, 125, 216, 216, 127, 82, 24, 54, 235, 55, 222, 80, 208, 90, 30, 69, 10, 30, 120, + 48, 239, 117, 55, 217, 64, 92, 3, + ]; + let x0_c0_b0_a1_bytes = [ + 47, 64, 88, 248, 212, 179, 29, 77, 32, 27, 51, 247, 199, 202, 142, 158, 234, 53, 177, 201, + 181, 197, 9, 1, 31, 109, 21, 63, 26, 22, 191, 120, 78, 20, 57, 233, 71, 10, 97, 44, 87, + 107, 192, 4, 172, 27, 240, 7, + ]; + let x0_c0_b1_a0_bytes = [ + 111, 64, 203, 144, 84, 246, 36, 84, 242, 40, 158, 185, 116, 81, 136, 56, 251, 133, 233, + 214, 83, 122, 228, 55, 216, 140, 109, 26, 132, 43, 108, 73, 117, 38, 229, 19, 179, 243, + 194, 140, 171, 145, 49, 72, 198, 113, 51, 3, + ]; + let x0_c0_b1_a1_bytes = [ + 2, 221, 248, 230, 28, 200, 185, 145, 172, 223, 125, 173, 202, 235, 152, 115, 44, 129, 108, + 105, 30, 91, 192, 218, 226, 80, 249, 76, 17, 193, 35, 250, 4, 9, 113, 22, 3, 93, 184, 59, + 69, 215, 238, 187, 14, 11, 126, 12, + ]; + let x0_c0_b2_a0_bytes = [ + 27, 66, 201, 99, 213, 78, 185, 239, 188, 95, 52, 87, 91, 2, 47, 201, 133, 144, 37, 59, 95, + 204, 68, 241, 81, 241, 17, 237, 119, 31, 105, 139, 9, 146, 5, 39, 56, 173, 211, 225, 43, + 100, 93, 64, 31, 193, 100, 10, + ]; + let x0_c0_b2_a1_bytes = [ + 228, 177, 70, 5, 221, 20, 28, 35, 107, 127, 168, 19, 216, 192, 192, 181, 75, 230, 226, 61, + 207, 8, 216, 81, 59, 93, 251, 237, 217, 32, 38, 31, 95, 239, 31, 7, 145, 48, 34, 226, 221, + 44, 148, 141, 166, 180, 57, 7, + ]; + let x0_c1_b0_a0_bytes = [ + 33, 25, 52, 14, 225, 200, 176, 33, 108, 144, 161, 200, 90, 168, 64, 62, 88, 113, 62, 78, + 211, 132, 185, 129, 131, 61, 99, 106, 157, 96, 28, 164, 122, 234, 91, 235, 157, 10, 45, 85, + 72, 219, 225, 17, 132, 159, 195, 5, + ]; + let x0_c1_b0_a1_bytes = [ + 223, 155, 91, 253, 92, 116, 16, 228, 169, 220, 252, 34, 61, 87, 155, 157, 60, 96, 94, 132, + 199, 11, 87, 64, 80, 75, 251, 183, 190, 249, 50, 35, 104, 10, 82, 173, 246, 8, 80, 230, + 221, 119, 131, 247, 72, 216, 153, 18, + ]; + let x0_c1_b1_a0_bytes = [ + 250, 77, 130, 197, 255, 70, 2, 248, 42, 12, 139, 237, 212, 143, 76, 125, 58, 221, 126, 44, + 217, 108, 8, 44, 150, 215, 153, 92, 49, 204, 179, 33, 8, 83, 253, 253, 229, 92, 72, 29, + 153, 131, 175, 39, 242, 89, 235, 12, + ]; + let x0_c1_b1_a1_bytes = [ + 96, 18, 99, 160, 37, 232, 100, 97, 94, 236, 38, 1, 124, 12, 127, 200, 142, 187, 92, 198, + 147, 114, 204, 177, 246, 34, 120, 66, 174, 224, 9, 250, 150, 182, 72, 229, 183, 57, 65, + 247, 239, 206, 37, 238, 217, 89, 113, 25, + ]; + let x0_c1_b2_a0_bytes = [ + 86, 113, 59, 186, 59, 194, 185, 19, 155, 48, 222, 99, 52, 213, 161, 32, 61, 208, 232, 126, + 193, 112, 193, 226, 67, 195, 78, 127, 121, 178, 125, 13, 230, 244, 75, 177, 128, 121, 245, + 106, 83, 157, 242, 30, 200, 116, 51, 10, + ]; + let x0_c1_b2_a1_bytes = [ + 205, 30, 202, 83, 93, 70, 131, 165, 76, 200, 101, 80, 49, 88, 147, 27, 104, 214, 227, 187, + 205, 246, 9, 210, 191, 12, 61, 187, 179, 172, 253, 254, 225, 192, 102, 190, 69, 17, 48, + 139, 88, 29, 190, 237, 160, 59, 213, 14, + ]; + let x2_c0_b0_a0_bytes = [ + 71, 158, 226, 94, 15, 60, 102, 52, 213, 157, 153, 47, 92, 130, 187, 97, 53, 22, 93, 208, + 27, 134, 165, 158, 166, 222, 70, 179, 83, 210, 55, 113, 161, 158, 96, 191, 132, 115, 16, + 164, 235, 215, 203, 8, 202, 111, 164, 3, + ]; + let x2_c0_b0_a1_bytes = [ + 179, 17, 26, 7, 85, 29, 212, 237, 20, 225, 222, 113, 225, 254, 24, 89, 220, 91, 66, 47, + 152, 193, 2, 54, 108, 109, 51, 87, 211, 82, 62, 172, 127, 106, 122, 174, 245, 147, 92, 70, + 38, 144, 48, 137, 23, 23, 117, 22, + ]; + let x2_c0_b1_a0_bytes = [ + 149, 111, 12, 131, 79, 201, 24, 186, 92, 70, 254, 36, 2, 125, 222, 214, 235, 139, 219, 116, + 105, 235, 108, 63, 81, 142, 61, 218, 32, 17, 138, 25, 183, 233, 98, 216, 36, 229, 68, 9, + 135, 245, 251, 153, 91, 52, 129, 20, + ]; + let x2_c0_b1_a1_bytes = [ + 51, 116, 227, 199, 197, 224, 41, 11, 194, 139, 151, 58, 114, 28, 52, 215, 47, 181, 200, 32, + 127, 140, 72, 184, 187, 135, 229, 18, 183, 11, 182, 22, 17, 9, 249, 145, 114, 57, 88, 239, + 131, 231, 65, 6, 155, 194, 254, 4, + ]; + let x2_c0_b2_a0_bytes = [ + 83, 243, 249, 17, 182, 3, 187, 178, 50, 163, 228, 7, 41, 42, 112, 214, 49, 230, 209, 51, + 47, 231, 202, 159, 207, 53, 206, 156, 185, 78, 41, 218, 53, 51, 150, 34, 225, 3, 70, 109, + 175, 0, 196, 203, 223, 250, 72, 23, + ]; + let x2_c0_b2_a1_bytes = [ + 199, 85, 149, 220, 117, 49, 210, 187, 65, 211, 178, 200, 40, 185, 196, 145, 71, 82, 217, + 89, 71, 169, 165, 111, 197, 116, 69, 251, 23, 153, 16, 20, 132, 175, 11, 145, 80, 126, 91, + 134, 75, 241, 10, 98, 180, 25, 75, 8, + ]; + let x2_c1_b0_a0_bytes = [ + 141, 236, 203, 10, 202, 77, 75, 56, 220, 209, 236, 228, 179, 193, 0, 11, 150, 176, 93, 11, + 160, 247, 196, 42, 124, 7, 17, 177, 63, 114, 152, 248, 70, 54, 208, 219, 105, 251, 220, + 155, 234, 26, 196, 108, 114, 133, 30, 15, + ]; + let x2_c1_b0_a1_bytes = [ + 11, 162, 153, 121, 1, 98, 69, 183, 236, 40, 118, 117, 84, 196, 122, 53, 214, 13, 246, 56, + 145, 63, 41, 189, 87, 227, 228, 123, 101, 181, 65, 245, 22, 17, 225, 34, 231, 239, 23, 138, + 67, 198, 49, 45, 16, 0, 34, 23, + ]; + let x2_c1_b1_a0_bytes = [ + 121, 71, 222, 182, 82, 106, 82, 68, 121, 64, 189, 104, 112, 119, 219, 131, 92, 81, 73, 12, + 67, 128, 130, 243, 98, 74, 171, 126, 252, 134, 58, 25, 252, 128, 244, 180, 125, 86, 217, + 76, 33, 252, 223, 237, 162, 185, 29, 10, + ]; + let x2_c1_b1_a1_bytes = [ + 21, 78, 120, 102, 240, 68, 106, 103, 189, 140, 232, 139, 109, 41, 214, 59, 7, 121, 26, 66, + 90, 102, 211, 18, 8, 42, 206, 212, 111, 72, 40, 112, 249, 144, 164, 128, 3, 165, 48, 132, + 127, 2, 45, 247, 63, 106, 89, 23, + ]; + let x2_c1_b2_a0_bytes = [ + 139, 58, 122, 68, 234, 250, 127, 30, 253, 71, 195, 108, 110, 86, 70, 100, 190, 112, 72, + 165, 128, 16, 212, 8, 59, 173, 66, 56, 168, 153, 20, 11, 212, 98, 254, 27, 216, 204, 202, + 169, 121, 168, 120, 226, 241, 209, 132, 0, + ]; + let x2_c1_b2_a1_bytes = [ + 92, 47, 142, 103, 182, 205, 41, 171, 63, 77, 46, 155, 28, 56, 96, 68, 63, 159, 183, 28, 81, + 184, 252, 185, 76, 140, 102, 186, 64, 129, 216, 87, 92, 34, 160, 50, 82, 54, 246, 65, 232, + 141, 147, 83, 83, 221, 127, 8, + ]; + + for i in 0..48 { + assignment.a[0][0][0][i] = M31::from(x0_c0_b0_a0_bytes[i]); + assignment.a[0][0][1][i] = M31::from(x0_c0_b0_a1_bytes[i]); + assignment.a[0][1][0][i] = M31::from(x0_c0_b1_a0_bytes[i]); + assignment.a[0][1][1][i] = M31::from(x0_c0_b1_a1_bytes[i]); + assignment.a[0][2][0][i] = M31::from(x0_c0_b2_a0_bytes[i]); + assignment.a[0][2][1][i] = M31::from(x0_c0_b2_a1_bytes[i]); + assignment.a[1][0][0][i] = M31::from(x0_c1_b0_a0_bytes[i]); + assignment.a[1][0][1][i] = M31::from(x0_c1_b0_a1_bytes[i]); + assignment.a[1][1][0][i] = M31::from(x0_c1_b1_a0_bytes[i]); + assignment.a[1][1][1][i] = M31::from(x0_c1_b1_a1_bytes[i]); + assignment.a[1][2][0][i] = M31::from(x0_c1_b2_a0_bytes[i]); + assignment.a[1][2][1][i] = M31::from(x0_c1_b2_a1_bytes[i]); + assignment.c[0][0][0][i] = M31::from(x2_c0_b0_a0_bytes[i]); + assignment.c[0][0][1][i] = M31::from(x2_c0_b0_a1_bytes[i]); + assignment.c[0][1][0][i] = M31::from(x2_c0_b1_a0_bytes[i]); + assignment.c[0][1][1][i] = M31::from(x2_c0_b1_a1_bytes[i]); + assignment.c[0][2][0][i] = M31::from(x2_c0_b2_a0_bytes[i]); + assignment.c[0][2][1][i] = M31::from(x2_c0_b2_a1_bytes[i]); + assignment.c[1][0][0][i] = M31::from(x2_c1_b0_a0_bytes[i]); + assignment.c[1][0][1][i] = M31::from(x2_c1_b0_a1_bytes[i]); + assignment.c[1][1][0][i] = M31::from(x2_c1_b1_a0_bytes[i]); + assignment.c[1][1][1][i] = M31::from(x2_c1_b1_a1_bytes[i]); + assignment.c[1][2][0][i] = M31::from(x2_c1_b2_a0_bytes[i]); + assignment.c[1][2][1][i] = M31::from(x2_c1_b2_a1_bytes[i]); + } + + debug_eval(&E12SquareCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E12ConjugateCircuit { + a: [[[[Variable; 48]; 2]; 3]; 2], + c: [[[[Variable; 48]; 2]; 3]; 2], +}); + +impl GenericDefine for E12ConjugateCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext12 = Ext12::new(builder); + + let a_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.a[0][0][0].to_vec(), 0), + a1: new_internal_element(self.a[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[0][1][0].to_vec(), 0), + a1: new_internal_element(self.a[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[0][2][0].to_vec(), 0), + a1: new_internal_element(self.a[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.a[1][0][0].to_vec(), 0), + a1: new_internal_element(self.a[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[1][1][0].to_vec(), 0), + a1: new_internal_element(self.a[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[1][2][0].to_vec(), 0), + a1: new_internal_element(self.a[1][2][1].to_vec(), 0), + }, + }, + }; + + let c_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.c[0][0][0].to_vec(), 0), + a1: new_internal_element(self.c[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[0][1][0].to_vec(), 0), + a1: new_internal_element(self.c[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[0][2][0].to_vec(), 0), + a1: new_internal_element(self.c[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.c[1][0][0].to_vec(), 0), + a1: new_internal_element(self.c[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[1][1][0].to_vec(), 0), + a1: new_internal_element(self.c[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[1][2][0].to_vec(), 0), + a1: new_internal_element(self.c[1][2][1].to_vec(), 0), + }, + }, + }; + + let z = ext12.conjugate(builder, &a_e12); + ext12.assert_isequal(builder, &z, &c_e12); + ext12.ext6.ext2.curve_f.check_mul(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e12_conjugate() { + compile_generic(&E12ConjugateCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + + let mut assignment = E12ConjugateCircuit:: { + a: [[[[M31::from(0); 48]; 2]; 3]; 2], + c: [[[[M31::from(0); 48]; 2]; 3]; 2], + }; + let x0_c0_b0_a0_bytes = [ + 71, 177, 236, 83, 1, 17, 168, 246, 122, 100, 204, 112, 142, 217, 145, 27, 117, 59, 181, 4, + 229, 102, 112, 231, 144, 76, 212, 114, 160, 6, 240, 191, 127, 58, 84, 179, 120, 206, 111, + 94, 23, 146, 65, 115, 219, 104, 57, 7, + ]; + let x0_c0_b0_a1_bytes = [ + 98, 70, 164, 16, 248, 85, 63, 169, 213, 122, 167, 96, 191, 181, 158, 165, 5, 21, 59, 136, + 220, 102, 102, 91, 95, 82, 173, 119, 180, 92, 56, 130, 87, 92, 12, 105, 103, 69, 103, 145, + 223, 44, 36, 110, 162, 13, 254, 20, + ]; + let x0_c0_b1_a0_bytes = [ + 55, 212, 190, 91, 232, 203, 217, 72, 223, 44, 237, 68, 48, 180, 74, 228, 203, 178, 114, 41, + 178, 72, 186, 81, 112, 129, 254, 48, 20, 251, 238, 215, 62, 167, 155, 163, 75, 120, 212, + 115, 165, 23, 78, 10, 208, 29, 139, 18, + ]; + let x0_c0_b1_a1_bytes = [ + 55, 125, 236, 216, 16, 213, 81, 181, 69, 164, 134, 74, 142, 76, 172, 244, 196, 237, 117, + 33, 136, 47, 144, 228, 78, 210, 94, 247, 212, 110, 220, 35, 28, 248, 106, 140, 240, 37, + 195, 76, 191, 46, 212, 227, 44, 75, 38, 5, + ]; + let x0_c0_b2_a0_bytes = [ + 108, 135, 79, 73, 222, 246, 223, 3, 196, 88, 96, 97, 246, 150, 37, 39, 189, 31, 83, 226, + 241, 117, 168, 182, 37, 40, 84, 61, 167, 84, 169, 98, 124, 99, 203, 2, 251, 90, 140, 51, + 191, 75, 138, 35, 75, 61, 10, 14, + ]; + let x0_c0_b2_a1_bytes = [ + 115, 78, 45, 63, 204, 181, 103, 170, 128, 112, 113, 13, 17, 129, 119, 33, 165, 247, 110, + 180, 201, 227, 216, 210, 130, 153, 40, 247, 200, 149, 181, 183, 5, 175, 222, 84, 66, 50, + 224, 230, 163, 8, 219, 29, 88, 60, 117, 0, + ]; + let x0_c1_b0_a0_bytes = [ + 135, 152, 139, 17, 161, 56, 2, 200, 103, 224, 8, 28, 89, 75, 246, 96, 113, 142, 12, 114, + 129, 93, 114, 50, 98, 235, 194, 5, 255, 19, 176, 190, 238, 241, 217, 155, 94, 110, 35, 223, + 208, 121, 202, 45, 36, 228, 191, 1, + ]; + let x0_c1_b0_a1_bytes = [ + 8, 121, 249, 42, 246, 219, 209, 219, 213, 193, 113, 42, 45, 186, 174, 204, 186, 34, 69, 23, + 107, 222, 217, 183, 104, 71, 116, 4, 83, 36, 127, 115, 127, 155, 99, 79, 112, 138, 154, 70, + 182, 27, 104, 18, 58, 153, 133, 25, + ]; + let x0_c1_b1_a0_bytes = [ + 243, 206, 55, 0, 101, 194, 150, 200, 220, 120, 221, 22, 96, 108, 9, 91, 132, 137, 197, 247, + 86, 186, 43, 155, 181, 94, 160, 171, 96, 172, 158, 111, 54, 155, 88, 2, 238, 135, 35, 144, + 225, 43, 226, 46, 73, 116, 171, 11, + ]; + let x0_c1_b1_a1_bytes = [ + 75, 168, 150, 127, 101, 168, 30, 3, 55, 176, 63, 180, 55, 209, 78, 27, 13, 168, 137, 105, + 232, 78, 11, 32, 12, 151, 79, 87, 139, 175, 210, 4, 145, 22, 56, 237, 46, 14, 117, 113, + 229, 26, 58, 118, 133, 43, 13, 13, + ]; + let x0_c1_b2_a0_bytes = [ + 156, 21, 251, 228, 85, 140, 169, 144, 214, 200, 194, 238, 194, 169, 249, 223, 17, 86, 36, + 172, 183, 194, 241, 22, 28, 130, 174, 104, 241, 241, 85, 132, 33, 109, 84, 66, 149, 250, + 181, 179, 232, 160, 93, 201, 167, 65, 56, 4, + ]; + let x0_c1_b2_a1_bytes = [ + 45, 60, 150, 78, 181, 165, 56, 10, 10, 5, 96, 212, 194, 255, 149, 172, 157, 182, 107, 249, + 69, 53, 116, 209, 34, 203, 97, 54, 255, 246, 100, 104, 52, 72, 19, 171, 150, 61, 243, 104, + 213, 203, 37, 137, 119, 252, 231, 12, + ]; + let x2_c0_b0_a0_bytes = [ + 71, 177, 236, 83, 1, 17, 168, 246, 122, 100, 204, 112, 142, 217, 145, 27, 117, 59, 181, 4, + 229, 102, 112, 231, 144, 76, 212, 114, 160, 6, 240, 191, 127, 58, 84, 179, 120, 206, 111, + 94, 23, 146, 65, 115, 219, 104, 57, 7, + ]; + let x2_c0_b0_a1_bytes = [ + 98, 70, 164, 16, 248, 85, 63, 169, 213, 122, 167, 96, 191, 181, 158, 165, 5, 21, 59, 136, + 220, 102, 102, 91, 95, 82, 173, 119, 180, 92, 56, 130, 87, 92, 12, 105, 103, 69, 103, 145, + 223, 44, 36, 110, 162, 13, 254, 20, + ]; + let x2_c0_b1_a0_bytes = [ + 55, 212, 190, 91, 232, 203, 217, 72, 223, 44, 237, 68, 48, 180, 74, 228, 203, 178, 114, 41, + 178, 72, 186, 81, 112, 129, 254, 48, 20, 251, 238, 215, 62, 167, 155, 163, 75, 120, 212, + 115, 165, 23, 78, 10, 208, 29, 139, 18, + ]; + let x2_c0_b1_a1_bytes = [ + 55, 125, 236, 216, 16, 213, 81, 181, 69, 164, 134, 74, 142, 76, 172, 244, 196, 237, 117, + 33, 136, 47, 144, 228, 78, 210, 94, 247, 212, 110, 220, 35, 28, 248, 106, 140, 240, 37, + 195, 76, 191, 46, 212, 227, 44, 75, 38, 5, + ]; + let x2_c0_b2_a0_bytes = [ + 108, 135, 79, 73, 222, 246, 223, 3, 196, 88, 96, 97, 246, 150, 37, 39, 189, 31, 83, 226, + 241, 117, 168, 182, 37, 40, 84, 61, 167, 84, 169, 98, 124, 99, 203, 2, 251, 90, 140, 51, + 191, 75, 138, 35, 75, 61, 10, 14, + ]; + let x2_c0_b2_a1_bytes = [ + 115, 78, 45, 63, 204, 181, 103, 170, 128, 112, 113, 13, 17, 129, 119, 33, 165, 247, 110, + 180, 201, 227, 216, 210, 130, 153, 40, 247, 200, 149, 181, 183, 5, 175, 222, 84, 66, 50, + 224, 230, 163, 8, 219, 29, 88, 60, 117, 0, + ]; + let x2_c1_b0_a0_bytes = [ + 36, 18, 116, 238, 94, 199, 252, 241, 151, 31, 75, 149, 165, 180, 181, 189, 178, 103, 164, + 132, 31, 117, 190, 52, 93, 39, 194, 237, 133, 55, 199, 165, 232, 186, 113, 167, 87, 57, + 248, 107, 201, 108, 181, 11, 198, 45, 65, 24, + ]; + let x2_c1_b0_a1_bytes = [ + 163, 49, 6, 213, 9, 36, 45, 222, 41, 62, 226, 134, 209, 69, 253, 81, 105, 211, 107, 223, + 53, 244, 86, 175, 86, 203, 16, 239, 49, 39, 248, 240, 87, 17, 232, 243, 69, 29, 129, 4, + 228, 202, 23, 39, 176, 120, 123, 0, + ]; + let x2_c1_b1_a0_bytes = [ + 184, 219, 199, 255, 154, 61, 104, 241, 34, 135, 118, 154, 158, 147, 162, 195, 159, 108, + 235, 254, 73, 24, 5, 204, 9, 180, 228, 71, 36, 159, 216, 244, 160, 17, 243, 64, 200, 31, + 248, 186, 184, 186, 157, 10, 161, 157, 85, 14, + ]; + let x2_c1_b1_a1_bytes = [ + 96, 2, 105, 128, 154, 87, 224, 182, 200, 79, 20, 253, 198, 46, 93, 3, 23, 78, 39, 141, 184, + 131, 37, 71, 179, 123, 53, 156, 249, 155, 164, 95, 70, 150, 19, 86, 135, 153, 166, 217, + 180, 203, 69, 195, 100, 230, 243, 12, + ]; + let x2_c1_b2_a0_bytes = [ + 15, 149, 4, 27, 170, 115, 85, 41, 41, 55, 145, 194, 59, 86, 178, 62, 18, 160, 140, 74, 233, + 15, 63, 80, 163, 144, 214, 138, 147, 89, 33, 224, 181, 63, 247, 0, 33, 173, 101, 151, 177, + 69, 34, 112, 66, 208, 200, 21, + ]; + let x2_c1_b2_a1_bytes = [ + 126, 110, 105, 177, 74, 90, 198, 175, 245, 250, 243, 220, 59, 0, 22, 114, 134, 63, 69, 253, + 90, 157, 188, 149, 156, 71, 35, 189, 133, 84, 18, 252, 162, 100, 56, 152, 31, 106, 40, 226, + 196, 26, 90, 176, 114, 21, 25, 13, + ]; + + for i in 0..48 { + assignment.a[0][0][0][i] = M31::from(x0_c0_b0_a0_bytes[i]); + assignment.a[0][0][1][i] = M31::from(x0_c0_b0_a1_bytes[i]); + assignment.a[0][1][0][i] = M31::from(x0_c0_b1_a0_bytes[i]); + assignment.a[0][1][1][i] = M31::from(x0_c0_b1_a1_bytes[i]); + assignment.a[0][2][0][i] = M31::from(x0_c0_b2_a0_bytes[i]); + assignment.a[0][2][1][i] = M31::from(x0_c0_b2_a1_bytes[i]); + assignment.a[1][0][0][i] = M31::from(x0_c1_b0_a0_bytes[i]); + assignment.a[1][0][1][i] = M31::from(x0_c1_b0_a1_bytes[i]); + assignment.a[1][1][0][i] = M31::from(x0_c1_b1_a0_bytes[i]); + assignment.a[1][1][1][i] = M31::from(x0_c1_b1_a1_bytes[i]); + assignment.a[1][2][0][i] = M31::from(x0_c1_b2_a0_bytes[i]); + assignment.a[1][2][1][i] = M31::from(x0_c1_b2_a1_bytes[i]); + assignment.c[0][0][0][i] = M31::from(x2_c0_b0_a0_bytes[i]); + assignment.c[0][0][1][i] = M31::from(x2_c0_b0_a1_bytes[i]); + assignment.c[0][1][0][i] = M31::from(x2_c0_b1_a0_bytes[i]); + assignment.c[0][1][1][i] = M31::from(x2_c0_b1_a1_bytes[i]); + assignment.c[0][2][0][i] = M31::from(x2_c0_b2_a0_bytes[i]); + assignment.c[0][2][1][i] = M31::from(x2_c0_b2_a1_bytes[i]); + assignment.c[1][0][0][i] = M31::from(x2_c1_b0_a0_bytes[i]); + assignment.c[1][0][1][i] = M31::from(x2_c1_b0_a1_bytes[i]); + assignment.c[1][1][0][i] = M31::from(x2_c1_b1_a0_bytes[i]); + assignment.c[1][1][1][i] = M31::from(x2_c1_b1_a1_bytes[i]); + assignment.c[1][2][0][i] = M31::from(x2_c1_b2_a0_bytes[i]); + assignment.c[1][2][1][i] = M31::from(x2_c1_b2_a1_bytes[i]); + } + + debug_eval(&E12ConjugateCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E12InverseCircuit { + a: [[[[Variable; 48]; 2]; 3]; 2], + c: [[[[Variable; 48]; 2]; 3]; 2], +}); + +impl GenericDefine for E12InverseCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext12 = Ext12::new(builder); + + let a_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.a[0][0][0].to_vec(), 0), + a1: new_internal_element(self.a[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[0][1][0].to_vec(), 0), + a1: new_internal_element(self.a[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[0][2][0].to_vec(), 0), + a1: new_internal_element(self.a[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.a[1][0][0].to_vec(), 0), + a1: new_internal_element(self.a[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[1][1][0].to_vec(), 0), + a1: new_internal_element(self.a[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[1][2][0].to_vec(), 0), + a1: new_internal_element(self.a[1][2][1].to_vec(), 0), + }, + }, + }; + + let c_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.c[0][0][0].to_vec(), 0), + a1: new_internal_element(self.c[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[0][1][0].to_vec(), 0), + a1: new_internal_element(self.c[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[0][2][0].to_vec(), 0), + a1: new_internal_element(self.c[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.c[1][0][0].to_vec(), 0), + a1: new_internal_element(self.c[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[1][1][0].to_vec(), 0), + a1: new_internal_element(self.c[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[1][2][0].to_vec(), 0), + a1: new_internal_element(self.c[1][2][1].to_vec(), 0), + }, + }, + }; + + let z = ext12.inverse(builder, &a_e12); + ext12.assert_isequal(builder, &z, &c_e12); + ext12.ext6.ext2.curve_f.check_mul(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e12_inverse() { + compile_generic(&E12InverseCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + + let mut assignment = E12InverseCircuit:: { + a: [[[[M31::from(0); 48]; 2]; 3]; 2], + c: [[[[M31::from(0); 48]; 2]; 3]; 2], + }; + let x0_c0_b0_a0_bytes = [ + 239, 186, 91, 151, 236, 129, 147, 153, 101, 99, 53, 151, 162, 197, 14, 129, 206, 52, 82, + 66, 40, 93, 181, 127, 159, 109, 86, 10, 123, 147, 115, 119, 236, 230, 242, 84, 19, 56, 246, + 198, 89, 111, 151, 230, 140, 35, 172, 23, + ]; + let x0_c0_b0_a1_bytes = [ + 129, 197, 170, 169, 211, 1, 138, 50, 251, 182, 222, 65, 29, 85, 241, 112, 203, 123, 83, + 142, 78, 54, 101, 246, 241, 13, 107, 73, 73, 27, 215, 229, 113, 211, 109, 83, 250, 71, 151, + 173, 78, 35, 205, 118, 255, 190, 133, 2, + ]; + let x0_c0_b1_a0_bytes = [ + 213, 43, 183, 161, 86, 92, 215, 70, 136, 235, 36, 130, 5, 48, 66, 116, 93, 226, 131, 54, + 211, 42, 44, 129, 95, 197, 114, 157, 128, 111, 237, 159, 42, 235, 82, 225, 113, 134, 63, + 128, 68, 138, 243, 118, 58, 154, 85, 23, + ]; + let x0_c0_b1_a1_bytes = [ + 114, 226, 223, 71, 191, 8, 71, 98, 212, 201, 134, 67, 17, 67, 112, 72, 13, 33, 13, 224, 6, + 172, 231, 177, 160, 227, 217, 230, 147, 22, 70, 71, 125, 239, 212, 160, 161, 245, 34, 195, + 37, 117, 140, 115, 217, 166, 1, 12, + ]; + let x0_c0_b2_a0_bytes = [ + 173, 62, 209, 5, 189, 147, 109, 62, 65, 158, 66, 54, 136, 251, 249, 122, 50, 122, 70, 119, + 226, 158, 12, 244, 61, 175, 69, 95, 78, 101, 28, 103, 42, 21, 43, 254, 0, 183, 162, 17, + 202, 212, 97, 232, 169, 231, 31, 6, + ]; + let x0_c0_b2_a1_bytes = [ + 102, 4, 179, 120, 17, 221, 42, 212, 239, 7, 7, 31, 186, 185, 3, 44, 237, 22, 250, 85, 111, + 94, 226, 138, 111, 134, 175, 237, 55, 208, 37, 210, 231, 8, 254, 247, 196, 61, 138, 81, + 208, 158, 27, 122, 37, 166, 58, 14, + ]; + let x0_c1_b0_a0_bytes = [ + 68, 117, 204, 86, 188, 131, 76, 39, 232, 170, 1, 168, 214, 0, 211, 16, 139, 169, 39, 58, + 251, 138, 210, 214, 10, 95, 209, 138, 91, 65, 161, 116, 191, 111, 56, 130, 80, 38, 168, + 232, 117, 1, 73, 115, 124, 171, 43, 11, + ]; + let x0_c1_b0_a1_bytes = [ + 7, 122, 155, 89, 246, 186, 116, 55, 46, 146, 121, 114, 185, 240, 212, 116, 96, 14, 145, + 133, 36, 128, 156, 208, 153, 122, 95, 170, 97, 83, 156, 180, 196, 193, 166, 73, 128, 146, + 146, 20, 250, 6, 91, 179, 83, 233, 79, 17, + ]; + let x0_c1_b1_a0_bytes = [ + 54, 148, 249, 115, 176, 147, 190, 102, 19, 199, 129, 72, 19, 255, 35, 66, 35, 39, 139, 124, + 233, 5, 56, 74, 211, 196, 116, 80, 177, 184, 65, 142, 219, 129, 2, 214, 251, 11, 61, 231, + 142, 103, 194, 34, 114, 204, 241, 18, + ]; + let x0_c1_b1_a1_bytes = [ + 149, 115, 220, 144, 24, 182, 223, 191, 4, 238, 199, 71, 115, 98, 97, 148, 102, 62, 143, 18, + 71, 27, 64, 213, 180, 149, 53, 153, 46, 192, 74, 169, 109, 199, 19, 27, 247, 92, 194, 209, + 115, 88, 36, 43, 23, 235, 99, 3, + ]; + let x0_c1_b2_a0_bytes = [ + 207, 64, 86, 239, 93, 197, 185, 192, 250, 176, 52, 113, 5, 9, 141, 195, 16, 43, 42, 138, + 200, 149, 95, 121, 15, 125, 71, 119, 141, 68, 215, 140, 2, 220, 57, 6, 73, 21, 185, 32, + 111, 5, 235, 41, 136, 124, 143, 10, + ]; + let x0_c1_b2_a1_bytes = [ + 163, 180, 236, 225, 210, 55, 0, 151, 126, 111, 86, 98, 207, 29, 45, 229, 123, 119, 174, + 140, 120, 117, 78, 237, 155, 193, 218, 54, 191, 241, 33, 5, 145, 169, 207, 165, 84, 25, 99, + 106, 93, 124, 150, 93, 43, 46, 25, 2, + ]; + let x2_c0_b0_a0_bytes = [ + 57, 214, 182, 130, 35, 159, 250, 24, 209, 249, 80, 73, 243, 134, 169, 163, 114, 248, 153, + 112, 127, 226, 230, 68, 197, 234, 100, 109, 111, 98, 238, 0, 214, 165, 110, 228, 34, 255, + 243, 76, 107, 48, 226, 17, 93, 223, 138, 7, + ]; + let x2_c0_b0_a1_bytes = [ + 161, 146, 144, 233, 77, 212, 55, 2, 104, 132, 98, 221, 178, 21, 102, 5, 108, 47, 242, 77, + 97, 196, 63, 16, 232, 62, 255, 69, 229, 213, 80, 32, 191, 163, 15, 40, 94, 56, 112, 207, + 110, 239, 148, 161, 222, 178, 210, 24, + ]; + let x2_c0_b1_a0_bytes = [ + 89, 67, 10, 79, 236, 37, 119, 218, 66, 177, 21, 220, 69, 153, 231, 145, 242, 6, 110, 247, + 155, 53, 163, 68, 134, 161, 21, 182, 60, 156, 127, 205, 125, 126, 113, 112, 7, 44, 193, + 129, 104, 203, 241, 240, 114, 100, 189, 18, + ]; + let x2_c0_b1_a1_bytes = [ + 86, 135, 71, 239, 167, 1, 39, 92, 175, 78, 24, 72, 242, 186, 239, 252, 243, 182, 155, 181, + 254, 11, 202, 187, 134, 137, 139, 112, 249, 252, 164, 178, 32, 149, 88, 48, 171, 167, 198, + 56, 242, 47, 161, 83, 184, 99, 20, 13, + ]; + let x2_c0_b2_a0_bytes = [ + 119, 10, 21, 35, 53, 171, 73, 201, 190, 67, 49, 86, 58, 77, 247, 76, 80, 240, 12, 59, 8, + 89, 147, 164, 147, 54, 211, 62, 114, 137, 64, 39, 186, 240, 252, 134, 109, 255, 125, 101, + 97, 89, 71, 44, 115, 120, 233, 24, + ]; + let x2_c0_b2_a1_bytes = [ + 182, 232, 20, 90, 71, 192, 139, 141, 111, 157, 143, 24, 204, 150, 173, 203, 139, 134, 130, + 160, 171, 135, 20, 204, 236, 150, 25, 223, 43, 37, 145, 212, 102, 207, 204, 32, 78, 142, + 23, 44, 79, 8, 42, 199, 176, 105, 208, 8, + ]; + let x2_c1_b0_a0_bytes = [ + 177, 251, 61, 25, 122, 5, 17, 207, 251, 43, 55, 10, 247, 253, 31, 163, 175, 201, 61, 254, + 47, 144, 137, 204, 83, 57, 178, 171, 255, 69, 153, 178, 165, 217, 113, 28, 235, 33, 203, 6, + 207, 251, 85, 32, 219, 4, 161, 15, + ]; + let x2_c1_b0_a1_bytes = [ + 224, 185, 252, 67, 17, 11, 212, 145, 15, 21, 53, 184, 30, 147, 28, 140, 61, 193, 213, 87, + 132, 221, 11, 125, 69, 105, 73, 204, 152, 156, 134, 106, 210, 73, 189, 209, 109, 164, 161, + 232, 241, 171, 183, 123, 243, 240, 69, 17, + ]; + let x2_c1_b1_a0_bytes = [ + 210, 222, 123, 144, 12, 44, 162, 17, 183, 202, 81, 141, 237, 186, 74, 145, 60, 11, 235, + 203, 217, 207, 77, 119, 54, 162, 37, 122, 37, 125, 203, 106, 192, 193, 198, 216, 102, 173, + 152, 126, 29, 217, 26, 101, 71, 28, 71, 12, + ]; + let x2_c1_b1_a1_bytes = [ + 127, 8, 161, 3, 209, 235, 42, 144, 140, 233, 109, 196, 17, 15, 62, 139, 56, 181, 19, 120, + 176, 247, 44, 34, 155, 222, 189, 228, 93, 70, 24, 167, 83, 250, 171, 150, 195, 194, 212, + 136, 247, 103, 205, 104, 87, 227, 41, 10, + ]; + let x2_c1_b2_a0_bytes = [ + 149, 146, 83, 190, 252, 159, 164, 10, 252, 95, 197, 72, 197, 222, 22, 150, 236, 47, 242, + 28, 19, 182, 100, 118, 242, 41, 87, 156, 192, 146, 219, 11, 150, 114, 7, 140, 84, 132, 57, + 98, 151, 187, 49, 172, 0, 154, 158, 4, + ]; + let x2_c1_b2_a1_bytes = [ + 89, 197, 129, 4, 34, 216, 120, 179, 250, 172, 172, 26, 57, 188, 253, 93, 68, 213, 85, 156, + 232, 216, 158, 222, 243, 177, 243, 162, 177, 230, 118, 217, 138, 9, 135, 45, 160, 84, 233, + 110, 67, 47, 104, 250, 232, 222, 121, 0, + ]; + + for i in 0..48 { + assignment.a[0][0][0][i] = M31::from(x0_c0_b0_a0_bytes[i]); + assignment.a[0][0][1][i] = M31::from(x0_c0_b0_a1_bytes[i]); + assignment.a[0][1][0][i] = M31::from(x0_c0_b1_a0_bytes[i]); + assignment.a[0][1][1][i] = M31::from(x0_c0_b1_a1_bytes[i]); + assignment.a[0][2][0][i] = M31::from(x0_c0_b2_a0_bytes[i]); + assignment.a[0][2][1][i] = M31::from(x0_c0_b2_a1_bytes[i]); + assignment.a[1][0][0][i] = M31::from(x0_c1_b0_a0_bytes[i]); + assignment.a[1][0][1][i] = M31::from(x0_c1_b0_a1_bytes[i]); + assignment.a[1][1][0][i] = M31::from(x0_c1_b1_a0_bytes[i]); + assignment.a[1][1][1][i] = M31::from(x0_c1_b1_a1_bytes[i]); + assignment.a[1][2][0][i] = M31::from(x0_c1_b2_a0_bytes[i]); + assignment.a[1][2][1][i] = M31::from(x0_c1_b2_a1_bytes[i]); + assignment.c[0][0][0][i] = M31::from(x2_c0_b0_a0_bytes[i]); + assignment.c[0][0][1][i] = M31::from(x2_c0_b0_a1_bytes[i]); + assignment.c[0][1][0][i] = M31::from(x2_c0_b1_a0_bytes[i]); + assignment.c[0][1][1][i] = M31::from(x2_c0_b1_a1_bytes[i]); + assignment.c[0][2][0][i] = M31::from(x2_c0_b2_a0_bytes[i]); + assignment.c[0][2][1][i] = M31::from(x2_c0_b2_a1_bytes[i]); + assignment.c[1][0][0][i] = M31::from(x2_c1_b0_a0_bytes[i]); + assignment.c[1][0][1][i] = M31::from(x2_c1_b0_a1_bytes[i]); + assignment.c[1][1][0][i] = M31::from(x2_c1_b1_a0_bytes[i]); + assignment.c[1][1][1][i] = M31::from(x2_c1_b1_a1_bytes[i]); + assignment.c[1][2][0][i] = M31::from(x2_c1_b2_a0_bytes[i]); + assignment.c[1][2][1][i] = M31::from(x2_c1_b2_a1_bytes[i]); + } + debug_eval(&E12InverseCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E12MulBy014Circuit { + a: [[[[Variable; 48]; 2]; 3]; 2], + w: [[[[Variable; 48]; 2]; 3]; 2], + b: [[Variable; 48]; 2], + c: [[Variable; 48]; 2], +}); + +impl GenericDefine for E12MulBy014Circuit { + fn define>(&self, builder: &mut Builder) { + let mut ext12 = Ext12::new(builder); + + let a_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.a[0][0][0].to_vec(), 0), + a1: new_internal_element(self.a[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[0][1][0].to_vec(), 0), + a1: new_internal_element(self.a[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[0][2][0].to_vec(), 0), + a1: new_internal_element(self.a[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.a[1][0][0].to_vec(), 0), + a1: new_internal_element(self.a[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[1][1][0].to_vec(), 0), + a1: new_internal_element(self.a[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[1][2][0].to_vec(), 0), + a1: new_internal_element(self.a[1][2][1].to_vec(), 0), + }, + }, + }; + + let w_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.w[0][0][0].to_vec(), 0), + a1: new_internal_element(self.w[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.w[0][1][0].to_vec(), 0), + a1: new_internal_element(self.w[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.w[0][2][0].to_vec(), 0), + a1: new_internal_element(self.w[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.w[1][0][0].to_vec(), 0), + a1: new_internal_element(self.w[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.w[1][1][0].to_vec(), 0), + a1: new_internal_element(self.w[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.w[1][2][0].to_vec(), 0), + a1: new_internal_element(self.w[1][2][1].to_vec(), 0), + }, + }, + }; + + let b_e2 = GE2 { + a0: new_internal_element(self.b[0].to_vec(), 0), + a1: new_internal_element(self.b[1].to_vec(), 0), + }; + + let c_e2 = GE2 { + a0: new_internal_element(self.c[0].to_vec(), 0), + a1: new_internal_element(self.c[1].to_vec(), 0), + }; + + let z = ext12.mul_by_014(builder, &a_e12, &b_e2, &c_e2); + ext12.assert_isequal(builder, &z, &w_e12); + ext12.ext6.ext2.curve_f.check_mul(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e12_mul_by_014() { + // let compile_result = + // compile_generic(&E12MulBy014Circuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + + let mut assignment = E12MulBy014Circuit:: { + a: [[[[M31::from(0); 48]; 2]; 3]; 2], + w: [[[[M31::from(0); 48]; 2]; 3]; 2], + b: [[M31::from(0); 48]; 2], + c: [[M31::from(0); 48]; 2], + }; + let x0_c0_b0_a0_bytes = [ + 46, 225, 141, 72, 79, 6, 52, 59, 209, 213, 86, 160, 220, 208, 132, 110, 53, 70, 111, 237, + 250, 13, 135, 108, 93, 27, 196, 125, 229, 194, 108, 221, 127, 4, 115, 130, 225, 243, 250, + 188, 89, 102, 164, 141, 191, 208, 246, 22, + ]; + let x0_c0_b0_a1_bytes = [ + 31, 107, 172, 201, 84, 5, 66, 186, 151, 71, 249, 145, 228, 59, 45, 212, 200, 223, 1, 16, + 229, 57, 250, 233, 212, 35, 187, 34, 118, 226, 250, 125, 125, 173, 6, 187, 2, 234, 253, + 112, 193, 250, 181, 214, 49, 29, 150, 22, + ]; + let x0_c0_b1_a0_bytes = [ + 102, 132, 113, 1, 157, 235, 122, 46, 89, 173, 53, 254, 78, 47, 128, 55, 205, 137, 5, 222, + 247, 82, 1, 250, 59, 129, 8, 180, 128, 183, 28, 9, 111, 191, 183, 115, 239, 27, 222, 239, + 238, 61, 74, 8, 57, 100, 87, 14, + ]; + let x0_c0_b1_a1_bytes = [ + 211, 198, 117, 79, 222, 237, 57, 94, 161, 82, 233, 228, 137, 153, 45, 193, 238, 255, 73, + 106, 208, 95, 16, 191, 145, 216, 253, 216, 63, 176, 145, 77, 179, 252, 234, 60, 4, 184, 71, + 22, 19, 70, 176, 90, 243, 27, 190, 13, + ]; + let x0_c0_b2_a0_bytes = [ + 151, 168, 135, 95, 89, 100, 143, 171, 239, 191, 150, 12, 80, 189, 237, 24, 22, 155, 221, + 154, 95, 234, 83, 226, 158, 222, 54, 60, 182, 225, 240, 29, 122, 81, 228, 72, 240, 76, 243, + 94, 198, 255, 8, 19, 222, 224, 137, 21, + ]; + let x0_c0_b2_a1_bytes = [ + 28, 112, 79, 97, 105, 30, 99, 190, 237, 253, 96, 11, 23, 52, 152, 45, 155, 53, 10, 47, 6, + 39, 119, 166, 156, 107, 163, 207, 226, 140, 64, 65, 96, 200, 95, 201, 13, 55, 127, 136, 55, + 9, 123, 33, 67, 0, 158, 21, + ]; + let x0_c1_b0_a0_bytes = [ + 212, 171, 88, 128, 53, 43, 171, 112, 143, 58, 210, 187, 196, 137, 38, 89, 57, 223, 27, 124, + 231, 24, 0, 187, 204, 189, 55, 104, 249, 111, 68, 82, 11, 127, 112, 65, 163, 142, 48, 175, + 61, 165, 140, 94, 7, 93, 134, 23, + ]; + let x0_c1_b0_a1_bytes = [ + 69, 70, 146, 4, 112, 110, 61, 229, 87, 7, 88, 244, 130, 214, 149, 194, 13, 228, 203, 135, + 25, 62, 35, 215, 158, 227, 144, 239, 67, 100, 10, 250, 22, 57, 183, 186, 56, 197, 235, 11, + 44, 103, 198, 44, 169, 66, 41, 6, + ]; + let x0_c1_b1_a0_bytes = [ + 116, 164, 31, 12, 98, 150, 12, 73, 229, 235, 76, 171, 164, 90, 119, 217, 95, 2, 213, 201, + 107, 68, 44, 233, 66, 236, 251, 36, 209, 84, 101, 16, 39, 100, 113, 12, 173, 46, 113, 75, + 99, 150, 80, 82, 216, 89, 173, 11, + ]; + let x0_c1_b1_a1_bytes = [ + 212, 231, 52, 254, 7, 77, 81, 168, 142, 65, 198, 223, 119, 200, 170, 39, 62, 180, 161, 52, + 229, 96, 188, 148, 59, 205, 34, 160, 235, 54, 180, 242, 166, 165, 80, 213, 187, 178, 112, + 41, 236, 98, 135, 190, 50, 87, 148, 17, + ]; + let x0_c1_b2_a0_bytes = [ + 203, 2, 160, 135, 190, 99, 216, 217, 114, 53, 245, 58, 73, 240, 132, 99, 109, 175, 162, + 114, 96, 150, 248, 105, 216, 12, 205, 67, 121, 31, 105, 68, 189, 49, 20, 110, 8, 108, 146, + 5, 248, 7, 36, 205, 153, 144, 33, 13, + ]; + let x0_c1_b2_a1_bytes = [ + 203, 136, 84, 84, 75, 168, 160, 42, 254, 245, 246, 224, 74, 54, 92, 224, 184, 237, 123, 60, + 155, 213, 237, 99, 78, 84, 82, 187, 38, 238, 213, 213, 150, 148, 186, 89, 137, 174, 204, + 235, 236, 253, 12, 2, 84, 47, 121, 10, + ]; + let x1_a0_bytes = [ + 97, 217, 42, 113, 196, 20, 178, 27, 215, 13, 156, 167, 138, 17, 171, 196, 232, 155, 154, + 149, 209, 178, 84, 234, 115, 240, 69, 32, 234, 186, 21, 219, 82, 254, 108, 18, 101, 227, + 82, 125, 231, 36, 240, 88, 221, 86, 203, 4, + ]; + let x1_a1_bytes = [ + 181, 119, 85, 130, 130, 97, 98, 37, 183, 64, 108, 80, 157, 44, 213, 158, 31, 115, 18, 140, + 43, 129, 7, 96, 201, 228, 58, 17, 72, 80, 38, 60, 222, 6, 243, 230, 151, 157, 15, 199, 64, + 204, 251, 199, 87, 114, 30, 14, + ]; + let x2_a0_bytes = [ + 142, 57, 191, 139, 145, 59, 244, 144, 145, 73, 235, 127, 111, 15, 212, 26, 156, 71, 198, + 192, 110, 63, 33, 64, 132, 28, 22, 180, 142, 188, 167, 105, 90, 169, 73, 42, 100, 218, 78, + 81, 162, 17, 252, 88, 132, 34, 36, 25, + ]; + let x2_a1_bytes = [ + 141, 172, 175, 31, 128, 169, 179, 227, 202, 136, 6, 176, 193, 155, 72, 63, 72, 69, 49, 75, + 204, 13, 77, 41, 90, 208, 48, 109, 251, 81, 88, 232, 104, 211, 141, 6, 146, 48, 156, 255, + 102, 143, 17, 169, 187, 25, 164, 24, + ]; + let x3_c0_b0_a0_bytes = [ + 139, 193, 89, 3, 233, 201, 122, 223, 194, 169, 54, 194, 48, 252, 80, 208, 78, 220, 230, 21, + 0, 245, 152, 35, 53, 51, 57, 175, 145, 231, 17, 100, 230, 199, 48, 3, 91, 7, 51, 3, 201, + 191, 182, 179, 127, 245, 84, 22, + ]; + let x3_c0_b0_a1_bytes = [ + 143, 137, 64, 149, 139, 89, 220, 39, 12, 127, 45, 136, 61, 41, 159, 67, 114, 127, 252, 46, + 20, 121, 136, 49, 88, 130, 161, 80, 103, 23, 73, 179, 59, 221, 18, 162, 143, 167, 85, 43, + 54, 92, 223, 169, 48, 23, 33, 13, + ]; + let x3_c0_b1_a0_bytes = [ + 218, 58, 2, 251, 106, 226, 165, 205, 132, 234, 252, 159, 96, 3, 66, 52, 135, 235, 35, 245, + 178, 53, 125, 139, 37, 161, 93, 201, 234, 166, 231, 137, 2, 46, 84, 203, 210, 63, 135, 22, + 39, 121, 217, 49, 195, 178, 109, 13, + ]; + let x3_c0_b1_a1_bytes = [ + 69, 81, 11, 211, 140, 63, 176, 144, 200, 183, 213, 228, 47, 4, 188, 80, 145, 7, 70, 41, + 127, 13, 90, 22, 44, 221, 197, 66, 237, 119, 132, 158, 164, 38, 247, 160, 217, 173, 103, 2, + 227, 124, 246, 225, 247, 237, 70, 8, + ]; + let x3_c0_b2_a0_bytes = [ + 213, 70, 9, 166, 158, 52, 110, 129, 50, 212, 141, 195, 222, 84, 123, 45, 199, 68, 201, 227, + 209, 120, 57, 73, 231, 101, 30, 138, 183, 8, 48, 53, 71, 37, 251, 64, 241, 72, 16, 136, + 174, 60, 196, 26, 204, 252, 254, 16, + ]; + let x3_c0_b2_a1_bytes = [ + 92, 75, 160, 53, 232, 125, 245, 45, 81, 16, 110, 36, 179, 125, 207, 188, 190, 45, 100, 167, + 24, 74, 103, 225, 158, 87, 184, 194, 198, 69, 15, 77, 142, 228, 157, 196, 111, 103, 84, + 244, 167, 53, 118, 185, 177, 119, 212, 23, + ]; + let x3_c1_b0_a0_bytes = [ + 79, 180, 128, 190, 186, 98, 168, 175, 124, 93, 72, 97, 41, 254, 186, 145, 181, 2, 3, 99, + 19, 243, 187, 225, 99, 96, 108, 143, 214, 4, 119, 79, 171, 52, 55, 3, 240, 237, 207, 179, + 186, 129, 67, 225, 190, 53, 232, 5, + ]; + let x3_c1_b0_a1_bytes = [ + 101, 50, 45, 138, 153, 115, 140, 5, 53, 2, 165, 107, 108, 181, 19, 195, 66, 84, 132, 120, + 144, 67, 247, 39, 47, 0, 32, 226, 132, 40, 109, 58, 69, 196, 160, 249, 51, 240, 102, 156, + 13, 85, 69, 252, 91, 12, 10, 0, + ]; + let x3_c1_b1_a0_bytes = [ + 148, 187, 155, 201, 27, 246, 72, 5, 110, 230, 145, 147, 78, 48, 217, 232, 208, 216, 193, + 55, 149, 123, 211, 76, 177, 184, 136, 97, 171, 210, 173, 128, 212, 119, 192, 0, 128, 8, + 157, 49, 248, 39, 179, 185, 226, 163, 81, 18, + ]; + let x3_c1_b1_a1_bytes = [ + 1, 157, 251, 4, 189, 95, 113, 234, 155, 50, 0, 251, 38, 171, 221, 139, 75, 188, 130, 49, + 177, 148, 232, 100, 251, 64, 90, 167, 177, 187, 140, 234, 43, 133, 148, 174, 104, 4, 12, + 65, 237, 37, 45, 125, 68, 64, 239, 6, + ]; + let x3_c1_b2_a0_bytes = [ + 199, 44, 149, 165, 101, 136, 132, 147, 162, 147, 239, 173, 253, 64, 189, 26, 139, 51, 208, + 95, 216, 1, 193, 161, 199, 211, 25, 240, 43, 126, 189, 172, 166, 101, 10, 165, 218, 25, + 170, 24, 167, 87, 240, 13, 45, 62, 111, 23, + ]; + let x3_c1_b2_a1_bytes = [ + 205, 79, 236, 205, 166, 11, 179, 69, 160, 45, 40, 178, 191, 234, 149, 228, 61, 98, 86, 83, + 162, 219, 49, 32, 134, 142, 185, 213, 255, 225, 114, 198, 88, 86, 22, 229, 93, 24, 197, + 179, 155, 224, 134, 14, 203, 213, 114, 8, + ]; + + for i in 0..48 { + assignment.a[0][0][0][i] = M31::from(x0_c0_b0_a0_bytes[i]); + assignment.a[0][0][1][i] = M31::from(x0_c0_b0_a1_bytes[i]); + assignment.a[0][1][0][i] = M31::from(x0_c0_b1_a0_bytes[i]); + assignment.a[0][1][1][i] = M31::from(x0_c0_b1_a1_bytes[i]); + assignment.a[0][2][0][i] = M31::from(x0_c0_b2_a0_bytes[i]); + assignment.a[0][2][1][i] = M31::from(x0_c0_b2_a1_bytes[i]); + assignment.a[1][0][0][i] = M31::from(x0_c1_b0_a0_bytes[i]); + assignment.a[1][0][1][i] = M31::from(x0_c1_b0_a1_bytes[i]); + assignment.a[1][1][0][i] = M31::from(x0_c1_b1_a0_bytes[i]); + assignment.a[1][1][1][i] = M31::from(x0_c1_b1_a1_bytes[i]); + assignment.a[1][2][0][i] = M31::from(x0_c1_b2_a0_bytes[i]); + assignment.a[1][2][1][i] = M31::from(x0_c1_b2_a1_bytes[i]); + assignment.b[0][i] = M31::from(x1_a0_bytes[i]); + assignment.b[1][i] = M31::from(x1_a1_bytes[i]); + assignment.c[0][i] = M31::from(x2_a0_bytes[i]); + assignment.c[1][i] = M31::from(x2_a1_bytes[i]); + assignment.w[0][0][0][i] = M31::from(x3_c0_b0_a0_bytes[i]); + assignment.w[0][0][1][i] = M31::from(x3_c0_b0_a1_bytes[i]); + assignment.w[0][1][0][i] = M31::from(x3_c0_b1_a0_bytes[i]); + assignment.w[0][1][1][i] = M31::from(x3_c0_b1_a1_bytes[i]); + assignment.w[0][2][0][i] = M31::from(x3_c0_b2_a0_bytes[i]); + assignment.w[0][2][1][i] = M31::from(x3_c0_b2_a1_bytes[i]); + assignment.w[1][0][0][i] = M31::from(x3_c1_b0_a0_bytes[i]); + assignment.w[1][0][1][i] = M31::from(x3_c1_b0_a1_bytes[i]); + assignment.w[1][1][0][i] = M31::from(x3_c1_b1_a0_bytes[i]); + assignment.w[1][1][1][i] = M31::from(x3_c1_b1_a1_bytes[i]); + assignment.w[1][2][0][i] = M31::from(x3_c1_b2_a0_bytes[i]); + assignment.w[1][2][1][i] = M31::from(x3_c1_b2_a1_bytes[i]); + } + debug_eval(&E12MulBy014Circuit::default(), &assignment, hint_registry); +} diff --git a/circuit-std-rs/tests/gnark/emulated/field_bls12381/e2.rs b/circuit-std-rs/tests/gnark/emulated/field_bls12381/e2.rs new file mode 100644 index 00000000..a21653bf --- /dev/null +++ b/circuit-std-rs/tests/gnark/emulated/field_bls12381/e2.rs @@ -0,0 +1,859 @@ +use circuit_std_rs::gnark::{ + element::new_internal_element, + emulated::field_bls12381::e2::{Ext2, GE2}, + hints::register_hint, +}; +use expander_compiler::frontend::compile_generic; +use expander_compiler::{ + compile::CompileOptions, + declare_circuit, + frontend::{extra::debug_eval, GenericDefine, HintRegistry, M31Config, RootAPI, Variable, M31}, +}; +declare_circuit!(E2AddCircuit { + x: [[Variable; 48]; 2], + y: [[Variable; 48]; 2], + z: [[Variable; 48]; 2], +}); + +impl GenericDefine for E2AddCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext2 = Ext2::new(builder); + let x_e2 = GE2 { + a0: new_internal_element(self.x[0].to_vec(), 0), + a1: new_internal_element(self.x[1].to_vec(), 0), + }; + let y_e2 = GE2 { + a0: new_internal_element(self.y[0].to_vec(), 0), + a1: new_internal_element(self.y[1].to_vec(), 0), + }; + let z = ext2.add(builder, &x_e2, &y_e2); + let expect_z = GE2 { + a0: new_internal_element(self.z[0].to_vec(), 0), + a1: new_internal_element(self.z[1].to_vec(), 0), + }; + ext2.assert_isequal(builder, &z, &expect_z); + ext2.curve_f.check_mul(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e2_add() { + compile_generic(&E2AddCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = E2AddCircuit:: { + x: [[M31::from(0); 48], [M31::from(0); 48]], + y: [[M31::from(0); 48], [M31::from(0); 48]], + z: [[M31::from(0); 48], [M31::from(0); 48]], + }; + + let x0_bytes = [ + 89, 156, 69, 194, 144, 213, 244, 116, 63, 190, 210, 105, 4, 3, 175, 7, 101, 54, 28, 7, 18, + 172, 79, 84, 237, 54, 73, 82, 129, 140, 106, 156, 148, 208, 55, 92, 9, 173, 33, 66, 123, + 235, 204, 136, 44, 150, 98, 10, + ]; + let x1_bytes = [ + 236, 205, 45, 143, 165, 12, 10, 61, 83, 59, 118, 233, 115, 199, 99, 173, 46, 152, 211, 133, + 250, 124, 121, 183, 156, 51, 67, 26, 197, 238, 173, 72, 255, 131, 102, 60, 79, 157, 114, + 50, 88, 209, 73, 233, 20, 196, 157, 18, + ]; + let y0_bytes = [ + 101, 10, 8, 84, 22, 11, 97, 20, 107, 192, 229, 172, 173, 2, 120, 227, 179, 177, 150, 202, + 54, 114, 18, 66, 169, 184, 198, 77, 8, 75, 97, 100, 206, 62, 149, 101, 48, 222, 77, 137, 6, + 205, 25, 24, 76, 102, 118, 25, + ]; + let y1_bytes = [ + 243, 203, 189, 51, 238, 238, 208, 177, 106, 92, 9, 174, 126, 219, 65, 8, 25, 127, 0, 66, + 228, 241, 244, 28, 252, 165, 248, 4, 63, 218, 226, 161, 203, 55, 182, 127, 95, 228, 71, + 202, 31, 217, 66, 238, 3, 35, 127, 14, + ]; + let z0_bytes = [ + 218, 253, 64, 116, 175, 52, 24, 151, 151, 215, 179, 170, 76, 250, 69, 90, 88, 37, 34, 244, + 208, 51, 26, 6, 74, 174, 1, 199, 44, 146, 237, 75, 240, 250, 248, 226, 161, 68, 67, 49, + 204, 164, 203, 228, 12, 79, 238, 5, + ]; + let z1_bytes = [ + 162, 191, 112, 190, 81, 47, 128, 118, 149, 112, 222, 152, 142, 11, 49, 60, 180, 34, 229, + 197, 248, 214, 150, 237, 125, 100, 177, 224, 222, 18, 165, 199, 250, 85, 240, 222, 198, 4, + 78, 217, 202, 6, 85, 164, 7, 27, 109, 21, + ]; + for i in 0..48 { + assignment.x[0][i] = M31::from(x0_bytes[i] as u32); + assignment.x[1][i] = M31::from(x1_bytes[i] as u32); + assignment.y[0][i] = M31::from(y0_bytes[i] as u32); + assignment.y[1][i] = M31::from(y1_bytes[i] as u32); + assignment.z[0][i] = M31::from(z0_bytes[i] as u32); + assignment.z[1][i] = M31::from(z1_bytes[i] as u32); + } + + // debug_eval( + // &E2AddCircuit::default(), + // &assignment, + // hint_registry, + // ); +} + +declare_circuit!(E2SubCircuit { + x: [[Variable; 48]; 2], + y: [[Variable; 48]; 2], + z: [[Variable; 48]; 2], +}); + +impl GenericDefine for E2SubCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext2 = Ext2::new(builder); + let x_e2 = GE2 { + a0: new_internal_element(self.x[0].to_vec(), 0), + a1: new_internal_element(self.x[1].to_vec(), 0), + }; + let y_e2 = GE2 { + a0: new_internal_element(self.y[0].to_vec(), 0), + a1: new_internal_element(self.y[1].to_vec(), 0), + }; + let mut z = ext2.sub(builder, &x_e2, &y_e2); + + for _ in 0..32 { + z = ext2.sub(builder, &z, &y_e2); + } + let z_reduce_a0 = ext2.curve_f.reduce(builder, &z.a0, false); + let z_reduce_a1 = ext2.curve_f.reduce(builder, &z.a1, false); + + for i in 0..48 { + builder.assert_is_equal(z_reduce_a0.limbs[i], self.z[0][i]); + builder.assert_is_equal(z_reduce_a1.limbs[i], self.z[1][i]); + } + ext2.curve_f.check_mul(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e2_sub() { + // let compile_result = compile(&E2SubCircuit::default()).unwrap(); + compile_generic(&E2SubCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = E2SubCircuit:: { + x: [[M31::from(0); 48], [M31::from(0); 48]], + y: [[M31::from(0); 48], [M31::from(0); 48]], + z: [[M31::from(0); 48], [M31::from(0); 48]], + }; + + let x0_bytes = [ + 89, 156, 69, 194, 144, 213, 244, 116, 63, 190, 210, 105, 4, 3, 175, 7, 101, 54, 28, 7, 18, + 172, 79, 84, 237, 54, 73, 82, 129, 140, 106, 156, 148, 208, 55, 92, 9, 173, 33, 66, 123, + 235, 204, 136, 44, 150, 98, 10, + ]; + let x1_bytes = [ + 236, 205, 45, 143, 165, 12, 10, 61, 83, 59, 118, 233, 115, 199, 99, 173, 46, 152, 211, 133, + 250, 124, 121, 183, 156, 51, 67, 26, 197, 238, 173, 72, 255, 131, 102, 60, 79, 157, 114, + 50, 88, 209, 73, 233, 20, 196, 157, 18, + ]; + let y0_bytes = [ + 101, 10, 8, 84, 22, 11, 97, 20, 107, 192, 229, 172, 173, 2, 120, 227, 179, 177, 150, 202, + 54, 114, 18, 66, 169, 184, 198, 77, 8, 75, 97, 100, 206, 62, 149, 101, 48, 222, 77, 137, 6, + 205, 25, 24, 76, 102, 118, 25, + ]; + let y1_bytes = [ + 243, 203, 189, 51, 238, 238, 208, 177, 106, 92, 9, 174, 126, 219, 65, 8, 25, 127, 0, 66, + 228, 241, 244, 28, 252, 165, 248, 4, 63, 218, 226, 161, 203, 55, 182, 127, 95, 228, 71, + 202, 31, 217, 66, 238, 3, 35, 127, 14, + ]; + let z0_bytes = [ + 180, 154, 49, 237, 175, 103, 82, 20, 105, 240, 180, 74, 119, 170, 182, 138, 184, 18, 206, + 191, 32, 71, 9, 182, 8, 193, 77, 188, 13, 81, 201, 58, 230, 82, 112, 173, 148, 255, 140, + 242, 236, 80, 118, 157, 164, 163, 65, 2, + ]; + let z1_bytes = [ + 159, 131, 176, 227, 240, 63, 9, 101, 141, 81, 41, 242, 7, 124, 254, 196, 126, 132, 52, 92, + 223, 29, 85, 61, 146, 31, 145, 149, 254, 27, 211, 122, 228, 121, 59, 129, 208, 247, 31, + 103, 24, 11, 170, 61, 11, 131, 77, 8, + ]; + for i in 0..48 { + assignment.x[0][i] = M31::from(x0_bytes[i] as u32); + assignment.x[1][i] = M31::from(x1_bytes[i] as u32); + assignment.y[0][i] = M31::from(y0_bytes[i] as u32); + assignment.y[1][i] = M31::from(y1_bytes[i] as u32); + assignment.z[0][i] = M31::from(z0_bytes[i] as u32); + assignment.z[1][i] = M31::from(z1_bytes[i] as u32); + } + + debug_eval(&E2SubCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E2DoubleCircuit { + x: [[Variable; 48]; 2], + z: [[Variable; 48]; 2], +}); + +impl GenericDefine for E2DoubleCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext2 = Ext2::new(builder); + let x_e2 = GE2 { + a0: new_internal_element(self.x[0].to_vec(), 0), + a1: new_internal_element(self.x[1].to_vec(), 0), + }; + let z = ext2.double(builder, &x_e2); + let z_reduce_a0 = ext2.curve_f.reduce(builder, &z.a0, false); + let z_reduce_a1 = ext2.curve_f.reduce(builder, &z.a1, false); + + for i in 0..48 { + builder.assert_is_equal(z_reduce_a0.limbs[i], self.z[0][i]); + builder.assert_is_equal(z_reduce_a1.limbs[i], self.z[1][i]); + } + ext2.curve_f.check_mul(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e2_double() { + // let compile_result = compile(&E2DoubleCircuit::default()).unwrap(); + compile_generic(&E2DoubleCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = E2DoubleCircuit:: { + x: [[M31::from(0); 48], [M31::from(0); 48]], + z: [[M31::from(0); 48], [M31::from(0); 48]], + }; + + let x0_bytes = [ + 15, 12, 79, 128, 139, 180, 205, 255, 209, 222, 213, 222, 254, 248, 10, 230, 191, 105, 202, + 47, 136, 213, 107, 173, 156, 11, 113, 96, 198, 183, 126, 251, 141, 187, 41, 102, 110, 132, + 31, 81, 75, 249, 2, 47, 228, 206, 81, 3, + ]; + let x1_bytes = [ + 240, 227, 119, 201, 24, 76, 33, 152, 185, 85, 45, 193, 110, 41, 147, 127, 248, 176, 165, + 66, 82, 161, 225, 108, 180, 84, 20, 69, 127, 71, 121, 72, 69, 230, 93, 22, 77, 43, 82, 119, + 31, 115, 198, 136, 207, 8, 46, 2, + ]; + let z0_bytes = [ + 30, 24, 158, 0, 23, 105, 155, 255, 163, 189, 171, 189, 253, 241, 21, 204, 127, 211, 148, + 95, 16, 171, 215, 90, 57, 23, 226, 192, 140, 111, 253, 246, 27, 119, 83, 204, 220, 8, 63, + 162, 150, 242, 5, 94, 200, 157, 163, 6, + ]; + let z1_bytes = [ + 224, 199, 239, 146, 49, 152, 66, 48, 115, 171, 90, 130, 221, 82, 38, 255, 240, 97, 75, 133, + 164, 66, 195, 217, 104, 169, 40, 138, 254, 142, 242, 144, 138, 204, 187, 44, 154, 86, 164, + 238, 62, 230, 140, 17, 159, 17, 92, 4, + ]; + for i in 0..48 { + assignment.x[0][i] = M31::from(x0_bytes[i] as u32); + assignment.x[1][i] = M31::from(x1_bytes[i] as u32); + assignment.z[0][i] = M31::from(z0_bytes[i] as u32); + assignment.z[1][i] = M31::from(z1_bytes[i] as u32); + } + + debug_eval(&E2DoubleCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E2MulCircuit { + x: [[Variable; 48]; 2], + y: [[Variable; 48]; 2], + z: [[Variable; 48]; 2], +}); + +impl GenericDefine for E2MulCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext2 = Ext2::new(builder); + let x_e2 = GE2 { + a0: new_internal_element(self.x[0].to_vec(), 0), + a1: new_internal_element(self.x[1].to_vec(), 0), + }; + let y_e2 = GE2 { + a0: new_internal_element(self.y[0].to_vec(), 0), + a1: new_internal_element(self.y[1].to_vec(), 0), + }; + let z = ext2.mul(builder, &x_e2, &y_e2); + let z_reduce_a0 = ext2.curve_f.reduce(builder, &z.a0, false); + let z_reduce_a1 = ext2.curve_f.reduce(builder, &z.a1, false); + + for i in 0..48 { + builder.assert_is_equal(z_reduce_a0.limbs[i], self.z[0][i]); + builder.assert_is_equal(z_reduce_a1.limbs[i], self.z[1][i]); + } + ext2.curve_f.check_mul(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e2_mul() { + // let compile_result = compile(&E2MulCircuit::default()).unwrap(); + compile_generic(&E2MulCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = E2MulCircuit:: { + x: [[M31::from(0); 48], [M31::from(0); 48]], + y: [[M31::from(0); 48], [M31::from(0); 48]], + z: [[M31::from(0); 48], [M31::from(0); 48]], + }; + + let x0_bytes = [ + 89, 156, 69, 194, 144, 213, 244, 116, 63, 190, 210, 105, 4, 3, 175, 7, 101, 54, 28, 7, 18, + 172, 79, 84, 237, 54, 73, 82, 129, 140, 106, 156, 148, 208, 55, 92, 9, 173, 33, 66, 123, + 235, 204, 136, 44, 150, 98, 10, + ]; + let x1_bytes = [ + 236, 205, 45, 143, 165, 12, 10, 61, 83, 59, 118, 233, 115, 199, 99, 173, 46, 152, 211, 133, + 250, 124, 121, 183, 156, 51, 67, 26, 197, 238, 173, 72, 255, 131, 102, 60, 79, 157, 114, + 50, 88, 209, 73, 233, 20, 196, 157, 18, + ]; + let y0_bytes = [ + 101, 10, 8, 84, 22, 11, 97, 20, 107, 192, 229, 172, 173, 2, 120, 227, 179, 177, 150, 202, + 54, 114, 18, 66, 169, 184, 198, 77, 8, 75, 97, 100, 206, 62, 149, 101, 48, 222, 77, 137, 6, + 205, 25, 24, 76, 102, 118, 25, + ]; + let y1_bytes = [ + 243, 203, 189, 51, 238, 238, 208, 177, 106, 92, 9, 174, 126, 219, 65, 8, 25, 127, 0, 66, + 228, 241, 244, 28, 252, 165, 248, 4, 63, 218, 226, 161, 203, 55, 182, 127, 95, 228, 71, + 202, 31, 217, 66, 238, 3, 35, 127, 14, + ]; + let z0_bytes = [ + 143, 141, 88, 121, 8, 168, 107, 196, 223, 95, 145, 40, 180, 240, 14, 127, 2, 131, 208, 179, + 204, 73, 135, 148, 189, 111, 164, 105, 224, 184, 248, 44, 208, 132, 0, 64, 210, 236, 241, + 225, 171, 116, 246, 214, 71, 118, 162, 23, + ]; + let z1_bytes = [ + 45, 113, 243, 46, 31, 23, 35, 212, 99, 184, 76, 19, 176, 150, 92, 64, 237, 213, 204, 21, + 66, 195, 173, 145, 168, 82, 248, 96, 149, 128, 101, 6, 129, 187, 168, 243, 171, 181, 118, + 146, 105, 156, 106, 82, 54, 190, 245, 20, + ]; + + for i in 0..48 { + assignment.x[0][i] = M31::from(x0_bytes[i] as u32); + assignment.x[1][i] = M31::from(x1_bytes[i] as u32); + assignment.y[0][i] = M31::from(y0_bytes[i] as u32); + assignment.y[1][i] = M31::from(y1_bytes[i] as u32); + assignment.z[0][i] = M31::from(z0_bytes[i] as u32); + assignment.z[1][i] = M31::from(z1_bytes[i] as u32); + } + + debug_eval(&E2MulCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E2SquareCircuit { + x: [[Variable; 48]; 2], + z: [[Variable; 48]; 2], +}); + +impl GenericDefine for E2SquareCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext2 = Ext2::new(builder); + let x_e2 = GE2 { + a0: new_internal_element(self.x[0].to_vec(), 0), + a1: new_internal_element(self.x[1].to_vec(), 0), + }; + let z = ext2.square(builder, &x_e2); + let z_reduce_a0 = ext2.curve_f.reduce(builder, &z.a0, false); + let z_reduce_a1 = ext2.curve_f.reduce(builder, &z.a1, false); + + for i in 0..48 { + builder.assert_is_equal(z_reduce_a0.limbs[i], self.z[0][i]); + builder.assert_is_equal(z_reduce_a1.limbs[i], self.z[1][i]); + } + ext2.curve_f.check_mul(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e2_square() { + // let compile_result = compile(&E2SquareCircuit::default()).unwrap(); + compile_generic(&E2SquareCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = E2SquareCircuit:: { + x: [[M31::from(0); 48], [M31::from(0); 48]], + z: [[M31::from(0); 48], [M31::from(0); 48]], + }; + + let x0_bytes = [ + 89, 156, 69, 194, 144, 213, 244, 116, 63, 190, 210, 105, 4, 3, 175, 7, 101, 54, 28, 7, 18, + 172, 79, 84, 237, 54, 73, 82, 129, 140, 106, 156, 148, 208, 55, 92, 9, 173, 33, 66, 123, + 235, 204, 136, 44, 150, 98, 10, + ]; + let x1_bytes = [ + 236, 205, 45, 143, 165, 12, 10, 61, 83, 59, 118, 233, 115, 199, 99, 173, 46, 152, 211, 133, + 250, 124, 121, 183, 156, 51, 67, 26, 197, 238, 173, 72, 255, 131, 102, 60, 79, 157, 114, + 50, 88, 209, 73, 233, 20, 196, 157, 18, + ]; + let z0_bytes = [ + 76, 190, 203, 175, 214, 65, 32, 217, 101, 144, 196, 235, 159, 76, 190, 209, 46, 223, 169, + 88, 25, 193, 105, 217, 115, 6, 68, 7, 79, 4, 154, 56, 167, 2, 202, 34, 126, 222, 83, 233, + 137, 224, 221, 96, 140, 156, 5, 18, + ]; + let z1_bytes = [ + 170, 117, 86, 12, 84, 70, 123, 39, 30, 83, 226, 114, 113, 237, 118, 58, 194, 47, 111, 221, + 135, 155, 127, 91, 79, 86, 4, 68, 107, 170, 254, 51, 102, 128, 53, 134, 93, 97, 103, 22, + 243, 175, 90, 255, 163, 111, 193, 25, + ]; + for i in 0..48 { + assignment.x[0][i] = M31::from(x0_bytes[i] as u32); + assignment.x[1][i] = M31::from(x1_bytes[i] as u32); + assignment.z[0][i] = M31::from(z0_bytes[i] as u32); + assignment.z[1][i] = M31::from(z1_bytes[i] as u32); + } + + debug_eval(&E2SquareCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E2DivCircuit { + x: [[Variable; 48]; 2], + y: [[Variable; 48]; 2], + z: [[Variable; 48]; 2], +}); + +impl GenericDefine for E2DivCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext2 = Ext2::new(builder); + let x_e2 = GE2 { + a0: new_internal_element(self.x[0].to_vec(), 0), + a1: new_internal_element(self.x[1].to_vec(), 0), + }; + let y_e2 = GE2 { + a0: new_internal_element(self.y[0].to_vec(), 0), + a1: new_internal_element(self.y[1].to_vec(), 0), + }; + let z = ext2.div(builder, &x_e2, &y_e2); + // let z_reduce_a0 = ext2.curve_f.reduce(builder, &z.a0, false); + // let z_reduce_a1 = ext2.curve_f.reduce(builder, &z.a1, false); + + for i in 0..48 { + builder.assert_is_equal(z.a0.limbs[i], self.z[0][i]); + builder.assert_is_equal(z.a1.limbs[i], self.z[1][i]); + } + ext2.curve_f.check_mul(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e2_div() { + compile_generic(&E2DivCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = E2DivCircuit:: { + x: [[M31::from(0); 48], [M31::from(0); 48]], + y: [[M31::from(0); 48], [M31::from(0); 48]], + z: [[M31::from(0); 48], [M31::from(0); 48]], + }; + + let x0_bytes = [ + 89, 156, 69, 194, 144, 213, 244, 116, 63, 190, 210, 105, 4, 3, 175, 7, 101, 54, 28, 7, 18, + 172, 79, 84, 237, 54, 73, 82, 129, 140, 106, 156, 148, 208, 55, 92, 9, 173, 33, 66, 123, + 235, 204, 136, 44, 150, 98, 10, + ]; + let x1_bytes = [ + 236, 205, 45, 143, 165, 12, 10, 61, 83, 59, 118, 233, 115, 199, 99, 173, 46, 152, 211, 133, + 250, 124, 121, 183, 156, 51, 67, 26, 197, 238, 173, 72, 255, 131, 102, 60, 79, 157, 114, + 50, 88, 209, 73, 233, 20, 196, 157, 18, + ]; + let y0_bytes = [ + 101, 10, 8, 84, 22, 11, 97, 20, 107, 192, 229, 172, 173, 2, 120, 227, 179, 177, 150, 202, + 54, 114, 18, 66, 169, 184, 198, 77, 8, 75, 97, 100, 206, 62, 149, 101, 48, 222, 77, 137, 6, + 205, 25, 24, 76, 102, 118, 25, + ]; + let y1_bytes = [ + 243, 203, 189, 51, 238, 238, 208, 177, 106, 92, 9, 174, 126, 219, 65, 8, 25, 127, 0, 66, + 228, 241, 244, 28, 252, 165, 248, 4, 63, 218, 226, 161, 203, 55, 182, 127, 95, 228, 71, + 202, 31, 217, 66, 238, 3, 35, 127, 14, + ]; + let z0_bytes = [ + 153, 184, 22, 74, 13, 182, 120, 88, 173, 188, 79, 252, 223, 69, 219, 113, 24, 134, 224, + 254, 32, 98, 137, 82, 111, 109, 147, 178, 206, 57, 2, 59, 140, 168, 221, 75, 120, 184, 199, + 120, 106, 250, 243, 94, 234, 159, 235, 8, + ]; + let z1_bytes = [ + 177, 188, 16, 148, 100, 119, 79, 251, 253, 76, 250, 108, 166, 218, 213, 148, 139, 44, 125, + 158, 121, 112, 238, 245, 236, 191, 74, 85, 188, 152, 34, 142, 65, 72, 66, 245, 76, 125, 71, + 123, 203, 25, 122, 132, 192, 59, 181, 2, + ]; + for i in 0..48 { + assignment.x[0][i] = M31::from(x0_bytes[i] as u32); + assignment.x[1][i] = M31::from(x1_bytes[i] as u32); + assignment.y[0][i] = M31::from(y0_bytes[i] as u32); + assignment.y[1][i] = M31::from(y1_bytes[i] as u32); + assignment.z[0][i] = M31::from(z0_bytes[i] as u32); + assignment.z[1][i] = M31::from(z1_bytes[i] as u32); + } + + debug_eval(&E2DivCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E2MulByElementCircuit { + a: [[Variable; 48]; 2], + b: [Variable; 48], + c: [[Variable; 48]; 2], +}); + +impl GenericDefine for E2MulByElementCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext2 = Ext2::new(builder); + let a_e2 = GE2 { + a0: new_internal_element(self.a[0].to_vec(), 0), + a1: new_internal_element(self.a[1].to_vec(), 0), + }; + let b = new_internal_element(self.b.to_vec(), 0); + let c = ext2.mul_by_element(builder, &a_e2, &b); + let c_reduce_a0 = ext2.curve_f.reduce(builder, &c.a0, false); + let c_reduce_a1 = ext2.curve_f.reduce(builder, &c.a1, false); + + for i in 0..48 { + builder.assert_is_equal(c_reduce_a0.limbs[i], self.c[0][i]); + builder.assert_is_equal(c_reduce_a1.limbs[i], self.c[1][i]); + } + ext2.curve_f.check_mul(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e2_mul_by_element() { + // let compile_result = compile(&E2MulByElementCircuit::default()).unwrap(); + compile_generic(&E2MulByElementCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = E2MulByElementCircuit:: { + a: [[M31::from(0); 48], [M31::from(0); 48]], + b: [M31::from(0); 48], + c: [[M31::from(0); 48], [M31::from(0); 48]], + }; + + let x0_bytes = [ + 89, 156, 69, 194, 144, 213, 244, 116, 63, 190, 210, 105, 4, 3, 175, 7, 101, 54, 28, 7, 18, + 172, 79, 84, 237, 54, 73, 82, 129, 140, 106, 156, 148, 208, 55, 92, 9, 173, 33, 66, 123, + 235, 204, 136, 44, 150, 98, 10, + ]; + let x1_bytes = [ + 236, 205, 45, 143, 165, 12, 10, 61, 83, 59, 118, 233, 115, 199, 99, 173, 46, 152, 211, 133, + 250, 124, 121, 183, 156, 51, 67, 26, 197, 238, 173, 72, 255, 131, 102, 60, 79, 157, 114, + 50, 88, 209, 73, 233, 20, 196, 157, 18, + ]; + let y0_bytes = [ + 101, 10, 8, 84, 22, 11, 97, 20, 107, 192, 229, 172, 173, 2, 120, 227, 179, 177, 150, 202, + 54, 114, 18, 66, 169, 184, 198, 77, 8, 75, 97, 100, 206, 62, 149, 101, 48, 222, 77, 137, 6, + 205, 25, 24, 76, 102, 118, 25, + ]; + let z0_bytes = [ + 182, 22, 7, 253, 0, 12, 198, 225, 34, 100, 90, 32, 63, 141, 75, 146, 131, 75, 234, 238, + 183, 203, 163, 40, 205, 44, 246, 38, 124, 126, 21, 66, 113, 12, 134, 89, 79, 157, 177, 199, + 10, 108, 231, 138, 198, 51, 108, 16, + ]; + let z1_bytes = [ + 99, 158, 220, 37, 153, 125, 46, 222, 184, 169, 143, 169, 208, 242, 197, 124, 114, 180, 20, + 50, 232, 149, 134, 129, 164, 99, 50, 252, 99, 116, 250, 173, 155, 113, 102, 35, 155, 201, + 251, 48, 142, 96, 192, 33, 247, 46, 83, 10, + ]; + + for i in 0..48 { + assignment.a[0][i] = M31::from(x0_bytes[i] as u32); + assignment.a[1][i] = M31::from(x1_bytes[i] as u32); + assignment.b[i] = M31::from(y0_bytes[i] as u32); + assignment.c[0][i] = M31::from(z0_bytes[i] as u32); + assignment.c[1][i] = M31::from(z1_bytes[i] as u32); + } + + debug_eval( + &E2MulByElementCircuit::default(), + &assignment, + hint_registry, + ); +} + +declare_circuit!(E2MulByNonResidueCircuit { + a: [[Variable; 48]; 2], + c: [[Variable; 48]; 2], +}); + +impl GenericDefine for E2MulByNonResidueCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext2 = Ext2::new(builder); + let a_e2 = GE2 { + a0: new_internal_element(self.a[0].to_vec(), 0), + a1: new_internal_element(self.a[1].to_vec(), 0), + }; + let c = ext2.mul_by_non_residue(builder, &a_e2); + let c_reduce_a0 = ext2.curve_f.reduce(builder, &c.a0, false); + let c_reduce_a1 = ext2.curve_f.reduce(builder, &c.a1, false); + + for i in 0..48 { + builder.assert_is_equal(c_reduce_a0.limbs[i], self.c[0][i]); + builder.assert_is_equal(c_reduce_a1.limbs[i], self.c[1][i]); + } + ext2.curve_f.check_mul(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e2_mul_by_non_residue() { + compile_generic( + &E2MulByNonResidueCircuit::default(), + CompileOptions::default(), + ) + .unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = E2MulByNonResidueCircuit:: { + a: [[M31::from(0); 48], [M31::from(0); 48]], + c: [[M31::from(0); 48], [M31::from(0); 48]], + }; + + let x0_bytes = [ + 89, 156, 69, 194, 144, 213, 244, 116, 63, 190, 210, 105, 4, 3, 175, 7, 101, 54, 28, 7, 18, + 172, 79, 84, 237, 54, 73, 82, 129, 140, 106, 156, 148, 208, 55, 92, 9, 173, 33, 66, 123, + 235, 204, 136, 44, 150, 98, 10, + ]; + let x1_bytes = [ + 236, 205, 45, 143, 165, 12, 10, 61, 83, 59, 118, 233, 115, 199, 99, 173, 46, 152, 211, 133, + 250, 124, 121, 183, 156, 51, 67, 26, 197, 238, 173, 72, 255, 131, 102, 60, 79, 157, 114, + 50, 88, 209, 73, 233, 20, 196, 157, 18, + ]; + let z0_bytes = [ + 24, 121, 23, 51, 235, 200, 233, 241, 235, 130, 176, 49, 143, 59, 247, 120, 90, 148, 249, + 119, 184, 1, 7, 4, 16, 22, 139, 43, 65, 233, 51, 184, 108, 249, 28, 99, 112, 183, 202, 90, + 189, 0, 3, 217, 1, 228, 197, 17, + ]; + let z1_bytes = [ + 154, 191, 115, 81, 54, 226, 255, 247, 146, 249, 244, 161, 121, 202, 102, 150, 111, 216, 62, + 150, 107, 86, 152, 164, 202, 87, 7, 121, 193, 47, 161, 128, 188, 167, 82, 85, 162, 162, + 120, 41, 57, 214, 150, 56, 87, 72, 255, 2, + ]; + + for i in 0..48 { + assignment.a[0][i] = M31::from(x0_bytes[i] as u32); + assignment.a[1][i] = M31::from(x1_bytes[i] as u32); + assignment.c[0][i] = M31::from(z0_bytes[i] as u32); + assignment.c[1][i] = M31::from(z1_bytes[i] as u32); + } + + debug_eval( + &E2MulByNonResidueCircuit::default(), + &assignment, + hint_registry, + ); +} + +declare_circuit!(E2NegCircuit { + a: [[Variable; 48]; 2], + c: [[Variable; 48]; 2], +}); + +impl GenericDefine for E2NegCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext2 = Ext2::new(builder); + let a_e2 = GE2 { + a0: new_internal_element(self.a[0].to_vec(), 0), + a1: new_internal_element(self.a[1].to_vec(), 0), + }; + let c = ext2.neg(builder, &a_e2); + let c_reduce_a0 = ext2.curve_f.reduce(builder, &c.a0, false); + let c_reduce_a1 = ext2.curve_f.reduce(builder, &c.a1, false); + + for i in 0..48 { + builder.assert_is_equal(c_reduce_a0.limbs[i], self.c[0][i]); + builder.assert_is_equal(c_reduce_a1.limbs[i], self.c[1][i]); + } + ext2.curve_f.check_mul(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e2_neg() { + // let compile_result = compile(&E2NegCircuit::default()).unwrap(); + compile_generic(&E2NegCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = E2NegCircuit:: { + a: [[M31::from(0); 48], [M31::from(0); 48]], + c: [[M31::from(0); 48], [M31::from(0); 48]], + }; + + let x0_bytes = [ + 89, 156, 69, 194, 144, 213, 244, 116, 63, 190, 210, 105, 4, 3, 175, 7, 101, 54, 28, 7, 18, + 172, 79, 84, 237, 54, 73, 82, 129, 140, 106, 156, 148, 208, 55, 92, 9, 173, 33, 66, 123, + 235, 204, 136, 44, 150, 98, 10, + ]; + let x1_bytes = [ + 236, 205, 45, 143, 165, 12, 10, 61, 83, 59, 118, 233, 115, 199, 99, 173, 46, 152, 211, 133, + 250, 124, 121, 183, 156, 51, 67, 26, 197, 238, 173, 72, 255, 131, 102, 60, 79, 157, 114, + 50, 88, 209, 73, 233, 20, 196, 157, 18, + ]; + let z0_bytes = [ + 82, 14, 186, 61, 111, 42, 10, 69, 192, 65, 129, 71, 250, 252, 252, 22, 191, 191, 148, 239, + 142, 38, 225, 18, 210, 219, 59, 161, 3, 191, 12, 200, 66, 220, 19, 231, 172, 250, 249, 8, + 31, 251, 178, 176, 189, 123, 158, 15, + ]; + let z1_bytes = [ + 191, 220, 209, 112, 90, 243, 244, 124, 172, 196, 221, 199, 138, 56, 72, 113, 245, 93, 221, + 112, 166, 85, 183, 175, 34, 223, 65, 217, 191, 92, 201, 27, 216, 40, 229, 6, 103, 10, 169, + 24, 66, 21, 54, 80, 213, 77, 99, 7, + ]; + + for i in 0..48 { + assignment.a[0][i] = M31::from(x0_bytes[i] as u32); + assignment.a[1][i] = M31::from(x1_bytes[i] as u32); + assignment.c[0][i] = M31::from(z0_bytes[i] as u32); + assignment.c[1][i] = M31::from(z1_bytes[i] as u32); + } + + debug_eval(&E2NegCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E2ConjugateCircuit { + a: [[Variable; 48]; 2], + c: [[Variable; 48]; 2], +}); + +impl GenericDefine for E2ConjugateCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext2 = Ext2::new(builder); + let a_e2 = GE2 { + a0: new_internal_element(self.a[0].to_vec(), 0), + a1: new_internal_element(self.a[1].to_vec(), 0), + }; + let c = ext2.conjugate(builder, &a_e2); + let c_reduce_a0 = ext2.curve_f.reduce(builder, &c.a0, false); + let c_reduce_a1 = ext2.curve_f.reduce(builder, &c.a1, false); + + for i in 0..48 { + builder.assert_is_equal(c_reduce_a0.limbs[i], self.c[0][i]); + builder.assert_is_equal(c_reduce_a1.limbs[i], self.c[1][i]); + } + ext2.curve_f.check_mul(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e2_conjugate() { + // let compile_result = compile(&E2ConjugateCircuit::default()).unwrap(); + compile_generic(&E2ConjugateCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = E2ConjugateCircuit:: { + a: [[M31::from(0); 48], [M31::from(0); 48]], + c: [[M31::from(0); 48], [M31::from(0); 48]], + }; + + let x0_bytes = [ + 89, 156, 69, 194, 144, 213, 244, 116, 63, 190, 210, 105, 4, 3, 175, 7, 101, 54, 28, 7, 18, + 172, 79, 84, 237, 54, 73, 82, 129, 140, 106, 156, 148, 208, 55, 92, 9, 173, 33, 66, 123, + 235, 204, 136, 44, 150, 98, 10, + ]; + let x1_bytes = [ + 236, 205, 45, 143, 165, 12, 10, 61, 83, 59, 118, 233, 115, 199, 99, 173, 46, 152, 211, 133, + 250, 124, 121, 183, 156, 51, 67, 26, 197, 238, 173, 72, 255, 131, 102, 60, 79, 157, 114, + 50, 88, 209, 73, 233, 20, 196, 157, 18, + ]; + let z0_bytes = [ + 89, 156, 69, 194, 144, 213, 244, 116, 63, 190, 210, 105, 4, 3, 175, 7, 101, 54, 28, 7, 18, + 172, 79, 84, 237, 54, 73, 82, 129, 140, 106, 156, 148, 208, 55, 92, 9, 173, 33, 66, 123, + 235, 204, 136, 44, 150, 98, 10, + ]; + let z1_bytes = [ + 191, 220, 209, 112, 90, 243, 244, 124, 172, 196, 221, 199, 138, 56, 72, 113, 245, 93, 221, + 112, 166, 85, 183, 175, 34, 223, 65, 217, 191, 92, 201, 27, 216, 40, 229, 6, 103, 10, 169, + 24, 66, 21, 54, 80, 213, 77, 99, 7, + ]; + + for i in 0..48 { + assignment.a[0][i] = M31::from(x0_bytes[i] as u32); + assignment.a[1][i] = M31::from(x1_bytes[i] as u32); + assignment.c[0][i] = M31::from(z0_bytes[i] as u32); + assignment.c[1][i] = M31::from(z1_bytes[i] as u32); + } + + debug_eval(&E2ConjugateCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E2InverseCircuit { + a: [[Variable; 48]; 2], + c: [[Variable; 48]; 2], +}); + +impl GenericDefine for E2InverseCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext2 = Ext2::new(builder); + let a_e2 = GE2 { + a0: new_internal_element(self.a[0].to_vec(), 0), + a1: new_internal_element(self.a[1].to_vec(), 0), + }; + let c = ext2.inverse(builder, &a_e2); + let c_reduce_a0 = ext2.curve_f.reduce(builder, &c.a0, false); + let c_reduce_a1 = ext2.curve_f.reduce(builder, &c.a1, false); + + for i in 0..48 { + builder.assert_is_equal(c_reduce_a0.limbs[i], self.c[0][i]); + builder.assert_is_equal(c_reduce_a1.limbs[i], self.c[1][i]); + } + ext2.curve_f.check_mul(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e2_inverse() { + compile_generic(&E2InverseCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = E2InverseCircuit:: { + a: [[M31::from(0); 48], [M31::from(0); 48]], + c: [[M31::from(0); 48], [M31::from(0); 48]], + }; + + let x0_bytes = [ + 89, 156, 69, 194, 144, 213, 244, 116, 63, 190, 210, 105, 4, 3, 175, 7, 101, 54, 28, 7, 18, + 172, 79, 84, 237, 54, 73, 82, 129, 140, 106, 156, 148, 208, 55, 92, 9, 173, 33, 66, 123, + 235, 204, 136, 44, 150, 98, 10, + ]; + let x1_bytes = [ + 236, 205, 45, 143, 165, 12, 10, 61, 83, 59, 118, 233, 115, 199, 99, 173, 46, 152, 211, 133, + 250, 124, 121, 183, 156, 51, 67, 26, 197, 238, 173, 72, 255, 131, 102, 60, 79, 157, 114, + 50, 88, 209, 73, 233, 20, 196, 157, 18, + ]; + let z0_bytes = [ + 188, 73, 170, 2, 86, 109, 56, 49, 4, 214, 214, 65, 170, 212, 146, 167, 82, 42, 230, 70, + 169, 141, 41, 214, 126, 246, 187, 34, 14, 112, 134, 20, 9, 143, 115, 7, 74, 103, 198, 27, + 169, 146, 135, 186, 148, 116, 195, 13, + ]; + let z1_bytes = [ + 25, 50, 4, 38, 189, 74, 213, 48, 113, 22, 13, 43, 46, 44, 21, 243, 221, 101, 44, 217, 100, + 12, 139, 227, 50, 156, 163, 74, 52, 27, 167, 130, 108, 55, 41, 186, 118, 30, 138, 246, 64, + 0, 64, 43, 180, 117, 173, 10, + ]; + + for i in 0..48 { + assignment.a[0][i] = M31::from(x0_bytes[i] as u32); + assignment.a[1][i] = M31::from(x1_bytes[i] as u32); + assignment.c[0][i] = M31::from(z0_bytes[i] as u32); + assignment.c[1][i] = M31::from(z1_bytes[i] as u32); + } + + debug_eval(&E2InverseCircuit::default(), &assignment, hint_registry); +} diff --git a/circuit-std-rs/tests/gnark/emulated/field_bls12381/e6.rs b/circuit-std-rs/tests/gnark/emulated/field_bls12381/e6.rs new file mode 100644 index 00000000..bc8db2d9 --- /dev/null +++ b/circuit-std-rs/tests/gnark/emulated/field_bls12381/e6.rs @@ -0,0 +1,1682 @@ +use circuit_std_rs::gnark::{ + element::new_internal_element, + emulated::field_bls12381::{ + e2::GE2, + e6::{Ext6, GE6}, + }, + hints::register_hint, +}; +use expander_compiler::{ + compile::CompileOptions, + declare_circuit, + frontend::{ + compile_generic, extra::debug_eval, GenericDefine, HintRegistry, M31Config, RootAPI, + Variable, M31, + }, +}; + +declare_circuit!(E6AddCircuit { + x: [[[Variable; 48]; 2]; 3], + y: [[[Variable; 48]; 2]; 3], + z: [[[Variable; 48]; 2]; 3], +}); + +impl GenericDefine for E6AddCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext6 = Ext6::new(builder); + let x_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.x[0][0].to_vec(), 0), + a1: new_internal_element(self.x[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.x[1][0].to_vec(), 0), + a1: new_internal_element(self.x[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.x[2][0].to_vec(), 0), + a1: new_internal_element(self.x[2][1].to_vec(), 0), + }, + }; + let y_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.y[0][0].to_vec(), 0), + a1: new_internal_element(self.y[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.y[1][0].to_vec(), 0), + a1: new_internal_element(self.y[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.y[2][0].to_vec(), 0), + a1: new_internal_element(self.y[2][1].to_vec(), 0), + }, + }; + let z_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.z[0][0].to_vec(), 0), + a1: new_internal_element(self.z[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.z[1][0].to_vec(), 0), + a1: new_internal_element(self.z[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.z[2][0].to_vec(), 0), + a1: new_internal_element(self.z[2][1].to_vec(), 0), + }, + }; + let z = ext6.add(builder, &x_e6, &y_e6); + ext6.assert_isequal(builder, &z, &z_e6); + ext6.ext2.curve_f.check_mul(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + } +} +#[test] +fn test_e6_add() { + // let compile_result = compile(&E2AddCircuit::default()).unwrap(); + compile_generic(&E6AddCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = E6AddCircuit:: { + x: [[[M31::from(0); 48]; 2]; 3], + y: [[[M31::from(0); 48]; 2]; 3], + z: [[[M31::from(0); 48]; 2]; 3], + }; + + let x0_b0_a0_bytes = [ + 43, 211, 155, 220, 85, 4, 8, 1, 215, 211, 93, 215, 81, 21, 56, 57, 139, 64, 114, 222, 34, + 249, 133, 1, 89, 193, 221, 30, 159, 24, 10, 156, 26, 94, 220, 176, 241, 186, 246, 191, 181, + 92, 117, 198, 20, 54, 44, 14, + ]; + let x0_b0_a1_bytes = [ + 63, 131, 211, 85, 212, 40, 216, 174, 142, 150, 21, 245, 183, 100, 255, 199, 209, 21, 209, + 87, 66, 192, 97, 175, 236, 116, 95, 238, 93, 20, 154, 35, 164, 253, 56, 202, 205, 64, 0, + 200, 179, 17, 69, 28, 185, 161, 70, 13, + ]; + let x0_b1_a0_bytes = [ + 214, 62, 78, 148, 85, 33, 95, 146, 49, 88, 94, 54, 52, 208, 3, 136, 177, 46, 77, 253, 17, + 128, 131, 235, 82, 176, 80, 134, 59, 52, 163, 238, 32, 181, 131, 56, 17, 55, 66, 102, 145, + 191, 18, 175, 151, 1, 212, 23, + ]; + let x0_b1_a1_bytes = [ + 41, 167, 64, 159, 223, 51, 189, 43, 186, 251, 202, 72, 55, 36, 85, 193, 232, 226, 132, 96, + 154, 82, 119, 118, 133, 141, 95, 19, 205, 2, 134, 48, 181, 178, 133, 101, 88, 189, 43, 189, + 238, 133, 161, 60, 82, 210, 193, 25, + ]; + let x0_b2_a0_bytes = [ + 69, 152, 0, 136, 208, 43, 221, 129, 150, 113, 46, 202, 33, 249, 218, 176, 47, 123, 129, + 203, 88, 135, 65, 235, 24, 13, 135, 20, 230, 253, 169, 246, 55, 229, 221, 139, 91, 205, + 100, 77, 117, 152, 144, 112, 64, 105, 19, 21, + ]; + let x0_b2_a1_bytes = [ + 91, 154, 129, 212, 234, 209, 169, 160, 142, 49, 247, 206, 85, 255, 156, 123, 218, 140, 13, + 35, 79, 130, 173, 36, 205, 226, 38, 38, 253, 40, 49, 195, 138, 58, 160, 15, 228, 18, 97, + 149, 42, 224, 34, 135, 225, 42, 216, 15, + ]; + let x1_b0_a0_bytes = [ + 168, 144, 97, 71, 250, 233, 57, 194, 117, 19, 227, 238, 182, 173, 56, 31, 77, 42, 237, 203, + 81, 157, 105, 108, 51, 186, 234, 114, 230, 161, 213, 26, 154, 32, 89, 75, 11, 160, 27, 146, + 90, 226, 1, 45, 226, 94, 235, 23, + ]; + let x1_b0_a1_bytes = [ + 1, 241, 173, 149, 51, 212, 21, 36, 198, 72, 155, 117, 227, 230, 43, 12, 239, 110, 117, 76, + 151, 134, 20, 75, 136, 2, 197, 149, 210, 100, 232, 213, 66, 182, 114, 49, 237, 192, 134, + 188, 192, 157, 229, 5, 205, 26, 72, 7, + ]; + let x1_b1_a0_bytes = [ + 5, 131, 227, 108, 57, 93, 117, 63, 62, 3, 235, 177, 236, 31, 181, 189, 212, 89, 138, 143, + 76, 255, 243, 255, 18, 170, 199, 28, 241, 228, 251, 200, 4, 18, 141, 186, 170, 58, 136, + 235, 114, 55, 39, 38, 1, 16, 35, 1, + ]; + let x1_b1_a1_bytes = [ + 125, 64, 186, 137, 111, 34, 155, 104, 156, 45, 242, 173, 235, 118, 208, 41, 134, 62, 54, + 225, 33, 126, 182, 34, 254, 7, 92, 226, 214, 219, 134, 153, 38, 192, 67, 164, 136, 69, 162, + 207, 122, 195, 73, 43, 24, 120, 96, 13, + ]; + let x1_b2_a0_bytes = [ + 145, 182, 101, 27, 67, 208, 10, 14, 239, 224, 162, 122, 20, 230, 25, 90, 124, 227, 52, 206, + 100, 13, 49, 213, 210, 224, 63, 236, 90, 227, 56, 138, 35, 218, 165, 113, 114, 120, 139, + 135, 191, 21, 32, 64, 126, 59, 230, 2, + ]; + let x1_b2_a1_bytes = [ + 93, 163, 83, 188, 82, 139, 106, 196, 217, 193, 42, 85, 147, 98, 114, 220, 131, 93, 17, 61, + 214, 81, 211, 13, 80, 49, 149, 41, 98, 183, 38, 215, 179, 227, 251, 194, 75, 197, 11, 128, + 111, 231, 95, 246, 179, 151, 8, 10, + ]; + let x2_b0_a0_bytes = [ + 40, 185, 253, 35, 80, 238, 66, 9, 77, 231, 236, 20, 10, 195, 196, 57, 180, 116, 174, 179, + 211, 195, 190, 6, 205, 104, 67, 158, 0, 111, 104, 82, 221, 209, 233, 184, 70, 179, 246, 6, + 118, 88, 247, 185, 12, 131, 22, 12, + ]; + let x2_b0_a1_bytes = [ + 64, 116, 129, 235, 7, 253, 237, 210, 84, 223, 176, 106, 155, 75, 43, 212, 192, 132, 70, + 164, 217, 70, 118, 250, 116, 119, 36, 132, 48, 121, 130, 249, 230, 179, 171, 251, 186, 1, + 135, 132, 116, 175, 42, 34, 134, 188, 142, 20, + ]; + let x2_b1_a0_bytes = [ + 219, 193, 49, 1, 143, 126, 212, 209, 111, 91, 73, 232, 32, 240, 184, 69, 134, 136, 215, + 140, 94, 127, 119, 235, 101, 90, 24, 163, 44, 25, 159, 183, 37, 199, 16, 243, 187, 113, + 202, 81, 4, 247, 57, 213, 152, 17, 247, 24, + ]; + let x2_b1_a1_bytes = [ + 251, 60, 251, 40, 79, 86, 89, 218, 86, 41, 105, 69, 36, 155, 121, 204, 74, 43, 10, 75, 27, + 254, 252, 49, 196, 130, 54, 2, 31, 147, 149, 101, 4, 198, 125, 198, 42, 91, 178, 65, 207, + 98, 107, 46, 128, 56, 33, 13, + ]; + let x2_b2_a0_bytes = [ + 214, 78, 102, 163, 19, 252, 231, 143, 133, 82, 209, 68, 54, 223, 244, 10, 172, 94, 182, + 153, 189, 148, 114, 192, 235, 237, 198, 0, 65, 225, 226, 128, 91, 191, 131, 253, 205, 69, + 240, 212, 52, 174, 176, 176, 190, 164, 249, 23, + ]; + let x2_b2_a1_bytes = [ + 184, 61, 213, 144, 61, 93, 20, 101, 104, 243, 33, 36, 233, 97, 15, 88, 94, 234, 30, 96, 37, + 212, 128, 50, 29, 20, 188, 79, 95, 224, 87, 154, 62, 30, 156, 210, 47, 216, 108, 21, 154, + 199, 130, 125, 149, 194, 224, 25, + ]; + + for i in 0..48 { + assignment.x[0][0][i] = M31::from(x0_b0_a0_bytes[i]); + assignment.x[0][1][i] = M31::from(x0_b0_a1_bytes[i]); + assignment.x[1][0][i] = M31::from(x0_b1_a0_bytes[i]); + assignment.x[1][1][i] = M31::from(x0_b1_a1_bytes[i]); + assignment.x[2][0][i] = M31::from(x0_b2_a0_bytes[i]); + assignment.x[2][1][i] = M31::from(x0_b2_a1_bytes[i]); + assignment.y[0][0][i] = M31::from(x1_b0_a0_bytes[i]); + assignment.y[0][1][i] = M31::from(x1_b0_a1_bytes[i]); + assignment.y[1][0][i] = M31::from(x1_b1_a0_bytes[i]); + assignment.y[1][1][i] = M31::from(x1_b1_a1_bytes[i]); + assignment.y[2][0][i] = M31::from(x1_b2_a0_bytes[i]); + assignment.y[2][1][i] = M31::from(x1_b2_a1_bytes[i]); + assignment.z[0][0][i] = M31::from(x2_b0_a0_bytes[i]); + assignment.z[0][1][i] = M31::from(x2_b0_a1_bytes[i]); + assignment.z[1][0][i] = M31::from(x2_b1_a0_bytes[i]); + assignment.z[1][1][i] = M31::from(x2_b1_a1_bytes[i]); + assignment.z[2][0][i] = M31::from(x2_b2_a0_bytes[i]); + assignment.z[2][1][i] = M31::from(x2_b2_a1_bytes[i]); + } + + debug_eval(&E6AddCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E6SubCircuit { + x: [[[Variable; 48]; 2]; 3], + y: [[[Variable; 48]; 2]; 3], + z: [[[Variable; 48]; 2]; 3], +}); + +impl GenericDefine for E6SubCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext6 = Ext6::new(builder); + + let x_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.x[0][0].to_vec(), 0), + a1: new_internal_element(self.x[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.x[1][0].to_vec(), 0), + a1: new_internal_element(self.x[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.x[2][0].to_vec(), 0), + a1: new_internal_element(self.x[2][1].to_vec(), 0), + }, + }; + + let y_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.y[0][0].to_vec(), 0), + a1: new_internal_element(self.y[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.y[1][0].to_vec(), 0), + a1: new_internal_element(self.y[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.y[2][0].to_vec(), 0), + a1: new_internal_element(self.y[2][1].to_vec(), 0), + }, + }; + + let z_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.z[0][0].to_vec(), 0), + a1: new_internal_element(self.z[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.z[1][0].to_vec(), 0), + a1: new_internal_element(self.z[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.z[2][0].to_vec(), 0), + a1: new_internal_element(self.z[2][1].to_vec(), 0), + }, + }; + + let z = ext6.sub(builder, &x_e6, &y_e6); + + ext6.assert_isequal(builder, &z, &z_e6); + + ext6.ext2.curve_f.check_mul(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e6_sub() { + compile_generic(&E6SubCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + + let mut assignment = E6SubCircuit:: { + x: [[[M31::from(0); 48]; 2]; 3], + y: [[[M31::from(0); 48]; 2]; 3], + z: [[[M31::from(0); 48]; 2]; 3], + }; + + let x0_b0_a0_bytes = [ + 117, 67, 202, 118, 173, 110, 225, 14, 221, 151, 124, 122, 61, 149, 241, 18, 203, 205, 177, + 75, 70, 107, 95, 134, 44, 31, 134, 223, 223, 119, 166, 241, 140, 160, 77, 31, 209, 113, + 203, 150, 180, 66, 197, 237, 193, 121, 208, 0, + ]; + let x0_b0_a1_bytes = [ + 110, 149, 76, 85, 199, 140, 8, 167, 128, 140, 218, 6, 61, 135, 234, 132, 175, 254, 240, + 100, 114, 91, 133, 61, 241, 86, 124, 142, 78, 33, 16, 246, 74, 52, 117, 19, 196, 33, 175, + 78, 43, 217, 62, 140, 22, 40, 10, 4, + ]; + let x0_b1_a0_bytes = [ + 167, 132, 89, 16, 118, 145, 244, 205, 24, 211, 7, 12, 137, 89, 178, 181, 153, 189, 41, 159, + 221, 184, 32, 188, 221, 84, 166, 48, 42, 197, 3, 73, 145, 51, 1, 61, 75, 2, 126, 160, 130, + 90, 183, 50, 169, 255, 244, 24, + ]; + let x0_b1_a1_bytes = [ + 39, 39, 199, 18, 85, 70, 114, 161, 125, 13, 253, 192, 41, 206, 162, 138, 196, 35, 243, 215, + 93, 0, 63, 90, 210, 114, 174, 223, 9, 211, 206, 184, 5, 176, 16, 169, 163, 132, 213, 168, + 237, 89, 183, 208, 107, 228, 88, 12, + ]; + let x0_b2_a0_bytes = [ + 3, 237, 190, 146, 70, 78, 10, 88, 226, 63, 22, 92, 151, 39, 13, 220, 63, 81, 10, 156, 43, + 201, 81, 202, 56, 56, 158, 192, 4, 42, 104, 209, 22, 195, 72, 183, 191, 39, 42, 147, 4, + 148, 232, 13, 145, 17, 54, 5, + ]; + let x0_b2_a1_bytes = [ + 225, 201, 6, 16, 49, 244, 117, 191, 166, 244, 42, 86, 39, 183, 237, 161, 17, 110, 212, 223, + 85, 115, 32, 210, 129, 151, 83, 12, 9, 192, 33, 159, 224, 159, 53, 119, 240, 95, 45, 169, + 13, 178, 183, 132, 43, 223, 8, 15, + ]; + let x1_b0_a0_bytes = [ + 143, 28, 221, 28, 84, 196, 131, 92, 212, 0, 200, 243, 196, 73, 255, 59, 7, 52, 5, 52, 7, + 221, 107, 182, 61, 65, 255, 95, 11, 146, 158, 222, 57, 139, 232, 252, 181, 149, 181, 61, + 71, 64, 160, 147, 89, 79, 87, 3, + ]; + let x1_b0_a1_bytes = [ + 192, 234, 124, 255, 103, 182, 125, 220, 156, 88, 109, 214, 103, 250, 217, 101, 68, 101, 36, + 254, 247, 79, 161, 60, 204, 171, 112, 23, 167, 16, 103, 254, 102, 55, 211, 111, 96, 222, + 146, 96, 106, 97, 77, 204, 16, 225, 246, 18, + ]; + let x1_b1_a0_bytes = [ + 28, 10, 69, 145, 40, 112, 221, 180, 163, 241, 233, 95, 178, 55, 10, 21, 76, 41, 31, 233, 7, + 242, 254, 187, 102, 68, 8, 118, 125, 34, 138, 22, 160, 179, 58, 176, 187, 214, 3, 245, 114, + 136, 0, 180, 234, 133, 85, 14, + ]; + let x1_b1_a1_bytes = [ + 119, 92, 66, 14, 39, 115, 82, 109, 0, 155, 226, 84, 212, 158, 188, 52, 234, 232, 165, 207, + 90, 156, 117, 52, 127, 224, 21, 27, 202, 135, 43, 189, 157, 13, 137, 2, 248, 24, 5, 250, + 183, 70, 125, 194, 206, 183, 148, 19, + ]; + let x1_b2_a0_bytes = [ + 172, 52, 244, 121, 0, 171, 124, 120, 72, 244, 219, 141, 30, 203, 101, 43, 76, 75, 35, 11, + 38, 13, 228, 90, 204, 27, 44, 108, 122, 94, 152, 135, 222, 164, 120, 85, 235, 64, 4, 44, + 242, 82, 68, 209, 105, 31, 133, 16, + ]; + let x1_b2_a1_bytes = [ + 3, 242, 58, 112, 155, 25, 152, 168, 242, 27, 59, 163, 47, 158, 43, 229, 19, 111, 181, 191, + 83, 236, 195, 148, 203, 169, 66, 113, 114, 122, 78, 15, 220, 32, 103, 124, 248, 65, 17, + 148, 68, 127, 27, 54, 166, 19, 190, 0, + ]; + let x2_b0_a0_bytes = [ + 145, 209, 236, 89, 89, 170, 92, 108, 8, 151, 8, 56, 119, 75, 158, 245, 231, 143, 93, 14, + 224, 96, 36, 55, 174, 240, 11, 115, 89, 49, 127, 119, 42, 194, 176, 101, 209, 131, 49, 164, + 7, 233, 164, 147, 82, 60, 122, 23, + ]; + let x2_b0_a1_bytes = [ + 89, 85, 207, 85, 95, 214, 137, 132, 227, 51, 193, 225, 211, 140, 188, 61, 143, 143, 125, + 93, 27, 222, 20, 104, 228, 189, 144, 106, 44, 92, 32, 92, 187, 169, 237, 230, 25, 235, 55, + 57, 91, 94, 113, 249, 239, 88, 20, 11, + ]; + let x2_b1_a0_bytes = [ + 139, 122, 20, 127, 77, 33, 23, 25, 117, 225, 29, 172, 214, 33, 168, 160, 77, 148, 10, 182, + 213, 198, 33, 0, 119, 16, 158, 186, 172, 162, 121, 50, 241, 127, 198, 140, 143, 43, 122, + 171, 15, 210, 182, 126, 190, 121, 159, 10, + ]; + let x2_b1_a1_bytes = [ + 91, 117, 132, 4, 46, 211, 30, 238, 124, 114, 110, 29, 84, 47, 146, 116, 254, 48, 254, 254, + 163, 54, 250, 140, 18, 165, 29, 184, 196, 150, 26, 96, 63, 79, 211, 233, 97, 19, 236, 249, + 207, 249, 185, 71, 135, 62, 197, 18, + ]; + let x2_b2_a0_bytes = [ + 2, 99, 202, 24, 70, 163, 140, 153, 153, 75, 142, 127, 119, 92, 83, 207, 23, 252, 151, 135, + 166, 142, 158, 214, 43, 47, 247, 71, 15, 23, 71, 174, 15, 203, 27, 165, 138, 142, 65, 178, + 172, 39, 36, 118, 17, 4, 178, 14, + ]; + let x2_b2_a1_bytes = [ + 222, 215, 203, 159, 149, 218, 221, 22, 180, 216, 239, 178, 247, 24, 194, 188, 253, 254, 30, + 32, 2, 135, 92, 61, 182, 237, 16, 155, 150, 69, 211, 143, 4, 127, 206, 250, 247, 29, 28, + 21, 201, 50, 156, 78, 133, 203, 74, 14, + ]; + + for i in 0..48 { + assignment.x[0][0][i] = M31::from(x0_b0_a0_bytes[i]); + assignment.x[0][1][i] = M31::from(x0_b0_a1_bytes[i]); + assignment.x[1][0][i] = M31::from(x0_b1_a0_bytes[i]); + assignment.x[1][1][i] = M31::from(x0_b1_a1_bytes[i]); + assignment.x[2][0][i] = M31::from(x0_b2_a0_bytes[i]); + assignment.x[2][1][i] = M31::from(x0_b2_a1_bytes[i]); + assignment.y[0][0][i] = M31::from(x1_b0_a0_bytes[i]); + assignment.y[0][1][i] = M31::from(x1_b0_a1_bytes[i]); + assignment.y[1][0][i] = M31::from(x1_b1_a0_bytes[i]); + assignment.y[1][1][i] = M31::from(x1_b1_a1_bytes[i]); + assignment.y[2][0][i] = M31::from(x1_b2_a0_bytes[i]); + assignment.y[2][1][i] = M31::from(x1_b2_a1_bytes[i]); + assignment.z[0][0][i] = M31::from(x2_b0_a0_bytes[i]); + assignment.z[0][1][i] = M31::from(x2_b0_a1_bytes[i]); + assignment.z[1][0][i] = M31::from(x2_b1_a0_bytes[i]); + assignment.z[1][1][i] = M31::from(x2_b1_a1_bytes[i]); + assignment.z[2][0][i] = M31::from(x2_b2_a0_bytes[i]); + assignment.z[2][1][i] = M31::from(x2_b2_a1_bytes[i]); + } + + debug_eval(&E6SubCircuit::default(), &assignment, hint_registry); +} +declare_circuit!(E6MulCircuit { + x: [[[Variable; 48]; 2]; 3], + y: [[[Variable; 48]; 2]; 3], + z: [[[Variable; 48]; 2]; 3], +}); + +impl GenericDefine for E6MulCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext6 = Ext6::new(builder); + + let x_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.x[0][0].to_vec(), 0), + a1: new_internal_element(self.x[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.x[1][0].to_vec(), 0), + a1: new_internal_element(self.x[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.x[2][0].to_vec(), 0), + a1: new_internal_element(self.x[2][1].to_vec(), 0), + }, + }; + + let y_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.y[0][0].to_vec(), 0), + a1: new_internal_element(self.y[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.y[1][0].to_vec(), 0), + a1: new_internal_element(self.y[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.y[2][0].to_vec(), 0), + a1: new_internal_element(self.y[2][1].to_vec(), 0), + }, + }; + + let z_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.z[0][0].to_vec(), 0), + a1: new_internal_element(self.z[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.z[1][0].to_vec(), 0), + a1: new_internal_element(self.z[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.z[2][0].to_vec(), 0), + a1: new_internal_element(self.z[2][1].to_vec(), 0), + }, + }; + + let z = ext6.mul(builder, &x_e6, &y_e6); + + ext6.assert_isequal(builder, &z, &z_e6); + + ext6.ext2.curve_f.check_mul(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e6_mul() { + compile_generic(&E6MulCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + + let mut assignment = E6MulCircuit:: { + x: [[[M31::from(0); 48]; 2]; 3], + y: [[[M31::from(0); 48]; 2]; 3], + z: [[[M31::from(0); 48]; 2]; 3], + }; + + let x0_b0_a0_bytes = [ + 46, 171, 188, 186, 190, 115, 108, 16, 106, 47, 30, 48, 92, 33, 24, 187, 243, 219, 27, 71, + 225, 210, 31, 244, 228, 11, 110, 205, 138, 94, 101, 51, 32, 146, 68, 158, 91, 248, 87, 49, + 113, 45, 18, 9, 66, 223, 1, 9, + ]; + let x0_b0_a1_bytes = [ + 200, 59, 2, 153, 8, 53, 214, 186, 105, 82, 243, 109, 164, 109, 113, 140, 250, 42, 7, 118, + 205, 121, 7, 142, 25, 196, 1, 120, 181, 155, 93, 59, 47, 9, 39, 56, 222, 243, 229, 81, 42, + 190, 234, 135, 29, 21, 58, 10, + ]; + let x0_b1_a0_bytes = [ + 44, 28, 34, 122, 59, 250, 97, 234, 89, 159, 141, 225, 198, 102, 238, 93, 2, 213, 43, 132, + 40, 208, 140, 196, 58, 226, 107, 20, 163, 33, 14, 18, 176, 3, 23, 16, 30, 125, 126, 32, 22, + 190, 71, 210, 30, 191, 219, 11, + ]; + let x0_b1_a1_bytes = [ + 117, 245, 238, 225, 186, 36, 41, 224, 112, 118, 52, 177, 6, 63, 94, 95, 195, 156, 135, 55, + 66, 238, 102, 19, 236, 170, 247, 0, 192, 35, 113, 135, 126, 252, 180, 6, 19, 225, 9, 182, + 205, 4, 15, 215, 223, 141, 27, 12, + ]; + let x0_b2_a0_bytes = [ + 39, 225, 139, 50, 21, 53, 177, 230, 184, 63, 137, 162, 135, 228, 11, 252, 62, 38, 15, 226, + 82, 118, 68, 100, 144, 193, 13, 144, 106, 160, 183, 126, 103, 164, 151, 4, 93, 223, 90, + 137, 128, 105, 212, 176, 142, 231, 9, 13, + ]; + let x0_b2_a1_bytes = [ + 13, 33, 87, 166, 233, 45, 135, 152, 194, 168, 223, 42, 131, 60, 4, 47, 58, 198, 193, 106, + 193, 188, 61, 167, 198, 143, 154, 46, 53, 12, 174, 127, 82, 235, 72, 155, 54, 216, 81, 166, + 76, 250, 194, 201, 20, 170, 145, 14, + ]; + let x1_b0_a0_bytes = [ + 2, 211, 218, 184, 13, 175, 37, 119, 109, 40, 212, 219, 183, 74, 233, 163, 185, 243, 126, + 237, 106, 186, 211, 233, 160, 102, 0, 230, 100, 165, 248, 28, 96, 119, 174, 107, 209, 142, + 190, 193, 152, 62, 155, 175, 169, 70, 198, 1, + ]; + let x1_b0_a1_bytes = [ + 2, 133, 167, 173, 76, 108, 164, 230, 130, 110, 187, 191, 213, 215, 105, 214, 206, 183, 176, + 90, 84, 70, 109, 18, 236, 29, 96, 101, 149, 41, 37, 218, 71, 92, 40, 234, 134, 231, 239, + 125, 255, 90, 112, 176, 182, 248, 118, 3, + ]; + let x1_b1_a0_bytes = [ + 84, 102, 133, 136, 37, 82, 182, 154, 143, 152, 228, 7, 202, 193, 77, 174, 99, 19, 163, 168, + 144, 32, 47, 97, 46, 107, 52, 174, 168, 67, 202, 93, 144, 247, 196, 217, 179, 40, 147, 112, + 208, 95, 228, 191, 236, 175, 23, 21, + ]; + let x1_b1_a1_bytes = [ + 250, 209, 134, 38, 35, 182, 176, 144, 176, 100, 39, 18, 144, 67, 229, 122, 63, 26, 6, 185, + 14, 76, 77, 69, 198, 73, 252, 148, 179, 201, 15, 229, 74, 147, 206, 37, 103, 84, 160, 82, + 223, 173, 206, 135, 34, 221, 149, 19, + ]; + let x1_b2_a0_bytes = [ + 78, 219, 161, 76, 22, 59, 94, 124, 156, 131, 175, 147, 51, 145, 148, 54, 54, 193, 166, 92, + 244, 72, 183, 189, 189, 119, 33, 102, 90, 90, 228, 193, 246, 103, 108, 63, 181, 50, 240, + 142, 75, 148, 11, 253, 219, 175, 4, 18, + ]; + let x1_b2_a1_bytes = [ + 157, 255, 244, 149, 96, 149, 68, 19, 16, 227, 89, 166, 192, 157, 80, 183, 121, 211, 186, 8, + 244, 156, 202, 65, 14, 189, 252, 38, 110, 38, 172, 34, 136, 186, 155, 102, 39, 200, 132, + 159, 155, 58, 186, 36, 41, 164, 111, 20, + ]; + let x2_b0_a0_bytes = [ + 139, 57, 43, 3, 203, 41, 159, 16, 165, 223, 135, 253, 137, 144, 225, 68, 65, 203, 47, 32, + 3, 82, 64, 122, 20, 104, 160, 155, 106, 139, 224, 96, 40, 95, 114, 1, 213, 182, 187, 111, + 179, 56, 224, 4, 45, 79, 115, 19, + ]; + let x2_b0_a1_bytes = [ + 182, 46, 28, 46, 128, 147, 103, 169, 72, 64, 229, 0, 37, 163, 104, 210, 193, 180, 172, 228, + 228, 129, 16, 194, 11, 41, 55, 53, 204, 163, 74, 69, 245, 7, 24, 42, 79, 15, 171, 228, 122, + 254, 81, 177, 236, 102, 202, 9, + ]; + let x2_b1_a0_bytes = [ + 198, 127, 46, 145, 88, 18, 205, 163, 244, 216, 212, 57, 7, 225, 227, 66, 178, 27, 48, 206, + 191, 120, 8, 212, 167, 146, 38, 34, 123, 43, 223, 50, 131, 109, 49, 118, 100, 5, 30, 194, + 25, 89, 176, 3, 231, 181, 38, 18, + ]; + let x2_b1_a1_bytes = [ + 194, 218, 15, 76, 86, 206, 59, 118, 75, 9, 124, 137, 170, 6, 84, 184, 125, 247, 228, 139, + 152, 171, 125, 242, 137, 199, 170, 11, 116, 83, 40, 184, 189, 14, 93, 195, 111, 138, 213, + 242, 212, 90, 128, 60, 50, 132, 69, 0, + ]; + let x2_b2_a0_bytes = [ + 239, 2, 119, 9, 143, 45, 156, 90, 96, 201, 15, 104, 44, 158, 202, 13, 109, 55, 21, 111, 75, + 182, 173, 240, 31, 203, 253, 85, 116, 120, 118, 81, 170, 84, 219, 136, 90, 225, 140, 106, + 110, 222, 193, 62, 128, 47, 233, 3, + ]; + let x2_b2_a1_bytes = [ + 163, 224, 214, 44, 217, 30, 86, 63, 64, 74, 49, 222, 85, 74, 144, 121, 178, 207, 115, 64, + 58, 69, 243, 3, 42, 210, 225, 158, 53, 32, 60, 206, 224, 25, 208, 203, 198, 36, 195, 177, + 49, 37, 9, 229, 194, 16, 66, 13, + ]; + + for i in 0..48 { + assignment.x[0][0][i] = M31::from(x0_b0_a0_bytes[i]); + assignment.x[0][1][i] = M31::from(x0_b0_a1_bytes[i]); + assignment.x[1][0][i] = M31::from(x0_b1_a0_bytes[i]); + assignment.x[1][1][i] = M31::from(x0_b1_a1_bytes[i]); + assignment.x[2][0][i] = M31::from(x0_b2_a0_bytes[i]); + assignment.x[2][1][i] = M31::from(x0_b2_a1_bytes[i]); + assignment.y[0][0][i] = M31::from(x1_b0_a0_bytes[i]); + assignment.y[0][1][i] = M31::from(x1_b0_a1_bytes[i]); + assignment.y[1][0][i] = M31::from(x1_b1_a0_bytes[i]); + assignment.y[1][1][i] = M31::from(x1_b1_a1_bytes[i]); + assignment.y[2][0][i] = M31::from(x1_b2_a0_bytes[i]); + assignment.y[2][1][i] = M31::from(x1_b2_a1_bytes[i]); + assignment.z[0][0][i] = M31::from(x2_b0_a0_bytes[i]); + assignment.z[0][1][i] = M31::from(x2_b0_a1_bytes[i]); + assignment.z[1][0][i] = M31::from(x2_b1_a0_bytes[i]); + assignment.z[1][1][i] = M31::from(x2_b1_a1_bytes[i]); + assignment.z[2][0][i] = M31::from(x2_b2_a0_bytes[i]); + assignment.z[2][1][i] = M31::from(x2_b2_a1_bytes[i]); + } + + debug_eval(&E6MulCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E6SquareCircuit { + x: [[[Variable; 48]; 2]; 3], + z: [[[Variable; 48]; 2]; 3], +}); + +impl GenericDefine for E6SquareCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext6 = Ext6::new(builder); + + let x_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.x[0][0].to_vec(), 0), + a1: new_internal_element(self.x[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.x[1][0].to_vec(), 0), + a1: new_internal_element(self.x[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.x[2][0].to_vec(), 0), + a1: new_internal_element(self.x[2][1].to_vec(), 0), + }, + }; + + let z_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.z[0][0].to_vec(), 0), + a1: new_internal_element(self.z[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.z[1][0].to_vec(), 0), + a1: new_internal_element(self.z[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.z[2][0].to_vec(), 0), + a1: new_internal_element(self.z[2][1].to_vec(), 0), + }, + }; + + let z = Ext6::square(&mut ext6, builder, &x_e6); + + ext6.assert_isequal(builder, &z, &z_e6); + + ext6.ext2.curve_f.check_mul(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e6_square() { + compile_generic(&E6SquareCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + + let mut assignment = E6SquareCircuit:: { + x: [[[M31::from(0); 48]; 2]; 3], + z: [[[M31::from(0); 48]; 2]; 3], + }; + + let x0_b0_a0_bytes = [ + 149, 252, 160, 161, 66, 108, 73, 228, 243, 168, 88, 37, 39, 191, 205, 98, 241, 61, 156, 45, + 52, 99, 67, 183, 178, 209, 195, 34, 3, 60, 173, 58, 42, 202, 210, 5, 243, 177, 190, 5, 100, + 201, 100, 209, 177, 231, 187, 21, + ]; + let x0_b0_a1_bytes = [ + 71, 251, 181, 71, 28, 134, 218, 38, 21, 32, 1, 21, 12, 198, 125, 39, 126, 54, 18, 10, 211, + 211, 104, 12, 203, 201, 22, 109, 65, 3, 1, 27, 81, 91, 222, 53, 40, 245, 103, 137, 79, 164, + 255, 137, 145, 160, 203, 14, + ]; + let x0_b1_a0_bytes = [ + 205, 77, 53, 46, 150, 38, 185, 19, 233, 44, 84, 29, 158, 181, 240, 47, 163, 3, 60, 164, + 129, 252, 205, 122, 22, 84, 219, 0, 146, 112, 155, 9, 115, 133, 84, 26, 18, 164, 163, 46, + 177, 9, 213, 50, 103, 38, 251, 19, + ]; + let x0_b1_a1_bytes = [ + 223, 114, 215, 138, 45, 155, 174, 77, 6, 236, 176, 6, 65, 105, 33, 159, 192, 203, 32, 175, + 68, 156, 172, 222, 85, 103, 32, 36, 253, 197, 35, 30, 173, 48, 57, 212, 101, 214, 118, 190, + 92, 26, 177, 126, 37, 200, 151, 0, + ]; + let x0_b2_a0_bytes = [ + 111, 205, 175, 51, 14, 14, 198, 159, 176, 90, 194, 167, 0, 56, 230, 245, 50, 250, 31, 186, + 192, 108, 141, 75, 129, 86, 203, 69, 3, 152, 246, 84, 135, 11, 208, 177, 161, 143, 194, 0, + 99, 6, 201, 91, 5, 202, 196, 25, + ]; + let x0_b2_a1_bytes = [ + 99, 11, 232, 254, 225, 220, 249, 134, 36, 14, 216, 116, 146, 232, 227, 0, 25, 38, 227, 90, + 221, 113, 88, 108, 85, 40, 251, 88, 105, 103, 27, 208, 30, 113, 129, 203, 249, 108, 144, + 154, 211, 251, 107, 12, 168, 105, 81, 1, + ]; + let x2_b0_a0_bytes = [ + 21, 61, 58, 202, 150, 61, 40, 78, 118, 188, 60, 67, 131, 26, 108, 110, 94, 101, 43, 230, + 149, 87, 4, 207, 232, 27, 6, 220, 59, 150, 3, 211, 185, 62, 139, 123, 205, 7, 160, 187, + 143, 73, 151, 82, 50, 160, 193, 21, + ]; + let x2_b0_a1_bytes = [ + 84, 111, 79, 158, 196, 154, 235, 30, 225, 34, 147, 112, 32, 10, 3, 32, 32, 18, 230, 244, + 84, 230, 163, 116, 200, 228, 152, 247, 75, 60, 129, 62, 23, 205, 10, 243, 139, 55, 149, + 133, 138, 253, 102, 67, 135, 148, 215, 12, + ]; + let x2_b1_a0_bytes = [ + 252, 95, 170, 53, 240, 79, 250, 214, 195, 45, 219, 214, 5, 204, 25, 135, 59, 205, 74, 233, + 211, 96, 45, 236, 68, 55, 107, 182, 36, 114, 211, 245, 43, 119, 254, 19, 178, 186, 73, 240, + 160, 164, 21, 145, 101, 105, 34, 14, + ]; + let x2_b1_a1_bytes = [ + 36, 26, 27, 52, 88, 138, 91, 54, 24, 252, 143, 17, 39, 84, 137, 8, 191, 39, 110, 10, 128, + 92, 128, 150, 191, 216, 22, 202, 75, 194, 99, 92, 20, 247, 159, 212, 122, 217, 46, 186, 86, + 242, 95, 187, 128, 14, 38, 5, + ]; + let x2_b2_a0_bytes = [ + 193, 78, 94, 37, 120, 49, 230, 20, 47, 17, 14, 25, 228, 74, 163, 207, 94, 107, 42, 232, + 230, 107, 131, 61, 250, 195, 232, 77, 250, 90, 114, 234, 173, 250, 168, 6, 172, 100, 78, + 35, 121, 210, 81, 97, 89, 82, 156, 17, + ]; + let x2_b2_a1_bytes = [ + 22, 126, 225, 109, 245, 84, 53, 66, 154, 187, 48, 16, 56, 105, 180, 247, 79, 94, 107, 74, + 174, 39, 224, 37, 9, 10, 74, 204, 85, 33, 2, 165, 244, 66, 179, 232, 52, 28, 97, 71, 5, + 169, 96, 142, 213, 59, 47, 19, + ]; + + for i in 0..48 { + assignment.x[0][0][i] = M31::from(x0_b0_a0_bytes[i]); + assignment.x[0][1][i] = M31::from(x0_b0_a1_bytes[i]); + assignment.x[1][0][i] = M31::from(x0_b1_a0_bytes[i]); + assignment.x[1][1][i] = M31::from(x0_b1_a1_bytes[i]); + assignment.x[2][0][i] = M31::from(x0_b2_a0_bytes[i]); + assignment.x[2][1][i] = M31::from(x0_b2_a1_bytes[i]); + assignment.z[0][0][i] = M31::from(x2_b0_a0_bytes[i]); + assignment.z[0][1][i] = M31::from(x2_b0_a1_bytes[i]); + assignment.z[1][0][i] = M31::from(x2_b1_a0_bytes[i]); + assignment.z[1][1][i] = M31::from(x2_b1_a1_bytes[i]); + assignment.z[2][0][i] = M31::from(x2_b2_a0_bytes[i]); + assignment.z[2][1][i] = M31::from(x2_b2_a1_bytes[i]); + } + + debug_eval(&E6SquareCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E6DivCircuit { + x: [[[Variable; 48]; 2]; 3], + y: [[[Variable; 48]; 2]; 3], + z: [[[Variable; 48]; 2]; 3], +}); + +impl GenericDefine for E6DivCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext6 = Ext6::new(builder); + + let x_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.x[0][0].to_vec(), 0), + a1: new_internal_element(self.x[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.x[1][0].to_vec(), 0), + a1: new_internal_element(self.x[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.x[2][0].to_vec(), 0), + a1: new_internal_element(self.x[2][1].to_vec(), 0), + }, + }; + + let y_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.y[0][0].to_vec(), 0), + a1: new_internal_element(self.y[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.y[1][0].to_vec(), 0), + a1: new_internal_element(self.y[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.y[2][0].to_vec(), 0), + a1: new_internal_element(self.y[2][1].to_vec(), 0), + }, + }; + + let z_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.z[0][0].to_vec(), 0), + a1: new_internal_element(self.z[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.z[1][0].to_vec(), 0), + a1: new_internal_element(self.z[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.z[2][0].to_vec(), 0), + a1: new_internal_element(self.z[2][1].to_vec(), 0), + }, + }; + + let z = ext6.div(builder, &x_e6, &y_e6); + + ext6.assert_isequal(builder, &z, &z_e6); + + ext6.ext2.curve_f.check_mul(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e6_div() { + // let compile_result = + // compile_generic(&E6DivCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + + let mut assignment = E6DivCircuit:: { + x: [[[M31::from(0); 48]; 2]; 3], + y: [[[M31::from(0); 48]; 2]; 3], + z: [[[M31::from(0); 48]; 2]; 3], + }; + + let x0_b0_a0_bytes = [ + 107, 46, 111, 157, 84, 135, 89, 107, 29, 18, 126, 99, 75, 231, 135, 136, 247, 175, 57, 99, + 90, 48, 149, 234, 25, 93, 172, 7, 58, 116, 96, 138, 58, 167, 206, 46, 194, 47, 132, 61, 81, + 255, 143, 139, 9, 178, 179, 24, + ]; + let x0_b0_a1_bytes = [ + 65, 150, 235, 198, 199, 204, 132, 179, 17, 239, 168, 83, 18, 235, 124, 242, 186, 37, 23, + 63, 212, 62, 143, 188, 225, 59, 144, 230, 131, 184, 85, 242, 107, 221, 207, 52, 189, 231, + 244, 131, 25, 123, 52, 56, 61, 9, 22, 20, + ]; + let x0_b1_a0_bytes = [ + 173, 39, 135, 175, 251, 127, 251, 89, 158, 139, 94, 66, 180, 143, 155, 50, 213, 196, 158, + 102, 168, 240, 200, 30, 74, 10, 136, 214, 182, 205, 96, 211, 42, 67, 117, 205, 187, 245, + 70, 16, 253, 106, 190, 159, 65, 142, 118, 12, + ]; + let x0_b1_a1_bytes = [ + 231, 106, 130, 80, 207, 77, 88, 201, 127, 90, 167, 140, 61, 4, 133, 64, 239, 153, 233, 31, + 153, 238, 25, 23, 203, 39, 59, 37, 7, 191, 226, 200, 133, 35, 91, 114, 57, 124, 77, 70, + 252, 40, 241, 60, 103, 188, 249, 23, + ]; + let x0_b2_a0_bytes = [ + 119, 50, 63, 185, 207, 181, 225, 181, 10, 24, 209, 197, 165, 151, 189, 133, 107, 135, 22, + 230, 46, 166, 178, 27, 159, 132, 48, 130, 126, 52, 108, 36, 236, 227, 27, 98, 88, 15, 205, + 18, 147, 23, 65, 177, 186, 202, 219, 19, + ]; + let x0_b2_a1_bytes = [ + 165, 58, 17, 37, 247, 187, 48, 54, 42, 252, 33, 95, 119, 174, 86, 195, 0, 104, 57, 143, + 164, 118, 207, 61, 240, 19, 145, 50, 187, 85, 46, 215, 93, 133, 181, 13, 96, 65, 146, 185, + 132, 116, 84, 145, 253, 103, 193, 19, + ]; + let x1_b0_a0_bytes = [ + 16, 79, 32, 49, 174, 6, 172, 207, 122, 139, 231, 68, 149, 199, 95, 98, 12, 84, 238, 96, + 101, 210, 104, 62, 64, 216, 27, 120, 43, 210, 103, 245, 8, 199, 91, 75, 67, 163, 246, 235, + 19, 66, 153, 185, 41, 186, 103, 5, + ]; + let x1_b0_a1_bytes = [ + 57, 238, 57, 195, 235, 52, 131, 101, 220, 163, 24, 39, 229, 83, 27, 121, 219, 17, 39, 82, + 86, 239, 237, 251, 127, 220, 229, 92, 111, 31, 58, 175, 86, 76, 37, 169, 23, 148, 115, 146, + 124, 241, 174, 228, 149, 9, 90, 6, + ]; + let x1_b1_a0_bytes = [ + 247, 148, 68, 210, 199, 239, 86, 29, 204, 205, 220, 164, 22, 11, 24, 35, 228, 244, 237, + 116, 25, 70, 189, 251, 247, 70, 117, 156, 224, 249, 17, 138, 63, 50, 78, 4, 155, 91, 30, + 26, 123, 159, 172, 23, 130, 144, 43, 25, + ]; + let x1_b1_a1_bytes = [ + 60, 103, 177, 115, 150, 175, 97, 91, 229, 107, 241, 226, 110, 3, 139, 96, 108, 37, 224, + 144, 45, 117, 18, 230, 93, 140, 255, 15, 131, 111, 155, 73, 142, 169, 96, 196, 69, 110, + 227, 144, 70, 184, 233, 207, 145, 70, 3, 0, + ]; + let x1_b2_a0_bytes = [ + 199, 33, 152, 245, 103, 119, 131, 68, 162, 115, 65, 191, 82, 228, 118, 227, 249, 183, 102, + 194, 217, 231, 28, 41, 83, 99, 36, 244, 250, 58, 231, 247, 65, 63, 127, 246, 254, 218, 128, + 63, 150, 53, 205, 127, 25, 160, 45, 21, + ]; + let x1_b2_a1_bytes = [ + 149, 118, 225, 27, 180, 204, 98, 78, 29, 25, 184, 252, 36, 166, 66, 106, 123, 142, 80, 56, + 225, 137, 128, 130, 194, 102, 142, 115, 42, 12, 187, 161, 9, 23, 9, 34, 199, 12, 73, 213, + 22, 80, 114, 193, 138, 69, 67, 16, + ]; + let x2_b0_a0_bytes = [ + 90, 197, 146, 236, 129, 61, 116, 59, 100, 18, 45, 130, 188, 202, 114, 151, 175, 48, 14, + 125, 137, 143, 100, 130, 199, 246, 11, 98, 206, 173, 27, 90, 238, 217, 195, 190, 244, 184, + 44, 110, 36, 35, 90, 250, 84, 187, 120, 11, + ]; + let x2_b0_a1_bytes = [ + 156, 140, 120, 55, 221, 129, 220, 124, 199, 65, 79, 230, 109, 209, 226, 177, 66, 182, 240, + 70, 63, 51, 79, 248, 163, 108, 109, 49, 94, 187, 20, 174, 22, 226, 131, 36, 33, 33, 148, + 76, 96, 169, 72, 146, 78, 134, 169, 22, + ]; + let x2_b1_a0_bytes = [ + 164, 204, 252, 143, 75, 2, 19, 248, 173, 72, 189, 106, 203, 49, 221, 71, 109, 218, 238, 90, + 49, 209, 82, 251, 197, 96, 219, 145, 69, 188, 129, 219, 65, 76, 185, 220, 97, 253, 231, + 125, 226, 252, 178, 159, 83, 25, 55, 13, + ]; + let x2_b1_a1_bytes = [ + 191, 109, 242, 246, 21, 112, 126, 212, 129, 232, 137, 91, 89, 38, 9, 142, 25, 97, 38, 146, + 30, 113, 12, 214, 44, 194, 123, 45, 28, 142, 124, 137, 153, 160, 18, 38, 250, 208, 129, 46, + 181, 60, 20, 233, 105, 102, 124, 12, + ]; + let x2_b2_a0_bytes = [ + 222, 43, 171, 59, 32, 102, 33, 247, 125, 121, 241, 64, 19, 99, 21, 169, 182, 203, 33, 160, + 245, 2, 234, 186, 2, 46, 154, 173, 209, 58, 169, 112, 207, 46, 35, 152, 250, 162, 239, 99, + 154, 73, 56, 209, 26, 4, 113, 21, + ]; + let x2_b2_a1_bytes = [ + 228, 214, 111, 241, 243, 60, 177, 143, 184, 255, 55, 230, 82, 186, 163, 92, 237, 57, 148, + 219, 0, 129, 130, 243, 246, 252, 253, 72, 173, 70, 236, 178, 95, 186, 219, 127, 127, 214, + 36, 192, 161, 233, 161, 237, 197, 138, 146, 16, + ]; + + for i in 0..48 { + assignment.x[0][0][i] = M31::from(x0_b0_a0_bytes[i]); + assignment.x[0][1][i] = M31::from(x0_b0_a1_bytes[i]); + assignment.x[1][0][i] = M31::from(x0_b1_a0_bytes[i]); + assignment.x[1][1][i] = M31::from(x0_b1_a1_bytes[i]); + assignment.x[2][0][i] = M31::from(x0_b2_a0_bytes[i]); + assignment.x[2][1][i] = M31::from(x0_b2_a1_bytes[i]); + assignment.y[0][0][i] = M31::from(x1_b0_a0_bytes[i]); + assignment.y[0][1][i] = M31::from(x1_b0_a1_bytes[i]); + assignment.y[1][0][i] = M31::from(x1_b1_a0_bytes[i]); + assignment.y[1][1][i] = M31::from(x1_b1_a1_bytes[i]); + assignment.y[2][0][i] = M31::from(x1_b2_a0_bytes[i]); + assignment.y[2][1][i] = M31::from(x1_b2_a1_bytes[i]); + assignment.z[0][0][i] = M31::from(x2_b0_a0_bytes[i]); + assignment.z[0][1][i] = M31::from(x2_b0_a1_bytes[i]); + assignment.z[1][0][i] = M31::from(x2_b1_a0_bytes[i]); + assignment.z[1][1][i] = M31::from(x2_b1_a1_bytes[i]); + assignment.z[2][0][i] = M31::from(x2_b2_a0_bytes[i]); + assignment.z[2][1][i] = M31::from(x2_b2_a1_bytes[i]); + } + + debug_eval(&E6DivCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E6MulByNonResidueCircuit { + a: [[[Variable; 48]; 2]; 3], + c: [[[Variable; 48]; 2]; 3], // Public variable +}); + +impl GenericDefine for E6MulByNonResidueCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext6 = Ext6::new(builder); + + let a_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.a[0][0].to_vec(), 0), + a1: new_internal_element(self.a[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[1][0].to_vec(), 0), + a1: new_internal_element(self.a[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[2][0].to_vec(), 0), + a1: new_internal_element(self.a[2][1].to_vec(), 0), + }, + }; + + let c_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.c[0][0].to_vec(), 0), + a1: new_internal_element(self.c[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[1][0].to_vec(), 0), + a1: new_internal_element(self.c[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[2][0].to_vec(), 0), + a1: new_internal_element(self.c[2][1].to_vec(), 0), + }, + }; + + let result = ext6.mul_by_non_residue(builder, &a_e6); + + ext6.assert_isequal(builder, &result, &c_e6); + + ext6.ext2.curve_f.check_mul(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e6_mul_by_non_residue() { + compile_generic( + &E6MulByNonResidueCircuit::default(), + CompileOptions::default(), + ) + .unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); // Updated hint registration + + let mut assignment = E6MulByNonResidueCircuit:: { + a: [[[M31::from(0); 48]; 2]; 3], + c: [[[M31::from(0); 48]; 2]; 3], + }; + + let x0_b0_a0_bytes = [ + 64, 88, 27, 110, 238, 39, 175, 216, 0, 29, 131, 126, 214, 115, 176, 254, 76, 55, 0, 215, + 59, 70, 40, 219, 237, 215, 146, 219, 178, 177, 230, 83, 93, 215, 207, 32, 189, 190, 197, + 133, 30, 113, 224, 95, 33, 111, 88, 0, + ]; + let x0_b0_a1_bytes = [ + 122, 78, 181, 224, 62, 88, 174, 158, 82, 231, 130, 108, 51, 204, 90, 167, 55, 38, 234, 69, + 242, 182, 217, 230, 63, 135, 52, 193, 222, 71, 109, 97, 201, 228, 118, 32, 66, 97, 177, 39, + 136, 245, 14, 185, 224, 252, 41, 16, + ]; + let x0_b1_a0_bytes = [ + 19, 197, 36, 21, 31, 161, 152, 225, 90, 247, 154, 217, 54, 210, 113, 218, 37, 48, 18, 232, + 196, 128, 209, 136, 220, 3, 88, 71, 54, 180, 158, 44, 100, 135, 14, 96, 125, 46, 82, 140, + 201, 53, 79, 149, 38, 100, 5, 6, + ]; + let x0_b1_a1_bytes = [ + 59, 16, 99, 177, 130, 33, 110, 86, 138, 187, 1, 227, 142, 131, 36, 234, 164, 215, 71, 206, + 79, 145, 201, 34, 138, 244, 1, 46, 141, 35, 110, 92, 237, 207, 216, 108, 22, 224, 70, 148, + 146, 55, 87, 189, 20, 82, 12, 17, + ]; + let x0_b2_a0_bytes = [ + 74, 190, 238, 44, 234, 56, 156, 176, 254, 232, 115, 121, 131, 101, 133, 143, 203, 79, 126, + 36, 45, 89, 244, 171, 139, 36, 88, 144, 76, 160, 27, 232, 239, 54, 71, 229, 147, 4, 218, + 192, 199, 157, 95, 79, 10, 1, 249, 11, + ]; + let x0_b2_a1_bytes = [ + 180, 248, 244, 93, 213, 144, 28, 114, 150, 60, 209, 143, 249, 0, 232, 139, 255, 201, 20, + 252, 109, 69, 225, 215, 17, 242, 137, 229, 0, 49, 158, 32, 234, 225, 207, 223, 55, 93, 15, + 83, 134, 142, 58, 203, 248, 80, 179, 11, + ]; + let x2_b0_a0_bytes = [ + 150, 197, 249, 206, 20, 168, 127, 62, 104, 172, 162, 233, 137, 100, 157, 3, 204, 133, 105, + 40, 191, 19, 19, 212, 121, 50, 206, 170, 75, 111, 125, 199, 5, 85, 119, 5, 92, 167, 202, + 109, 65, 15, 37, 132, 17, 176, 69, 0, + ]; + let x2_b0_a1_bytes = [ + 254, 182, 227, 138, 191, 201, 184, 34, 149, 37, 69, 9, 125, 102, 109, 27, 203, 25, 147, 32, + 155, 158, 213, 131, 157, 22, 226, 117, 77, 209, 185, 8, 218, 24, 23, 197, 203, 97, 233, 19, + 78, 44, 154, 26, 3, 82, 172, 23, + ]; + let x2_b1_a0_bytes = [ + 64, 88, 27, 110, 238, 39, 175, 216, 0, 29, 131, 126, 214, 115, 176, 254, 76, 55, 0, 215, + 59, 70, 40, 219, 237, 215, 146, 219, 178, 177, 230, 83, 93, 215, 207, 32, 189, 190, 197, + 133, 30, 113, 224, 95, 33, 111, 88, 0, + ]; + let x2_b1_a1_bytes = [ + 122, 78, 181, 224, 62, 88, 174, 158, 82, 231, 130, 108, 51, 204, 90, 167, 55, 38, 234, 69, + 242, 182, 217, 230, 63, 135, 52, 193, 222, 71, 109, 97, 201, 228, 118, 32, 66, 97, 177, 39, + 136, 245, 14, 185, 224, 252, 41, 16, + ]; + let x2_b2_a0_bytes = [ + 19, 197, 36, 21, 31, 161, 152, 225, 90, 247, 154, 217, 54, 210, 113, 218, 37, 48, 18, 232, + 196, 128, 209, 136, 220, 3, 88, 71, 54, 180, 158, 44, 100, 135, 14, 96, 125, 46, 82, 140, + 201, 53, 79, 149, 38, 100, 5, 6, + ]; + let x2_b2_a1_bytes = [ + 59, 16, 99, 177, 130, 33, 110, 86, 138, 187, 1, 227, 142, 131, 36, 234, 164, 215, 71, 206, + 79, 145, 201, 34, 138, 244, 1, 46, 141, 35, 110, 92, 237, 207, 216, 108, 22, 224, 70, 148, + 146, 55, 87, 189, 20, 82, 12, 17, + ]; + + for i in 0..48 { + assignment.a[0][0][i] = M31::from(x0_b0_a0_bytes[i]); + assignment.a[0][1][i] = M31::from(x0_b0_a1_bytes[i]); + assignment.a[1][0][i] = M31::from(x0_b1_a0_bytes[i]); + assignment.a[1][1][i] = M31::from(x0_b1_a1_bytes[i]); + assignment.a[2][0][i] = M31::from(x0_b2_a0_bytes[i]); + assignment.a[2][1][i] = M31::from(x0_b2_a1_bytes[i]); + assignment.c[0][0][i] = M31::from(x2_b0_a0_bytes[i]); + assignment.c[0][1][i] = M31::from(x2_b0_a1_bytes[i]); + assignment.c[1][0][i] = M31::from(x2_b1_a0_bytes[i]); + assignment.c[1][1][i] = M31::from(x2_b1_a1_bytes[i]); + assignment.c[2][0][i] = M31::from(x2_b2_a0_bytes[i]); + assignment.c[2][1][i] = M31::from(x2_b2_a1_bytes[i]); + } + + debug_eval( + &E6MulByNonResidueCircuit::default(), + &assignment, + hint_registry, + ); +} +declare_circuit!(E6MulByE2Circuit { + a: [[[Variable; 48]; 2]; 3], + b: [[Variable; 48]; 2], + c: [[[Variable; 48]; 2]; 3], // Public variable +}); + +impl GenericDefine for E6MulByE2Circuit { + fn define>(&self, builder: &mut Builder) { + let mut ext6 = Ext6::new(builder); + + let a_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.a[0][0].to_vec(), 0), + a1: new_internal_element(self.a[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[1][0].to_vec(), 0), + a1: new_internal_element(self.a[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[2][0].to_vec(), 0), + a1: new_internal_element(self.a[2][1].to_vec(), 0), + }, + }; + + let b_e2 = GE2 { + a0: new_internal_element(self.b[0].to_vec(), 0), + a1: new_internal_element(self.b[1].to_vec(), 0), + }; + + let c_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.c[0][0].to_vec(), 0), + a1: new_internal_element(self.c[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[1][0].to_vec(), 0), + a1: new_internal_element(self.c[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[2][0].to_vec(), 0), + a1: new_internal_element(self.c[2][1].to_vec(), 0), + }, + }; + + let result = ext6.mul_by_e2(builder, &a_e6, &b_e2); + + ext6.assert_isequal(builder, &result, &c_e6); + + ext6.ext2.curve_f.check_mul(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e6_mul_by_e2() { + compile_generic(&E6MulByE2Circuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + + let mut assignment = E6MulByE2Circuit:: { + a: [[[M31::from(0); 48]; 2]; 3], + b: [[M31::from(0); 48]; 2], + c: [[[M31::from(0); 48]; 2]; 3], + }; + + let x0_b0_a0_bytes = [ + 16, 57, 17, 157, 215, 105, 216, 201, 10, 247, 112, 166, 181, 199, 152, 28, 187, 8, 152, + 145, 14, 226, 75, 178, 88, 143, 56, 117, 1, 55, 178, 123, 152, 85, 192, 63, 120, 146, 235, + 227, 59, 102, 139, 161, 232, 201, 13, 15, + ]; + let x0_b0_a1_bytes = [ + 147, 25, 67, 213, 252, 165, 176, 151, 237, 58, 214, 37, 92, 194, 214, 83, 112, 89, 63, 174, + 49, 236, 181, 205, 144, 131, 107, 113, 212, 194, 51, 103, 59, 21, 254, 228, 22, 100, 72, + 56, 115, 145, 130, 37, 159, 1, 86, 9, + ]; + let x0_b1_a0_bytes = [ + 110, 7, 174, 136, 163, 166, 185, 216, 13, 253, 185, 54, 98, 138, 172, 69, 174, 201, 224, + 173, 136, 39, 104, 115, 49, 121, 205, 32, 41, 60, 211, 20, 121, 127, 59, 21, 232, 21, 70, + 229, 85, 167, 158, 220, 206, 194, 61, 13, + ]; + let x0_b1_a1_bytes = [ + 202, 174, 161, 164, 127, 100, 139, 170, 157, 175, 150, 48, 67, 211, 86, 114, 98, 112, 118, + 3, 114, 72, 79, 21, 159, 94, 217, 155, 248, 141, 225, 169, 226, 250, 129, 40, 158, 219, + 156, 118, 90, 99, 244, 64, 66, 206, 74, 21, + ]; + let x0_b2_a0_bytes = [ + 215, 144, 182, 192, 19, 102, 21, 232, 158, 9, 31, 130, 212, 188, 238, 38, 170, 19, 229, 84, + 75, 24, 111, 142, 45, 145, 229, 48, 24, 184, 233, 158, 38, 62, 101, 186, 114, 91, 221, 55, + 65, 177, 108, 67, 158, 124, 155, 9, + ]; + let x0_b2_a1_bytes = [ + 64, 44, 116, 89, 206, 11, 228, 146, 252, 236, 146, 29, 185, 236, 100, 94, 122, 98, 78, 87, + 177, 244, 214, 2, 13, 132, 236, 195, 65, 161, 227, 70, 108, 189, 17, 229, 3, 52, 169, 45, + 226, 64, 174, 22, 254, 15, 191, 12, + ]; + let x1_a0_bytes = [ + 114, 106, 253, 79, 101, 99, 40, 6, 197, 30, 178, 73, 223, 122, 42, 247, 149, 236, 253, 200, + 209, 115, 97, 199, 100, 27, 124, 167, 186, 36, 238, 0, 217, 9, 223, 217, 47, 188, 242, 234, + 223, 225, 128, 69, 157, 221, 219, 12, + ]; + let x1_a1_bytes = [ + 124, 98, 167, 48, 13, 100, 22, 101, 244, 251, 76, 109, 36, 17, 221, 126, 147, 35, 171, 78, + 158, 4, 185, 1, 216, 28, 6, 58, 116, 108, 163, 8, 182, 253, 15, 51, 79, 123, 131, 108, 64, + 10, 160, 56, 244, 55, 72, 7, + ]; + let x2_b0_a0_bytes = [ + 153, 55, 58, 153, 36, 139, 91, 1, 157, 142, 175, 89, 153, 215, 36, 153, 112, 24, 223, 137, + 246, 136, 0, 233, 164, 171, 128, 99, 192, 200, 94, 71, 91, 98, 71, 192, 102, 137, 106, 60, + 158, 122, 239, 0, 147, 81, 179, 5, + ]; + let x2_b0_a1_bytes = [ + 173, 66, 149, 241, 216, 131, 213, 206, 107, 1, 169, 230, 249, 39, 185, 87, 1, 148, 238, + 174, 23, 178, 86, 73, 54, 92, 238, 174, 43, 198, 127, 81, 163, 84, 151, 138, 197, 159, 230, + 81, 0, 78, 116, 244, 147, 43, 211, 4, + ]; + let x2_b1_a0_bytes = [ + 62, 157, 10, 199, 254, 78, 13, 97, 44, 120, 224, 70, 91, 75, 226, 66, 53, 202, 111, 148, + 237, 182, 102, 239, 86, 42, 226, 26, 238, 35, 85, 252, 219, 84, 202, 237, 73, 130, 254, 21, + 215, 62, 251, 87, 198, 30, 87, 21, + ]; + let x2_b1_a1_bytes = [ + 118, 55, 226, 164, 64, 86, 177, 125, 35, 181, 228, 222, 21, 244, 209, 30, 48, 165, 18, 136, + 74, 152, 217, 237, 180, 21, 74, 136, 35, 36, 224, 236, 200, 90, 169, 148, 75, 14, 110, 250, + 159, 162, 149, 221, 95, 147, 151, 17, + ]; + let x2_b2_a0_bytes = [ + 178, 231, 158, 80, 57, 45, 61, 51, 192, 173, 128, 149, 51, 219, 187, 6, 27, 224, 109, 58, + 182, 90, 23, 59, 58, 241, 11, 39, 250, 215, 241, 128, 16, 22, 140, 42, 141, 122, 205, 52, + 39, 245, 102, 215, 23, 174, 254, 10, + ]; + let x2_b2_a1_bytes = [ + 56, 187, 148, 53, 25, 217, 226, 99, 85, 254, 164, 111, 88, 109, 86, 6, 250, 129, 217, 211, + 222, 9, 171, 190, 246, 148, 132, 90, 176, 253, 247, 67, 72, 186, 177, 65, 187, 205, 117, + 234, 105, 70, 3, 215, 194, 53, 158, 13, + ]; + + for i in 0..48 { + assignment.a[0][0][i] = M31::from(x0_b0_a0_bytes[i]); + assignment.a[0][1][i] = M31::from(x0_b0_a1_bytes[i]); + assignment.a[1][0][i] = M31::from(x0_b1_a0_bytes[i]); + assignment.a[1][1][i] = M31::from(x0_b1_a1_bytes[i]); + assignment.a[2][0][i] = M31::from(x0_b2_a0_bytes[i]); + assignment.a[2][1][i] = M31::from(x0_b2_a1_bytes[i]); + assignment.b[0][i] = M31::from(x1_a0_bytes[i]); + assignment.b[1][i] = M31::from(x1_a1_bytes[i]); + assignment.c[0][0][i] = M31::from(x2_b0_a0_bytes[i]); + assignment.c[0][1][i] = M31::from(x2_b0_a1_bytes[i]); + assignment.c[1][0][i] = M31::from(x2_b1_a0_bytes[i]); + assignment.c[1][1][i] = M31::from(x2_b1_a1_bytes[i]); + assignment.c[2][0][i] = M31::from(x2_b2_a0_bytes[i]); + assignment.c[2][1][i] = M31::from(x2_b2_a1_bytes[i]); + } + + debug_eval(&E6MulByE2Circuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E6MulBy01Circuit { + a: [[[Variable; 48]; 2]; 3], + c0: [[Variable; 48]; 2], + c1: [[Variable; 48]; 2], + c: [[[Variable; 48]; 2]; 3], // Public variable +}); + +impl GenericDefine for E6MulBy01Circuit { + fn define>(&self, builder: &mut Builder) { + let mut ext6 = Ext6::new(builder); + + let a_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.a[0][0].to_vec(), 0), + a1: new_internal_element(self.a[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[1][0].to_vec(), 0), + a1: new_internal_element(self.a[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[2][0].to_vec(), 0), + a1: new_internal_element(self.a[2][1].to_vec(), 0), + }, + }; + + let c0_e2 = GE2 { + a0: new_internal_element(self.c0[0].to_vec(), 0), + a1: new_internal_element(self.c0[1].to_vec(), 0), + }; + + let c1_e2 = GE2 { + a0: new_internal_element(self.c1[0].to_vec(), 0), + a1: new_internal_element(self.c1[1].to_vec(), 0), + }; + + let c_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.c[0][0].to_vec(), 0), + a1: new_internal_element(self.c[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[1][0].to_vec(), 0), + a1: new_internal_element(self.c[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[2][0].to_vec(), 0), + a1: new_internal_element(self.c[2][1].to_vec(), 0), + }, + }; + + let result = ext6.mul_by_01(builder, &a_e6, &c0_e2, &c1_e2); + + ext6.assert_isequal(builder, &result, &c_e6); + + ext6.ext2.curve_f.check_mul(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e6_mul_by_01() { + // let compile_result = + // compile_generic(&E6MulBy01Circuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + + let mut assignment = E6MulBy01Circuit:: { + a: [[[M31::from(0); 48]; 2]; 3], + c0: [[M31::from(0); 48]; 2], + c1: [[M31::from(0); 48]; 2], + c: [[[M31::from(0); 48]; 2]; 3], + }; + let x0_b0_a0_bytes = [ + 239, 229, 161, 178, 64, 169, 64, 146, 202, 108, 226, 209, 171, 161, 210, 163, 187, 178, 82, + 117, 197, 147, 230, 123, 200, 118, 68, 116, 34, 4, 83, 5, 152, 248, 76, 174, 5, 112, 146, + 135, 108, 122, 197, 44, 5, 108, 105, 4, + ]; + let x0_b0_a1_bytes = [ + 216, 141, 84, 101, 248, 2, 198, 56, 82, 51, 71, 90, 78, 183, 64, 149, 118, 57, 60, 187, + 111, 237, 194, 199, 219, 87, 147, 173, 207, 209, 64, 111, 123, 230, 108, 254, 244, 133, 53, + 127, 124, 63, 113, 147, 77, 118, 183, 3, + ]; + let x0_b1_a0_bytes = [ + 252, 68, 138, 30, 240, 188, 31, 211, 225, 176, 125, 69, 159, 20, 155, 74, 109, 188, 182, + 240, 117, 158, 67, 126, 170, 59, 191, 249, 176, 86, 164, 133, 153, 181, 0, 208, 232, 168, + 81, 236, 62, 23, 145, 81, 4, 201, 133, 15, + ]; + let x0_b1_a1_bytes = [ + 69, 4, 32, 130, 215, 215, 132, 105, 38, 152, 198, 127, 228, 215, 56, 21, 211, 172, 97, 142, + 60, 71, 76, 251, 213, 10, 173, 20, 136, 142, 2, 77, 211, 134, 48, 29, 14, 55, 27, 130, 246, + 106, 239, 48, 238, 88, 93, 16, + ]; + let x0_b2_a0_bytes = [ + 14, 194, 113, 170, 251, 40, 206, 58, 33, 253, 225, 10, 146, 13, 43, 65, 62, 73, 217, 189, + 74, 205, 137, 20, 25, 102, 195, 121, 173, 201, 149, 110, 4, 161, 24, 190, 208, 112, 21, + 234, 125, 84, 183, 230, 250, 37, 20, 24, + ]; + let x0_b2_a1_bytes = [ + 107, 114, 82, 151, 175, 169, 28, 209, 16, 59, 150, 160, 0, 123, 71, 152, 251, 135, 94, 27, + 160, 226, 181, 125, 56, 52, 234, 172, 73, 206, 144, 100, 142, 162, 227, 202, 84, 30, 143, + 93, 245, 250, 146, 243, 7, 104, 210, 22, + ]; + let x1_a0_bytes = [ + 186, 151, 19, 68, 40, 192, 201, 108, 0, 91, 94, 25, 135, 234, 188, 37, 171, 13, 192, 227, + 215, 174, 77, 246, 206, 150, 192, 189, 188, 18, 52, 109, 174, 255, 45, 7, 112, 19, 158, + 246, 207, 176, 139, 230, 213, 125, 252, 17, + ]; + let x1_a1_bytes = [ + 21, 143, 182, 121, 149, 97, 79, 60, 204, 97, 32, 34, 238, 52, 114, 69, 145, 70, 181, 151, + 20, 254, 118, 41, 21, 21, 225, 217, 126, 14, 178, 141, 239, 124, 163, 129, 73, 88, 135, + 179, 215, 84, 62, 114, 42, 64, 68, 7, + ]; + let x2_a0_bytes = [ + 138, 88, 211, 80, 5, 54, 126, 91, 234, 136, 231, 41, 212, 67, 79, 189, 64, 69, 62, 2, 130, + 218, 241, 195, 164, 151, 141, 15, 73, 243, 223, 243, 185, 165, 89, 79, 139, 227, 17, 201, + 244, 9, 196, 252, 155, 229, 41, 14, + ]; + let x2_a1_bytes = [ + 188, 54, 82, 119, 88, 70, 53, 72, 210, 158, 255, 168, 36, 111, 243, 221, 38, 115, 86, 69, + 191, 147, 157, 51, 99, 204, 161, 227, 117, 163, 184, 79, 219, 60, 101, 125, 235, 215, 48, + 147, 224, 77, 251, 76, 225, 240, 1, 17, + ]; + let x3_b0_a0_bytes = [ + 40, 96, 6, 151, 173, 123, 226, 158, 228, 208, 229, 107, 250, 123, 77, 212, 186, 116, 42, + 150, 131, 126, 246, 122, 153, 71, 111, 206, 37, 27, 249, 210, 5, 214, 63, 13, 26, 76, 236, + 228, 15, 27, 44, 68, 86, 230, 77, 24, + ]; + let x3_b0_a1_bytes = [ + 140, 178, 226, 46, 250, 177, 38, 248, 99, 255, 15, 55, 233, 151, 29, 199, 102, 241, 52, 35, + 95, 113, 183, 199, 214, 107, 102, 112, 177, 214, 175, 168, 34, 130, 161, 190, 49, 245, 201, + 91, 45, 35, 145, 57, 43, 204, 222, 2, + ]; + let x3_b1_a0_bytes = [ + 246, 231, 192, 70, 80, 0, 214, 197, 196, 105, 124, 197, 34, 205, 213, 205, 9, 189, 175, + 232, 67, 175, 201, 10, 43, 23, 174, 144, 116, 110, 21, 175, 81, 126, 128, 21, 252, 69, 168, + 54, 68, 86, 146, 195, 55, 198, 122, 4, + ]; + let x3_b1_a1_bytes = [ + 249, 240, 86, 232, 156, 233, 242, 7, 101, 210, 128, 59, 74, 51, 114, 86, 181, 2, 22, 200, + 2, 61, 154, 240, 138, 7, 136, 232, 239, 90, 39, 109, 149, 12, 0, 53, 248, 48, 198, 163, 88, + 108, 25, 86, 41, 192, 50, 8, + ]; + let x3_b2_a0_bytes = [ + 202, 120, 182, 202, 118, 232, 150, 158, 129, 79, 84, 133, 125, 42, 4, 175, 202, 174, 44, + 152, 67, 60, 67, 69, 30, 143, 122, 56, 108, 238, 162, 89, 197, 243, 15, 19, 209, 209, 143, + 217, 164, 38, 189, 171, 222, 13, 210, 19, + ]; + let x3_b2_a1_bytes = [ + 154, 134, 254, 146, 102, 16, 154, 179, 160, 89, 167, 216, 187, 214, 197, 64, 58, 26, 12, + 159, 107, 92, 130, 18, 94, 56, 7, 68, 33, 81, 44, 186, 118, 68, 216, 94, 84, 87, 90, 231, + 93, 231, 209, 158, 109, 43, 242, 20, + ]; + + for i in 0..48 { + assignment.a[0][0][i] = M31::from(x0_b0_a0_bytes[i]); + assignment.a[0][1][i] = M31::from(x0_b0_a1_bytes[i]); + assignment.a[1][0][i] = M31::from(x0_b1_a0_bytes[i]); + assignment.a[1][1][i] = M31::from(x0_b1_a1_bytes[i]); + assignment.a[2][0][i] = M31::from(x0_b2_a0_bytes[i]); + assignment.a[2][1][i] = M31::from(x0_b2_a1_bytes[i]); + assignment.c0[0][i] = M31::from(x1_a0_bytes[i]); + assignment.c0[1][i] = M31::from(x1_a1_bytes[i]); + assignment.c1[0][i] = M31::from(x2_a0_bytes[i]); + assignment.c1[1][i] = M31::from(x2_a1_bytes[i]); + assignment.c[0][0][i] = M31::from(x3_b0_a0_bytes[i]); + assignment.c[0][1][i] = M31::from(x3_b0_a1_bytes[i]); + assignment.c[1][0][i] = M31::from(x3_b1_a0_bytes[i]); + assignment.c[1][1][i] = M31::from(x3_b1_a1_bytes[i]); + assignment.c[2][0][i] = M31::from(x3_b2_a0_bytes[i]); + assignment.c[2][1][i] = M31::from(x3_b2_a1_bytes[i]); + } + + debug_eval(&E6MulBy01Circuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E6NegCircuit { + a: [[[Variable; 48]; 2]; 3], + c: [[[Variable; 48]; 2]; 3], // Public variable +}); + +impl GenericDefine for E6NegCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext6 = Ext6::new(builder); + + let a_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.a[0][0].to_vec(), 0), + a1: new_internal_element(self.a[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[1][0].to_vec(), 0), + a1: new_internal_element(self.a[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[2][0].to_vec(), 0), + a1: new_internal_element(self.a[2][1].to_vec(), 0), + }, + }; + + let c_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.c[0][0].to_vec(), 0), + a1: new_internal_element(self.c[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[1][0].to_vec(), 0), + a1: new_internal_element(self.c[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[2][0].to_vec(), 0), + a1: new_internal_element(self.c[2][1].to_vec(), 0), + }, + }; + + let result = ext6.neg(builder, &a_e6); + ext6.assert_isequal(builder, &result, &c_e6); + + ext6.ext2.curve_f.check_mul(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e6_neg() { + compile_generic(&E6NegCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + + let mut assignment = E6NegCircuit:: { + a: [[[M31::from(0); 48]; 2]; 3], + c: [[[M31::from(0); 48]; 2]; 3], + }; + + let x0_b0_a0_bytes = [ + 116, 6, 234, 253, 168, 74, 65, 30, 170, 142, 158, 184, 33, 84, 176, 59, 39, 31, 68, 152, + 100, 233, 15, 176, 94, 80, 69, 58, 137, 167, 36, 189, 51, 230, 84, 91, 111, 236, 115, 231, + 37, 185, 220, 160, 17, 14, 196, 3, + ]; + let x0_b0_a1_bytes = [ + 74, 217, 113, 27, 10, 53, 174, 157, 74, 32, 126, 65, 73, 185, 191, 214, 75, 202, 59, 40, 1, + 229, 87, 54, 182, 214, 172, 205, 241, 238, 156, 6, 115, 105, 1, 134, 107, 190, 214, 227, + 195, 156, 125, 3, 27, 177, 68, 20, + ]; + let x0_b1_a0_bytes = [ + 156, 1, 181, 29, 159, 51, 200, 2, 179, 151, 250, 205, 64, 17, 207, 162, 7, 246, 108, 213, + 210, 159, 81, 251, 163, 6, 43, 23, 100, 250, 77, 164, 96, 61, 201, 255, 155, 157, 17, 183, + 138, 30, 232, 18, 210, 234, 119, 13, + ]; + let x0_b1_a1_bytes = [ + 67, 74, 29, 124, 15, 39, 125, 211, 85, 255, 163, 176, 37, 195, 144, 76, 67, 69, 116, 59, + 54, 163, 254, 137, 168, 252, 55, 64, 225, 163, 218, 46, 91, 93, 133, 23, 105, 178, 144, + 210, 71, 102, 22, 156, 220, 31, 126, 3, + ]; + let x0_b2_a0_bytes = [ + 165, 53, 235, 67, 200, 212, 135, 127, 103, 241, 184, 182, 61, 98, 13, 112, 24, 61, 180, 73, + 29, 81, 249, 63, 111, 128, 12, 220, 3, 213, 244, 214, 126, 148, 142, 13, 20, 84, 97, 163, + 244, 109, 32, 173, 58, 146, 143, 23, + ]; + let x0_b2_a1_bytes = [ + 139, 176, 170, 247, 65, 42, 233, 157, 160, 227, 93, 104, 151, 125, 167, 9, 117, 73, 194, 2, + 23, 230, 150, 90, 203, 142, 63, 12, 47, 48, 180, 119, 136, 117, 87, 9, 48, 16, 188, 215, + 25, 173, 239, 70, 235, 131, 89, 12, + ]; + let x3_b0_a0_bytes = [ + 55, 164, 21, 2, 87, 181, 189, 155, 85, 113, 181, 248, 220, 171, 251, 226, 252, 214, 108, + 94, 60, 233, 32, 183, 96, 194, 63, 185, 251, 163, 82, 167, 163, 198, 246, 231, 70, 187, + 167, 99, 116, 45, 163, 152, 216, 3, 61, 22, + ]; + let x3_b0_a1_bytes = [ + 97, 209, 141, 228, 245, 202, 80, 28, 181, 223, 213, 111, 181, 70, 236, 71, 216, 43, 117, + 206, 159, 237, 216, 48, 9, 60, 216, 37, 147, 92, 218, 93, 100, 67, 74, 189, 74, 233, 68, + 103, 214, 73, 2, 54, 207, 96, 188, 5, + ]; + let x3_b1_a0_bytes = [ + 15, 169, 74, 226, 96, 204, 54, 183, 76, 104, 89, 227, 189, 238, 220, 123, 28, 0, 68, 33, + 206, 50, 223, 107, 27, 12, 90, 220, 32, 81, 41, 192, 118, 111, 130, 67, 26, 10, 10, 148, + 15, 200, 151, 38, 24, 39, 137, 12, + ]; + let x3_b1_a1_bytes = [ + 104, 96, 226, 131, 240, 216, 129, 230, 169, 0, 176, 0, 217, 60, 27, 210, 224, 176, 60, 187, + 106, 47, 50, 221, 22, 22, 77, 179, 163, 167, 156, 53, 124, 79, 198, 43, 77, 245, 138, 120, + 82, 128, 105, 157, 13, 242, 130, 22, + ]; + let x3_b2_a0_bytes = [ + 6, 117, 20, 188, 55, 43, 119, 58, 152, 14, 155, 250, 192, 157, 158, 174, 11, 185, 252, 172, + 131, 129, 55, 39, 80, 146, 120, 23, 129, 118, 130, 141, 88, 24, 189, 53, 162, 83, 186, 167, + 165, 120, 95, 140, 175, 127, 113, 2, + ]; + let x3_b2_a1_bytes = [ + 32, 250, 84, 8, 190, 213, 21, 28, 95, 28, 246, 72, 103, 130, 4, 21, 175, 172, 238, 243, + 137, 236, 153, 12, 244, 131, 69, 231, 85, 27, 195, 236, 78, 55, 244, 57, 134, 151, 95, 115, + 128, 57, 144, 242, 254, 141, 167, 13, + ]; + + for i in 0..48 { + assignment.a[0][0][i] = M31::from(x0_b0_a0_bytes[i]); + assignment.a[0][1][i] = M31::from(x0_b0_a1_bytes[i]); + assignment.a[1][0][i] = M31::from(x0_b1_a0_bytes[i]); + assignment.a[1][1][i] = M31::from(x0_b1_a1_bytes[i]); + assignment.a[2][0][i] = M31::from(x0_b2_a0_bytes[i]); + assignment.a[2][1][i] = M31::from(x0_b2_a1_bytes[i]); + assignment.c[0][0][i] = M31::from(x3_b0_a0_bytes[i]); + assignment.c[0][1][i] = M31::from(x3_b0_a1_bytes[i]); + assignment.c[1][0][i] = M31::from(x3_b1_a0_bytes[i]); + assignment.c[1][1][i] = M31::from(x3_b1_a1_bytes[i]); + assignment.c[2][0][i] = M31::from(x3_b2_a0_bytes[i]); + assignment.c[2][1][i] = M31::from(x3_b2_a1_bytes[i]); + } + + debug_eval(&E6NegCircuit::default(), &assignment, hint_registry); +} +declare_circuit!(E6InverseCircuit { + a: [[[Variable; 48]; 2]; 3], + c: [[[Variable; 48]; 2]; 3], // Public variable +}); + +impl GenericDefine for E6InverseCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext6 = Ext6::new(builder); + + let a_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.a[0][0].to_vec(), 0), + a1: new_internal_element(self.a[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[1][0].to_vec(), 0), + a1: new_internal_element(self.a[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[2][0].to_vec(), 0), + a1: new_internal_element(self.a[2][1].to_vec(), 0), + }, + }; + + let c_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.c[0][0].to_vec(), 0), + a1: new_internal_element(self.c[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[1][0].to_vec(), 0), + a1: new_internal_element(self.c[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[2][0].to_vec(), 0), + a1: new_internal_element(self.c[2][1].to_vec(), 0), + }, + }; + + let result = ext6.inverse(builder, &a_e6); + ext6.assert_isequal(builder, &result, &c_e6); + + ext6.ext2.curve_f.check_mul(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e6_inverse() { + compile_generic(&E6InverseCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + + let mut assignment = E6InverseCircuit:: { + a: [[[M31::from(0); 48]; 2]; 3], + c: [[[M31::from(0); 48]; 2]; 3], + }; + + let x0_b0_a0_bytes = [ + 42, 191, 107, 1, 61, 26, 173, 13, 160, 78, 61, 122, 92, 29, 163, 162, 133, 224, 146, 25, + 59, 158, 4, 106, 41, 66, 220, 84, 62, 148, 251, 247, 116, 66, 190, 14, 209, 79, 118, 179, + 163, 124, 142, 157, 70, 75, 135, 9, + ]; + let x0_b0_a1_bytes = [ + 211, 80, 115, 82, 164, 221, 106, 133, 199, 205, 208, 188, 168, 21, 57, 40, 179, 134, 122, + 83, 214, 125, 232, 227, 94, 208, 153, 53, 5, 91, 60, 107, 111, 192, 42, 241, 126, 4, 223, + 7, 202, 41, 151, 248, 42, 136, 202, 18, + ]; + let x0_b1_a0_bytes = [ + 230, 142, 241, 182, 172, 53, 243, 38, 51, 114, 207, 39, 193, 178, 94, 164, 237, 60, 49, + 201, 56, 151, 44, 35, 115, 180, 149, 238, 95, 234, 223, 68, 115, 48, 95, 57, 92, 8, 2, 55, + 89, 227, 203, 32, 236, 8, 37, 8, + ]; + let x0_b1_a1_bytes = [ + 175, 204, 91, 4, 54, 39, 255, 219, 210, 131, 129, 250, 20, 29, 26, 195, 225, 84, 161, 62, + 19, 4, 156, 203, 236, 158, 167, 164, 177, 156, 156, 191, 39, 168, 77, 57, 213, 134, 75, + 249, 148, 206, 186, 177, 237, 248, 25, 20, + ]; + let x0_b2_a0_bytes = [ + 72, 70, 59, 131, 175, 200, 39, 60, 247, 77, 55, 65, 105, 174, 197, 3, 147, 15, 56, 34, 225, + 101, 126, 71, 117, 222, 105, 147, 48, 91, 61, 157, 29, 199, 238, 20, 87, 18, 143, 164, 207, + 65, 151, 173, 84, 221, 69, 8, + ]; + let x0_b2_a1_bytes = [ + 124, 176, 9, 207, 196, 159, 159, 65, 67, 227, 130, 231, 59, 74, 160, 145, 206, 84, 167, + 199, 54, 98, 13, 14, 88, 232, 246, 1, 134, 251, 196, 191, 209, 208, 89, 19, 159, 83, 100, + 169, 65, 148, 60, 147, 220, 58, 39, 10, + ]; + let x3_b0_a0_bytes = [ + 241, 211, 96, 221, 135, 252, 51, 160, 240, 44, 177, 6, 233, 34, 43, 65, 225, 187, 89, 228, + 132, 88, 152, 212, 254, 70, 210, 244, 133, 61, 76, 202, 1, 214, 152, 159, 50, 108, 226, + 224, 77, 138, 58, 52, 196, 171, 248, 2, + ]; + let x3_b0_a1_bytes = [ + 102, 158, 6, 155, 253, 105, 81, 12, 177, 99, 91, 215, 140, 62, 35, 12, 235, 225, 229, 225, + 110, 51, 146, 31, 209, 37, 204, 124, 153, 134, 139, 92, 185, 55, 128, 182, 137, 140, 126, + 70, 213, 91, 217, 27, 245, 2, 135, 12, + ]; + let x3_b1_a0_bytes = [ + 80, 250, 232, 255, 129, 150, 236, 243, 241, 211, 26, 29, 138, 145, 205, 240, 56, 146, 126, + 65, 224, 117, 109, 179, 85, 61, 139, 201, 97, 176, 208, 180, 213, 192, 135, 20, 113, 168, + 90, 174, 215, 144, 185, 63, 18, 118, 199, 16, + ]; + let x3_b1_a1_bytes = [ + 79, 99, 136, 50, 88, 106, 124, 92, 158, 146, 150, 211, 235, 118, 143, 132, 238, 206, 182, + 239, 228, 54, 55, 88, 72, 112, 177, 56, 58, 73, 253, 9, 218, 106, 84, 202, 167, 194, 137, + 34, 248, 71, 70, 206, 63, 56, 27, 6, + ]; + let x3_b2_a0_bytes = [ + 214, 90, 220, 213, 91, 247, 245, 183, 117, 178, 27, 175, 136, 232, 144, 62, 52, 5, 23, 96, + 176, 81, 121, 179, 19, 91, 112, 174, 163, 162, 230, 68, 126, 148, 42, 157, 89, 88, 68, 113, + 249, 197, 123, 86, 231, 35, 229, 21, + ]; + let x3_b2_a1_bytes = [ + 138, 250, 218, 214, 205, 57, 171, 168, 67, 27, 229, 167, 87, 177, 26, 86, 82, 57, 100, 97, + 198, 239, 162, 172, 62, 30, 46, 232, 182, 101, 113, 253, 139, 213, 76, 44, 222, 32, 201, + 43, 244, 235, 1, 22, 14, 141, 123, 25, + ]; + + for i in 0..48 { + assignment.a[0][0][i] = M31::from(x0_b0_a0_bytes[i]); + assignment.a[0][1][i] = M31::from(x0_b0_a1_bytes[i]); + assignment.a[1][0][i] = M31::from(x0_b1_a0_bytes[i]); + assignment.a[1][1][i] = M31::from(x0_b1_a1_bytes[i]); + assignment.a[2][0][i] = M31::from(x0_b2_a0_bytes[i]); + assignment.a[2][1][i] = M31::from(x0_b2_a1_bytes[i]); + assignment.c[0][0][i] = M31::from(x3_b0_a0_bytes[i]); + assignment.c[0][1][i] = M31::from(x3_b0_a1_bytes[i]); + assignment.c[1][0][i] = M31::from(x3_b1_a0_bytes[i]); + assignment.c[1][1][i] = M31::from(x3_b1_a1_bytes[i]); + assignment.c[2][0][i] = M31::from(x3_b2_a0_bytes[i]); + assignment.c[2][1][i] = M31::from(x3_b2_a1_bytes[i]); + } + + debug_eval(&E6InverseCircuit::default(), &assignment, hint_registry); +} diff --git a/circuit-std-rs/tests/gnark/emulated/field_bls12381/mod.rs b/circuit-std-rs/tests/gnark/emulated/field_bls12381/mod.rs new file mode 100644 index 00000000..f2828701 --- /dev/null +++ b/circuit-std-rs/tests/gnark/emulated/field_bls12381/mod.rs @@ -0,0 +1,3 @@ +pub mod e12; +pub mod e2; +pub mod e6; diff --git a/circuit-std-rs/tests/gnark/emulated/mod.rs b/circuit-std-rs/tests/gnark/emulated/mod.rs new file mode 100644 index 00000000..89f7a447 --- /dev/null +++ b/circuit-std-rs/tests/gnark/emulated/mod.rs @@ -0,0 +1,2 @@ +pub mod field_bls12381; +pub mod sw_bls12381; diff --git a/circuit-std-rs/tests/gnark/emulated/sw_bls12381/g1.rs b/circuit-std-rs/tests/gnark/emulated/sw_bls12381/g1.rs new file mode 100644 index 00000000..f6fcad69 --- /dev/null +++ b/circuit-std-rs/tests/gnark/emulated/sw_bls12381/g1.rs @@ -0,0 +1,142 @@ +use circuit_std_rs::gnark::{ + element::Element, + emulated::sw_bls12381::g1::{G1Affine, G1}, + hints::register_hint, +}; +use expander_compiler::{ + compile::CompileOptions, + declare_circuit, + frontend::{ + compile_generic, extra::debug_eval, GenericDefine, HintRegistry, M31Config, RootAPI, + Variable, M31, + }, +}; + +declare_circuit!(G1AddCircuit { + p: [[Variable; 48]; 2], + q: [[Variable; 48]; 2], + r: [[Variable; 48]; 2], +}); + +impl GenericDefine for G1AddCircuit { + fn define>(&self, builder: &mut Builder) { + let mut g1 = G1::new(builder); + let p1_g1 = G1Affine { + x: Element::new( + self.p[0].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + y: Element::new( + self.p[1].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + }; + let p2_g1 = G1Affine { + x: Element::new( + self.q[0].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + y: Element::new( + self.q[1].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + }; + let r_g1 = G1Affine { + x: Element::new( + self.r[0].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + y: Element::new( + self.r[1].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + }; + let mut r = g1.add(builder, &p1_g1, &p2_g1); + for _ in 0..16 { + r = g1.add(builder, &r, &p2_g1); + } + g1.curve_f.assert_isequal(builder, &r.x, &r_g1.x); + g1.curve_f.assert_isequal(builder, &r.y, &r_g1.y); + g1.curve_f.check_mul(builder); + g1.curve_f.table.final_check(builder); + g1.curve_f.table.final_check(builder); + g1.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_g1_add() { + compile_generic(&G1AddCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = G1AddCircuit:: { + p: [[M31::from(0); 48]; 2], + q: [[M31::from(0); 48]; 2], + r: [[M31::from(0); 48]; 2], + }; + let p1_x_bytes = [ + 169, 204, 143, 202, 195, 182, 32, 187, 150, 46, 27, 88, 137, 82, 209, 11, 255, 228, 147, + 72, 218, 149, 56, 139, 243, 28, 49, 146, 210, 5, 238, 232, 111, 204, 78, 170, 83, 191, 222, + 173, 137, 165, 150, 240, 62, 27, 213, 8, + ]; + let p1_y_bytes = [ + 85, 56, 238, 125, 65, 131, 108, 201, 186, 2, 96, 151, 226, 80, 22, 2, 111, 141, 203, 67, + 50, 147, 209, 102, 238, 82, 12, 96, 172, 239, 2, 177, 184, 146, 208, 150, 63, 214, 239, + 198, 101, 74, 169, 226, 148, 53, 104, 1, + ]; + let p2_x_bytes = [ + 108, 4, 52, 16, 255, 115, 116, 198, 234, 60, 202, 181, 169, 240, 221, 33, 38, 178, 114, + 195, 169, 16, 147, 33, 62, 116, 10, 191, 25, 163, 79, 192, 140, 43, 109, 235, 157, 42, 15, + 48, 115, 213, 48, 51, 19, 165, 178, 17, + ]; + let p2_y_bytes = [ + 130, 146, 65, 1, 211, 117, 217, 145, 69, 140, 76, 106, 43, 160, 192, 247, 96, 225, 2, 72, + 219, 238, 254, 202, 9, 210, 253, 111, 73, 49, 26, 145, 68, 161, 64, 101, 238, 0, 236, 128, + 164, 92, 95, 30, 143, 178, 6, 20, + ]; + let res_x_bytes = [ + 148, 92, 212, 64, 35, 246, 218, 14, 150, 169, 177, 191, 61, 6, 4, 120, 60, 253, 36, 139, + 95, 95, 14, 122, 89, 3, 62, 198, 100, 50, 114, 221, 144, 187, 29, 15, 203, 89, 220, 29, + 120, 25, 153, 169, 184, 184, 133, 16, + ]; + let res_y_bytes = [ + 41, 226, 254, 238, 50, 145, 74, 128, 160, 125, 237, 161, 93, 66, 241, 104, 218, 230, 154, + 134, 24, 204, 225, 220, 175, 115, 243, 57, 238, 157, 161, 175, 213, 34, 145, 106, 226, 230, + 19, 110, 196, 196, 229, 104, 152, 64, 12, 6, + ]; + + for i in 0..48 { + assignment.p[0][i] = M31::from(p1_x_bytes[i]); + assignment.p[1][i] = M31::from(p1_y_bytes[i]); + assignment.q[0][i] = M31::from(p2_x_bytes[i]); + assignment.q[1][i] = M31::from(p2_y_bytes[i]); + assignment.r[0][i] = M31::from(res_x_bytes[i]); + assignment.r[1][i] = M31::from(res_y_bytes[i]); + } + + debug_eval(&G1AddCircuit::default(), &assignment, hint_registry); +} diff --git a/circuit-std-rs/tests/gnark/emulated/sw_bls12381/mod.rs b/circuit-std-rs/tests/gnark/emulated/sw_bls12381/mod.rs new file mode 100644 index 00000000..fbdb3da2 --- /dev/null +++ b/circuit-std-rs/tests/gnark/emulated/sw_bls12381/mod.rs @@ -0,0 +1,2 @@ +pub mod g1; +pub mod pairing; diff --git a/circuit-std-rs/tests/gnark/emulated/sw_bls12381/pairing.rs b/circuit-std-rs/tests/gnark/emulated/sw_bls12381/pairing.rs new file mode 100644 index 00000000..51192af2 --- /dev/null +++ b/circuit-std-rs/tests/gnark/emulated/sw_bls12381/pairing.rs @@ -0,0 +1,252 @@ +use circuit_std_rs::gnark::{ + element::Element, + emulated::{ + field_bls12381::e2::GE2, + sw_bls12381::{g1::*, g2::*, pairing::Pairing}, + }, + hints::register_hint, +}; +use expander_compiler::{ + declare_circuit, + frontend::{extra::debug_eval, GenericDefine, HintRegistry, M31Config, RootAPI, Variable, M31}, +}; + +declare_circuit!(PairingCheckGKRCircuit { + in1_g1: [[Variable; 48]; 2], + in2_g1: [[Variable; 48]; 2], + in1_g2: [[[Variable; 48]; 2]; 2], + in2_g2: [[[Variable; 48]; 2]; 2], +}); + +impl GenericDefine for PairingCheckGKRCircuit { + fn define>(&self, builder: &mut Builder) { + let mut pairing = Pairing::new(builder); + let p1_g1 = G1Affine { + x: Element::new( + self.in1_g1[0].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + y: Element::new( + self.in1_g1[1].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + }; + let p2_g1 = G1Affine { + x: Element::new( + self.in2_g1[0].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + y: Element::new( + self.in2_g1[1].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + }; + let q1_g2 = G2AffP { + x: GE2 { + a0: Element::new( + self.in1_g2[0][0].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + a1: Element::new( + self.in1_g2[0][1].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + }, + y: GE2 { + a0: Element::new( + self.in1_g2[1][0].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + a1: Element::new( + self.in1_g2[1][1].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + }, + }; + let q2_g2 = G2AffP { + x: GE2 { + a0: Element::new( + self.in2_g2[0][0].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + a1: Element::new( + self.in2_g2[0][1].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + }, + y: GE2 { + a0: Element::new( + self.in2_g2[1][0].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + a1: Element::new( + self.in2_g2[1][1].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + }, + }; + pairing + .pairing_check( + builder, + &[p1_g1, p2_g1], + &mut [ + G2Affine { + p: q1_g2, + lines: LineEvaluations::default(), + }, + G2Affine { + p: q2_g2, + lines: LineEvaluations::default(), + }, + ], + ) + .unwrap(); + pairing.ext12.ext6.ext2.curve_f.check_mul(builder); + pairing.ext12.ext6.ext2.curve_f.table.final_check(builder); + pairing.ext12.ext6.ext2.curve_f.table.final_check(builder); + pairing.ext12.ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_pairing_check_gkr() { + // let compile_result = + // compile_generic(&PairingCheckGKRCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = PairingCheckGKRCircuit:: { + in1_g1: [[M31::from(0); 48]; 2], + in2_g1: [[M31::from(0); 48]; 2], + in1_g2: [[[M31::from(0); 48]; 2]; 2], + in2_g2: [[[M31::from(0); 48]; 2]; 2], + }; + let p1_x_bytes = [ + 138, 209, 41, 52, 20, 222, 185, 9, 48, 234, 53, 109, 218, 26, 76, 112, 204, 195, 135, 184, + 95, 253, 141, 179, 243, 220, 94, 195, 151, 34, 112, 210, 63, 186, 25, 221, 129, 128, 76, + 209, 101, 191, 44, 36, 248, 25, 127, 3, + ]; + let p1_y_bytes = [ + 97, 193, 54, 196, 208, 241, 229, 252, 144, 121, 89, 115, 226, 242, 251, 60, 142, 182, 216, + 242, 212, 30, 189, 82, 97, 228, 230, 80, 38, 19, 77, 187, 242, 96, 65, 136, 115, 75, 173, + 136, 35, 202, 199, 3, 37, 33, 182, 19, + ]; + let p2_x_bytes = [ + 53, 43, 44, 191, 248, 216, 253, 96, 84, 253, 43, 36, 151, 202, 77, 190, 19, 71, 28, 215, + 161, 72, 57, 211, 182, 58, 152, 199, 107, 235, 238, 63, 160, 97, 190, 43, 89, 195, 111, + 179, 72, 18, 109, 141, 133, 74, 215, 16, + ]; + let p2_y_bytes = [ + 96, 0, 147, 41, 253, 168, 205, 45, 124, 150, 80, 188, 171, 228, 217, 34, 233, 192, 87, 38, + 176, 98, 88, 196, 41, 115, 40, 174, 52, 234, 97, 53, 209, 179, 91, 66, 107, 130, 187, 171, + 10, 254, 6, 227, 50, 212, 34, 8, + ]; + let q1_x0_bytes = [ + 115, 71, 82, 0, 253, 98, 21, 231, 188, 204, 204, 250, 44, 169, 184, 249, 132, 60, 132, 14, + 34, 48, 165, 84, 111, 109, 143, 182, 32, 72, 227, 210, 133, 144, 154, 196, 16, 169, 138, + 79, 19, 122, 34, 156, 176, 236, 114, 22, + ]; + let q1_x1_bytes = [ + 182, 57, 221, 84, 50, 87, 48, 115, 6, 98, 38, 176, 152, 25, 126, 43, 201, 61, 87, 42, 225, + 138, 200, 170, 0, 20, 174, 117, 112, 157, 233, 97, 0, 149, 210, 18, 224, 229, 157, 26, 197, + 93, 245, 96, 227, 157, 237, 15, + ]; + let q1_y0_bytes = [ + 185, 67, 44, 184, 194, 122, 245, 73, 123, 160, 144, 28, 83, 227, 9, 222, 52, 33, 74, 97, + 66, 113, 234, 143, 125, 244, 115, 58, 79, 29, 83, 208, 130, 83, 146, 30, 95, 202, 3, 189, + 0, 6, 81, 73, 107, 141, 234, 1, + ]; + let q1_y1_bytes = [ + 113, 182, 199, 78, 243, 62, 126, 145, 147, 111, 153, 151, 219, 69, 54, 127, 72, 82, 59, + 169, 219, 65, 228, 8, 193, 143, 67, 158, 12, 45, 225, 109, 220, 217, 133, 185, 75, 245, 82, + 200, 137, 178, 165, 90, 190, 232, 244, 21, + ]; + let q2_x0_bytes = [ + 48, 100, 73, 236, 161, 161, 88, 235, 92, 188, 236, 139, 70, 238, 43, 160, 189, 118, 66, + 116, 44, 222, 23, 195, 67, 252, 105, 112, 240, 119, 247, 53, 3, 24, 156, 3, 178, 117, 41, + 16, 120, 114, 244, 103, 65, 157, 255, 21, + ]; + let q2_x1_bytes = [ + 87, 198, 239, 80, 28, 107, 195, 211, 220, 50, 148, 176, 2, 30, 65, 17, 206, 180, 103, 123, + 161, 64, 40, 77, 84, 98, 25, 164, 111, 180, 209, 62, 23, 78, 4, 174, 123, 52, 30, 19, 149, + 4, 6, 56, 6, 173, 138, 12, + ]; + let q2_y0_bytes = [ + 178, 164, 255, 33, 62, 219, 245, 30, 146, 252, 242, 196, 23, 5, 90, 103, 75, 9, 67, 186, + 155, 40, 106, 209, 158, 161, 142, 60, 109, 58, 29, 180, 3, 126, 95, 225, 244, 243, 36, 82, + 32, 223, 19, 39, 202, 170, 158, 12, + ]; + let q2_y1_bytes = [ + 47, 93, 130, 172, 91, 197, 69, 2, 220, 41, 78, 230, 47, 199, 202, 197, 177, 54, 53, 90, + 233, 76, 186, 248, 212, 121, 120, 208, 231, 195, 87, 150, 233, 33, 103, 94, 11, 15, 108, + 247, 78, 10, 223, 139, 186, 5, 53, 8, + ]; + + for i in 0..48 { + assignment.in1_g1[0][i] = M31::from(p1_x_bytes[i]); + assignment.in1_g1[1][i] = M31::from(p1_y_bytes[i]); + assignment.in2_g1[0][i] = M31::from(p2_x_bytes[i]); + assignment.in2_g1[1][i] = M31::from(p2_y_bytes[i]); + assignment.in1_g2[0][0][i] = M31::from(q1_x0_bytes[i]); + assignment.in1_g2[0][1][i] = M31::from(q1_x1_bytes[i]); + assignment.in1_g2[1][0][i] = M31::from(q1_y0_bytes[i]); + assignment.in1_g2[1][1][i] = M31::from(q1_y1_bytes[i]); + assignment.in2_g2[0][0][i] = M31::from(q2_x0_bytes[i]); + assignment.in2_g2[0][1][i] = M31::from(q2_x1_bytes[i]); + assignment.in2_g2[1][0][i] = M31::from(q2_y0_bytes[i]); + assignment.in2_g2[1][1][i] = M31::from(q2_y1_bytes[i]); + } + + debug_eval( + &PairingCheckGKRCircuit::default(), + &assignment, + hint_registry, + ); +} diff --git a/circuit-std-rs/tests/gnark/mod.rs b/circuit-std-rs/tests/gnark/mod.rs new file mode 100644 index 00000000..e871bbde --- /dev/null +++ b/circuit-std-rs/tests/gnark/mod.rs @@ -0,0 +1,7 @@ +pub mod element; +// pub mod emparam; +// pub mod emulated; +// pub mod field; +// pub mod hints; +// pub mod limbs; +// pub mod utils; diff --git a/circuit-std-rs/tests/logup.rs b/circuit-std-rs/tests/logup.rs index 1f2a44ca..14522286 100644 --- a/circuit-std-rs/tests/logup.rs +++ b/circuit-std-rs/tests/logup.rs @@ -1,6 +1,9 @@ mod common; -use circuit_std_rs::{LogUpCircuit, LogUpParams}; +use circuit_std_rs::{ + logup::{query_count_hint, rangeproof_hint, LogUpRangeProofTable}, + LogUpCircuit, LogUpParams, +}; use expander_compiler::frontend::*; #[test] @@ -16,3 +19,42 @@ fn logup_test() { common::circuit_test_helper::(&logup_params); common::circuit_test_helper::(&logup_params); } + +declare_circuit!(LogUpRangeproofCircuit { test: Variable }); +impl GenericDefine for LogUpRangeproofCircuit { + fn define>(&self, builder: &mut Builder) { + let mut table = LogUpRangeProofTable::new(8); + table.initial(builder); + for i in 1..12 { + for j in (1 << (i - 1))..(1 << i) { + let key = builder.constant(j); + if i > 8 { + table.rangeproof(builder, key, i); + } else { + table.rangeproof_onechunk(builder, key, i); + } + } + } + table.final_check(builder); + } +} + +#[test] +fn rangeproof_logup_test() { + let mut hint_registry = HintRegistry::::new(); + hint_registry.register("myhint.querycounthint", query_count_hint); + hint_registry.register("myhint.rangeproofhint", rangeproof_hint); + //compile and test + let compile_result = compile_generic( + &LogUpRangeproofCircuit::default(), + CompileOptions::default(), + ) + .unwrap(); + let assignment = LogUpRangeproofCircuit { test: M31::from(0) }; + let witness = compile_result + .witness_solver + .solve_witness_with_hints(&assignment, &mut hint_registry) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![true]); +} diff --git a/circuit-std-rs/tests/poseidon_m31.rs b/circuit-std-rs/tests/poseidon_m31.rs new file mode 100644 index 00000000..0faa5ae4 --- /dev/null +++ b/circuit-std-rs/tests/poseidon_m31.rs @@ -0,0 +1,108 @@ +use circuit_std_rs::poseidon_m31::*; +use expander_compiler::frontend::*; + +declare_circuit!(PoseidonSpongeLen8Circuit { + inputs: [Variable; 8], + outputs: [Variable; 16] +}); + +impl Define for PoseidonSpongeLen8Circuit { + fn define(&self, builder: &mut API) { + let params = PoseidonM31Params::new( + builder, + POSEIDON_M31X16_RATE, + 16, + POSEIDON_M31X16_FULL_ROUNDS, + POSEIDON_M31X16_PARTIAL_ROUNDS, + ); + let res = params.hash_to_state(builder, &self.inputs); + (0..params.width).for_each(|i| builder.assert_is_equal(res[i], self.outputs[i])); + } +} + +#[test] +// NOTE(HS) Poseidon Mersenne-31 Width-16 Sponge tested over input length 8 +fn test_poseidon_m31x16_hash_to_state_input_len8() { + let compile_result = compile(&PoseidonSpongeLen8Circuit::default()).unwrap(); + + let assignment = PoseidonSpongeLen8Circuit:: { + inputs: [M31::from(114514); 8], + outputs: [ + M31 { v: 1021105124 }, + M31 { v: 1342990709 }, + M31 { v: 1593716396 }, + M31 { v: 2100280498 }, + M31 { v: 330652568 }, + M31 { v: 1371365483 }, + M31 { v: 586650367 }, + M31 { v: 345482939 }, + M31 { v: 849034538 }, + M31 { v: 175601510 }, + M31 { v: 1454280121 }, + M31 { v: 1362077584 }, + M31 { v: 528171622 }, + M31 { v: 187534772 }, + M31 { v: 436020341 }, + M31 { v: 1441052621 }, + ], + }; + let witness = compile_result + .witness_solver + .solve_witness(&assignment) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![true]); +} + +declare_circuit!(PoseidonSpongeLen16Circuit { + inputs: [Variable; 16], + outputs: [Variable; 16] +}); + +impl Define for PoseidonSpongeLen16Circuit { + fn define(&self, builder: &mut API) { + let params = PoseidonM31Params::new( + builder, + POSEIDON_M31X16_RATE, + 16, + POSEIDON_M31X16_FULL_ROUNDS, + POSEIDON_M31X16_PARTIAL_ROUNDS, + ); + let res = params.hash_to_state(builder, &self.inputs); + (0..params.width).for_each(|i| builder.assert_is_equal(res[i], self.outputs[i])); + } +} + +#[test] +// NOTE(HS) Poseidon Mersenne-31 Width-16 Sponge tested over input length 16 +fn test_poseidon_m31x16_hash_to_state_input_len16() { + let compile_result = compile(&PoseidonSpongeLen16Circuit::default()).unwrap(); + + let assignment = PoseidonSpongeLen16Circuit:: { + inputs: [M31::from(114514); 16], + outputs: [ + M31 { v: 1510043913 }, + M31 { v: 1840611937 }, + M31 { v: 45881205 }, + M31 { v: 1134797377 }, + M31 { v: 803058407 }, + M31 { v: 1772167459 }, + M31 { v: 846553905 }, + M31 { v: 2143336151 }, + M31 { v: 300871060 }, + M31 { v: 545838827 }, + M31 { v: 1603101164 }, + M31 { v: 396293243 }, + M31 { v: 502075988 }, + M31 { v: 2067011878 }, + M31 { v: 402134378 }, + M31 { v: 535675968 }, + ], + }; + let witness = compile_result + .witness_solver + .solve_witness(&assignment) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![true]); +} diff --git a/circuit-std-rs/tests/sha256_debug_utils.rs b/circuit-std-rs/tests/sha256_debug_utils.rs new file mode 100644 index 00000000..7f0fb30c --- /dev/null +++ b/circuit-std-rs/tests/sha256_debug_utils.rs @@ -0,0 +1,281 @@ +// the compression function of sha256, used to debug only, credit: https://crates.io/crates/sha2 + +#![allow(clippy::many_single_char_names)] +pub const BLOCK_LEN: usize = 16; +use core::convert::TryInto; + +#[inline(always)] +fn shl(v: [u32; 4], o: u32) -> [u32; 4] { + [v[0] >> o, v[1] >> o, v[2] >> o, v[3] >> o] +} + +#[inline(always)] +fn shr(v: [u32; 4], o: u32) -> [u32; 4] { + [v[0] << o, v[1] << o, v[2] << o, v[3] << o] +} + +#[inline(always)] +fn or(a: [u32; 4], b: [u32; 4]) -> [u32; 4] { + [a[0] | b[0], a[1] | b[1], a[2] | b[2], a[3] | b[3]] +} + +#[inline(always)] +fn xor(a: [u32; 4], b: [u32; 4]) -> [u32; 4] { + [a[0] ^ b[0], a[1] ^ b[1], a[2] ^ b[2], a[3] ^ b[3]] +} + +#[inline(always)] +fn add(a: [u32; 4], b: [u32; 4]) -> [u32; 4] { + [ + a[0].wrapping_add(b[0]), + a[1].wrapping_add(b[1]), + a[2].wrapping_add(b[2]), + a[3].wrapping_add(b[3]), + ] +} + +fn sha256load(v2: [u32; 4], v3: [u32; 4]) -> [u32; 4] { + [v3[3], v2[0], v2[1], v2[2]] +} + +fn sha256swap(v0: [u32; 4]) -> [u32; 4] { + [v0[2], v0[3], v0[0], v0[1]] +} + +fn sha256msg1(v0: [u32; 4], v1: [u32; 4]) -> [u32; 4] { + // sigma 0 on vectors + #[inline] + fn sigma0x4(x: [u32; 4]) -> [u32; 4] { + let t1 = or(shl(x, 7), shr(x, 25)); + let t2 = or(shl(x, 18), shr(x, 14)); + let t3 = shl(x, 3); + xor(xor(t1, t2), t3) + } + + add(v0, sigma0x4(sha256load(v0, v1))) +} + +fn sha256msg2(v4: [u32; 4], v3: [u32; 4]) -> [u32; 4] { + macro_rules! sigma1 { + ($a:expr) => { + $a.rotate_right(17) ^ $a.rotate_right(19) ^ ($a >> 10) + }; + } + + let [x3, x2, x1, x0] = v4; + let [w15, w14, _, _] = v3; + + let w16 = x0.wrapping_add(sigma1!(w14)); + let w17 = x1.wrapping_add(sigma1!(w15)); + let w18 = x2.wrapping_add(sigma1!(w16)); + let w19 = x3.wrapping_add(sigma1!(w17)); + + [w19, w18, w17, w16] +} + +fn sha256_digest_round_x2(cdgh: [u32; 4], abef: [u32; 4], wk: [u32; 4]) -> [u32; 4] { + macro_rules! big_sigma0 { + ($a:expr) => { + ($a.rotate_right(2) ^ $a.rotate_right(13) ^ $a.rotate_right(22)) + }; + } + macro_rules! big_sigma1 { + ($a:expr) => { + ($a.rotate_right(6) ^ $a.rotate_right(11) ^ $a.rotate_right(25)) + }; + } + macro_rules! bool3ary_202 { + ($a:expr, $b:expr, $c:expr) => { + $c ^ ($a & ($b ^ $c)) + }; + } // Choose, MD5F, SHA1C + macro_rules! bool3ary_232 { + ($a:expr, $b:expr, $c:expr) => { + ($a & $b) ^ ($a & $c) ^ ($b & $c) + }; + } // Majority, SHA1M + + let [_, _, wk1, wk0] = wk; + let [a0, b0, e0, f0] = abef; + let [c0, d0, g0, h0] = cdgh; + + // a round + let x0 = big_sigma1!(e0) + .wrapping_add(bool3ary_202!(e0, f0, g0)) + .wrapping_add(wk0) + .wrapping_add(h0); + let y0 = big_sigma0!(a0).wrapping_add(bool3ary_232!(a0, b0, c0)); + let (a1, b1, c1, d1, e1, f1, g1, h1) = ( + x0.wrapping_add(y0), + a0, + b0, + c0, + x0.wrapping_add(d0), + e0, + f0, + g0, + ); + + // a round + let x1 = big_sigma1!(e1) + .wrapping_add(bool3ary_202!(e1, f1, g1)) + .wrapping_add(wk1) + .wrapping_add(h1); + let y1 = big_sigma0!(a1).wrapping_add(bool3ary_232!(a1, b1, c1)); + let (a2, b2, _, _, e2, f2, _, _) = ( + x1.wrapping_add(y1), + a1, + b1, + c1, + x1.wrapping_add(d1), + e1, + f1, + g1, + ); + + [a2, b2, e2, f2] +} + +fn schedule(v0: [u32; 4], v1: [u32; 4], v2: [u32; 4], v3: [u32; 4]) -> [u32; 4] { + let t1 = sha256msg1(v0, v1); + let t2 = sha256load(v2, v3); + let t3 = add(t1, t2); + sha256msg2(t3, v3) +} + +/// Constants necessary for SHA-256 family of digests. +pub const K32: [u32; 64] = [ + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, + 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, + 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, + 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, + 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, + 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, + 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, + 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2, +]; + +/// Constants necessary for SHA-256 family of digests. +pub const K32X4: [[u32; 4]; 16] = [ + [K32[3], K32[2], K32[1], K32[0]], + [K32[7], K32[6], K32[5], K32[4]], + [K32[11], K32[10], K32[9], K32[8]], + [K32[15], K32[14], K32[13], K32[12]], + [K32[19], K32[18], K32[17], K32[16]], + [K32[23], K32[22], K32[21], K32[20]], + [K32[27], K32[26], K32[25], K32[24]], + [K32[31], K32[30], K32[29], K32[28]], + [K32[35], K32[34], K32[33], K32[32]], + [K32[39], K32[38], K32[37], K32[36]], + [K32[43], K32[42], K32[41], K32[40]], + [K32[47], K32[46], K32[45], K32[44]], + [K32[51], K32[50], K32[49], K32[48]], + [K32[55], K32[54], K32[53], K32[52]], + [K32[59], K32[58], K32[57], K32[56]], + [K32[63], K32[62], K32[61], K32[60]], +]; + +macro_rules! rounds4 { + ($abef:ident, $cdgh:ident, $rest:expr, $i:expr) => {{ + let t1 = add($rest, K32X4[$i]); + $cdgh = sha256_digest_round_x2($cdgh, $abef, t1); + let t2 = sha256swap(t1); + $abef = sha256_digest_round_x2($abef, $cdgh, t2); + }}; +} + +macro_rules! schedule_rounds4 { + ( + $abef:ident, $cdgh:ident, + $w0:expr, $w1:expr, $w2:expr, $w3:expr, $w4:expr, + $i: expr + ) => {{ + $w4 = schedule($w0, $w1, $w2, $w3); + rounds4!($abef, $cdgh, $w4, $i); + }}; +} + +#[allow(dead_code)] +fn print_state(abef: &[u32; 4], cdgh: &[u32; 4]) { + for i in 0..2 { + print!("{} ", (abef[i] >> 1) & 1); + } + + for i in 0..2 { + print!("{} ", (cdgh[i] >> 1) & 1); + } + + for i in 2..4 { + print!("{} ", (abef[i] >> 1) & 1); + } + + for i in 2..4 { + print!("{} ", (cdgh[i] >> 1) & 1); + } + + println!(); +} + +/// Process a block with the SHA-256 algorithm. +fn sha256_digest_block_u32(state: &mut [u32; 8], block: &[u32; 16]) { + let mut abef = [state[0], state[1], state[4], state[5]]; + let mut cdgh = [state[2], state[3], state[6], state[7]]; + + // Rounds 0..64 + let mut w0 = [block[3], block[2], block[1], block[0]]; + let mut w1 = [block[7], block[6], block[5], block[4]]; + let mut w2 = [block[11], block[10], block[9], block[8]]; + let mut w3 = [block[15], block[14], block[13], block[12]]; + let mut w4; + + // [w3, w2, w1, w0] would be the total big-endian interpretation of the block + + rounds4!(abef, cdgh, w0, 0); + rounds4!(abef, cdgh, w1, 1); + rounds4!(abef, cdgh, w2, 2); + rounds4!(abef, cdgh, w3, 3); + schedule_rounds4!(abef, cdgh, w0, w1, w2, w3, w4, 4); + schedule_rounds4!(abef, cdgh, w1, w2, w3, w4, w0, 5); + schedule_rounds4!(abef, cdgh, w2, w3, w4, w0, w1, 6); + schedule_rounds4!(abef, cdgh, w3, w4, w0, w1, w2, 7); + schedule_rounds4!(abef, cdgh, w4, w0, w1, w2, w3, 8); + schedule_rounds4!(abef, cdgh, w0, w1, w2, w3, w4, 9); + schedule_rounds4!(abef, cdgh, w1, w2, w3, w4, w0, 10); + schedule_rounds4!(abef, cdgh, w2, w3, w4, w0, w1, 11); + schedule_rounds4!(abef, cdgh, w3, w4, w0, w1, w2, 12); + schedule_rounds4!(abef, cdgh, w4, w0, w1, w2, w3, 13); + schedule_rounds4!(abef, cdgh, w0, w1, w2, w3, w4, 14); + schedule_rounds4!(abef, cdgh, w1, w2, w3, w4, w0, 15); + + let [a, b, e, f] = abef; + let [c, d, g, h] = cdgh; + + state[0] = state[0].wrapping_add(a); + state[1] = state[1].wrapping_add(b); + state[2] = state[2].wrapping_add(c); + state[3] = state[3].wrapping_add(d); + state[4] = state[4].wrapping_add(e); + state[5] = state[5].wrapping_add(f); + state[6] = state[6].wrapping_add(g); + state[7] = state[7].wrapping_add(h); +} + +pub const H256_256: [u32; 8] = [ + 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, +]; + +pub fn compress(state: &mut [u32; 8], blocks: &[[u8; 64]]) { + let mut block_u32 = [0u32; BLOCK_LEN]; + + // since LLVM can't properly use aliasing yet it will make + // unnecessary state stores without this copy + let mut state_cpy = *state; + for block in blocks { + // block is interpreted as u32 in big endian for every 4 bytes + for (o, chunk) in block_u32.iter_mut().zip(block.chunks_exact(4)) { + *o = u32::from_be_bytes(chunk.try_into().unwrap()); + } + sha256_digest_block_u32(&mut state_cpy, &block_u32); + } + *state = state_cpy; +} diff --git a/circuit-std-rs/tests/sha256_gf2.rs b/circuit-std-rs/tests/sha256_gf2.rs new file mode 100644 index 00000000..4df1b4cf --- /dev/null +++ b/circuit-std-rs/tests/sha256_gf2.rs @@ -0,0 +1,137 @@ +use circuit_std_rs::sha256::{gf2::SHA256GF2, gf2_utils::u32_to_bit}; +use expander_compiler::frontend::*; +#[allow(unused_imports)] +use extra::debug_eval; +use rand::RngCore; +use sha2::{Digest, Sha256}; + +mod sha256_debug_utils; +use sha256_debug_utils::{compress, H256_256 as SHA256_INIT_STATE}; + +const INPUT_LEN: usize = 1024; // input size in bits, must be a multiple of 8 +const OUTPUT_LEN: usize = 256; // FIXED 256 + +declare_circuit!(SHA256CircuitCompressionOnly { + input: [Variable; 512], + output: [Variable; 256], +}); + +impl GenericDefine for SHA256CircuitCompressionOnly { + fn define>(&self, api: &mut Builder) { + let hasher = SHA256GF2::new(); + let mut state = SHA256_INIT_STATE + .iter() + .map(|x| u32_to_bit(api, *x)) + .collect::>() + .try_into() + .unwrap(); + hasher.sha256_compress(api, &mut state, &self.input); + let output = state.iter().flatten().cloned().collect::>(); + for i in 0..256 { + api.assert_is_equal(output[i], self.output[i]); + } + } +} + +#[test] +fn test_sha256_compression_gf2() { + // let compile_result = compile_generic( + // &SHA256CircuitCompressionOnly::default(), + // CompileOptions::default(), + // ) + // .unwrap(); + + let compile_result = compile_generic_cross_layer( + &SHA256CircuitCompressionOnly::default(), + CompileOptions::default(), + ) + .unwrap(); + + let mut rng = rand::thread_rng(); + let n_tests = 5; + for _ in 0..n_tests { + let data = [rng.next_u32() as u8; 512 / 8]; + let mut state = SHA256_INIT_STATE; + compress(&mut state, &[data.try_into().unwrap()]); + let output = state + .iter() + .flat_map(|v| v.to_be_bytes()) + .collect::>(); + + let mut assignment = SHA256CircuitCompressionOnly::default(); + + for i in 0..64 { + for j in 0..8 { + assignment.input[i * 8 + j] = ((data[i] >> (7 - j)) as u32 & 1).into(); + } + } + for i in 0..32 { + for j in 0..8 { + assignment.output[i * 8 + j] = ((output[i] >> (7 - j)) as u32 & 1).into(); + } + } + + // debug_eval::( + // &SHA256CircuitCompressionOnly::default(), + // &assignment, + // EmptyHintCaller::new(), + // ); + + let witness = compile_result + .witness_solver + .solve_witness(&assignment) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![true]); + } +} + +declare_circuit!(SHA256Circuit { + input: [Variable; INPUT_LEN], + output: [Variable; OUTPUT_LEN], +}); + +impl GenericDefine for SHA256Circuit { + fn define>(&self, api: &mut Builder) { + let mut hasher = SHA256GF2::new(); + hasher.update(&self.input); + let output = hasher.finalize(api); + (0..OUTPUT_LEN).for_each(|i| api.assert_is_equal(output[i], self.output[i])); + } +} + +#[test] +fn test_sha256_gf2() { + assert!(INPUT_LEN % 8 == 0); + // let compile_result = + // compile_generic(&SHA256Circuit::default(), CompileOptions::default()).unwrap(); + + let compile_result = + compile_generic_cross_layer(&SHA256Circuit::default(), CompileOptions::default()).unwrap(); + + let n_tests = 5; + let mut rng = rand::thread_rng(); + for _ in 0..n_tests { + let data = [rng.next_u32() as u8; INPUT_LEN / 8]; + let mut hash = Sha256::new(); + hash.update(data); + let output = hash.finalize(); + let mut assignment = SHA256Circuit::default(); + for i in 0..INPUT_LEN / 8 { + for j in 0..8 { + assignment.input[i * 8 + j] = (((data[i] >> (7 - j)) & 1) as u32).into(); + } + } + for i in 0..OUTPUT_LEN / 8 { + for j in 0..8 { + assignment.output[i * 8 + j] = (((output[i] >> (7 - j) as u32) & 1) as u32).into(); + } + } + let witness = compile_result + .witness_solver + .solve_witness(&assignment) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![true]); + } +} diff --git a/circuit-std-rs/tests/sha256_m31.rs b/circuit-std-rs/tests/sha256_m31.rs new file mode 100644 index 00000000..028f8e9e --- /dev/null +++ b/circuit-std-rs/tests/sha256_m31.rs @@ -0,0 +1,76 @@ +use circuit_std_rs::{sha256::m31::sha256_37bytes, sha256::m31_utils::to_binary_hint}; +use expander_compiler::frontend::*; +use extra::*; +use sha2::{Digest, Sha256}; + +declare_circuit!(SHA25637BYTESCircuit { + input: [Variable; 37], + output: [Variable; 32], +}); + +pub fn check_sha256>( + builder: &mut B, + origin_data: &Vec, +) -> Vec { + let output = origin_data[37..].to_vec(); + let result = sha256_37bytes(builder, &origin_data[..37]); + for i in 0..32 { + builder.assert_is_equal(result[i], output[i]); + } + result +} + +impl GenericDefine for SHA25637BYTESCircuit { + fn define>(&self, builder: &mut Builder) { + for _ in 0..8 { + let mut data = self.input.to_vec(); + data.append(&mut self.output.to_vec()); + builder.memorized_simple_call(check_sha256, &data); + } + } +} + +#[test] +fn test_sha256_37bytes() { + let mut hint_registry = HintRegistry::::new(); + hint_registry.register("myhint.tobinary", to_binary_hint); + let compile_result = + compile_generic(&SHA25637BYTESCircuit::default(), CompileOptions::default()).unwrap(); + for i in 0..1 { + let data = [i; 37]; + let mut hash = Sha256::new(); + hash.update(data); + let output = hash.finalize(); + let mut assignment = SHA25637BYTESCircuit::default(); + for i in 0..37 { + assignment.input[i] = M31::from(data[i] as u32); + } + for i in 0..32 { + assignment.output[i] = M31::from(output[i] as u32); + } + let witness = compile_result + .witness_solver + .solve_witness_with_hints(&assignment, &mut hint_registry) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![true]); + } +} + +#[test] +fn debug_sha256_37bytes() { + let mut hint_registry = HintRegistry::::new(); + hint_registry.register("myhint.tobinary", to_binary_hint); + let data = [255; 37]; + let mut hash = Sha256::new(); + hash.update(data); + let output = hash.finalize(); + let mut assignment = SHA25637BYTESCircuit::default(); + for i in 0..37 { + assignment.input[i] = M31::from(data[i] as u32); + } + for i in 0..32 { + assignment.output[i] = M31::from(output[i] as u32); + } + debug_eval(&SHA25637BYTESCircuit::default(), &assignment, hint_registry); +} diff --git a/ecgo/builder/api.go b/ecgo/builder/api.go index 7dd0d243..b26e6549 100644 --- a/ecgo/builder/api.go +++ b/ecgo/builder/api.go @@ -53,6 +53,26 @@ func (builder *builder) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) f // returns res = Σ(vars) or res = vars[0] - Σ(vars[1:]) if sub == true. func (builder *builder) add(vars []int, sub bool) frontend.Variable { + // check if all variables are constants + allConst := true + if sum, ok := builder.constantValue(vars[0]); ok { + for _, x := range vars[1:] { + if v, ok := builder.constantValue(x); ok { + if sub { + sum = builder.field.Sub(sum, v) + } else { + sum = builder.field.Add(sum, v) + } + } else { + allConst = false + break + } + } + if allConst { + return builder.toVariable(sum) + } + } + coef := make([]constraint.Element, len(vars)) coef[0] = builder.tOne if sub { @@ -75,6 +95,9 @@ func (builder *builder) add(vars []int, sub bool) frontend.Variable { // Neg returns the negation of the given variable. func (builder *builder) Neg(i frontend.Variable) frontend.Variable { v := builder.toVariableId(i) + if c, ok := builder.constantValue(v); ok { + return builder.toVariable(builder.field.Neg(c)) + } coef := []constraint.Element{builder.field.Neg(builder.tOne)} builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.LinComb, @@ -87,6 +110,20 @@ func (builder *builder) Neg(i frontend.Variable) frontend.Variable { // Mul computes the product of the given variables. func (builder *builder) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { vars := builder.toVariableIds(append([]frontend.Variable{i1, i2}, in...)...) + allConst := true + if sum, ok := builder.constantValue(vars[0]); ok { + for _, x := range vars[1:] { + if v, ok := builder.constantValue(x); ok { + sum = builder.field.Mul(sum, v) + } else { + allConst = false + break + } + } + if allConst { + return builder.toVariable(sum) + } + } builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.Mul, Inputs: vars, @@ -99,6 +136,18 @@ func (builder *builder) DivUnchecked(i1, i2 frontend.Variable) frontend.Variable vars := builder.toVariableIds(i1, i2) v1 := vars[0] v2 := vars[1] + c1, ok1 := builder.constantValue(v1) + c2, ok2 := builder.constantValue(v2) + if ok1 && ok2 { + if c2.IsZero() { + if c1.IsZero() { + return builder.toVariable(constraint.Element{}) + } + panic("division by zero") + } + inv, _ := builder.field.Inverse(c2) + return builder.toVariable(builder.field.Mul(c1, inv)) + } builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.Div, X: v1, @@ -113,6 +162,15 @@ func (builder *builder) Div(i1, i2 frontend.Variable) frontend.Variable { vars := builder.toVariableIds(i1, i2) v1 := vars[0] v2 := vars[1] + c1, ok1 := builder.constantValue(v1) + c2, ok2 := builder.constantValue(v2) + if ok1 && ok2 { + if c2.IsZero() { + panic("division by zero") + } + inv, _ := builder.field.Inverse(c2) + return builder.toVariable(builder.field.Mul(c1, inv)) + } builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.Div, X: v1, @@ -160,6 +218,17 @@ func (builder *builder) Xor(_a, _b frontend.Variable) frontend.Variable { vars := builder.toVariableIds(_a, _b) a := vars[0] b := vars[1] + c1, ok1 := builder.constantValue(a) + c2, ok2 := builder.constantValue(b) + if ok1 && ok2 { + builder.AssertIsBoolean(_a) + builder.AssertIsBoolean(_b) + t := builder.field.Sub(c1, c2) + if t.IsZero() { + return builder.toVariable(constraint.Element{}) + } + return builder.toVariable(builder.tOne) + } builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.BoolBinOp, X: a, @@ -174,6 +243,16 @@ func (builder *builder) Or(_a, _b frontend.Variable) frontend.Variable { vars := builder.toVariableIds(_a, _b) a := vars[0] b := vars[1] + c1, ok1 := builder.constantValue(a) + c2, ok2 := builder.constantValue(b) + if ok1 && ok2 { + builder.AssertIsBoolean(_a) + builder.AssertIsBoolean(_b) + if c1.IsZero() && c2.IsZero() { + return builder.toVariable(constraint.Element{}) + } + return builder.toVariable(builder.tOne) + } builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.BoolBinOp, X: a, @@ -188,6 +267,16 @@ func (builder *builder) And(_a, _b frontend.Variable) frontend.Variable { vars := builder.toVariableIds(_a, _b) a := vars[0] b := vars[1] + c1, ok1 := builder.constantValue(a) + c2, ok2 := builder.constantValue(b) + if ok1 && ok2 { + builder.AssertIsBoolean(_a) + builder.AssertIsBoolean(_b) + if c1.IsZero() || c2.IsZero() { + return builder.toVariable(constraint.Element{}) + } + return builder.toVariable(builder.tOne) + } builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.BoolBinOp, X: a, @@ -207,7 +296,15 @@ func (builder *builder) Select(i0, i1, i2 frontend.Variable) frontend.Variable { // ensures that cond is boolean builder.AssertIsBoolean(cond) - v := builder.Sub(i1, i2) // no constraint is recorded + cst, ok := builder.constantValue(builder.toVariableId(cond)) + if ok { + if cst.IsZero() { + return i2 + } + return i1 + } + + v := builder.Sub(i1, i2) w := builder.Mul(cond, v) return builder.Add(w, i2) } @@ -246,6 +343,12 @@ func (builder *builder) Lookup2(b0, b1 frontend.Variable, i0, i1, i2, i3 fronten // IsZero returns 1 if the given variable is zero, otherwise returns 0. func (builder *builder) IsZero(i1 frontend.Variable) frontend.Variable { a := builder.toVariableId(i1) + if c, ok := builder.constantValue(a); ok { + if c.IsZero() { + return builder.toVariable(builder.tOne) + } + return builder.toVariable(constraint.Element{}) + } builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.IsZero, X: a, diff --git a/ecgo/builder/api_assertions.go b/ecgo/builder/api_assertions.go index c0102595..ae8d9370 100644 --- a/ecgo/builder/api_assertions.go +++ b/ecgo/builder/api_assertions.go @@ -13,6 +13,13 @@ import ( // AssertIsEqual adds an assertion that i1 is equal to i2. func (builder *builder) AssertIsEqual(i1, i2 frontend.Variable) { x := builder.toVariableId(builder.Sub(i1, i2)) + v, xConstant := builder.constantValue(x) + if xConstant { + if !v.IsZero() { + panic("AssertIsEqual will never be satisfied on nonzero constant") + } + return + } builder.constraints = append(builder.constraints, irsource.Constraint{ Typ: irsource.Zero, Var: x, @@ -22,6 +29,13 @@ func (builder *builder) AssertIsEqual(i1, i2 frontend.Variable) { // AssertIsDifferent constrains i1 and i2 to have different values. func (builder *builder) AssertIsDifferent(i1, i2 frontend.Variable) { x := builder.toVariableId(builder.Sub(i1, i2)) + v, xConstant := builder.constantValue(x) + if xConstant { + if v.IsZero() { + panic("AssertIsDifferent will never be satisfied on zero constant") + } + return + } builder.constraints = append(builder.constraints, irsource.Constraint{ Typ: irsource.NonZero, Var: x, @@ -31,6 +45,12 @@ func (builder *builder) AssertIsDifferent(i1, i2 frontend.Variable) { // AssertIsBoolean adds an assertion that the variable is either 0 or 1. func (builder *builder) AssertIsBoolean(i1 frontend.Variable) { x := builder.toVariableId(i1) + if b, ok := builder.constantValue(x); ok { + if !(b.IsZero() || builder.field.IsOne(b)) { + panic("assertIsBoolean failed: constant is not 0 or 1") + } + return + } builder.constraints = append(builder.constraints, irsource.Constraint{ Typ: irsource.Bool, Var: x, diff --git a/ecgo/builder/builder.go b/ecgo/builder/builder.go index e90f8f0b..3367ed99 100644 --- a/ecgo/builder/builder.go +++ b/ecgo/builder/builder.go @@ -35,6 +35,8 @@ type builder struct { nbExternalInput int maxVar int + varConstId []int + constValues []constraint.Element // defers (for gnark API) defers []func(frontend.API) error @@ -58,6 +60,8 @@ func (r *Root) newBuilder(nbExternalInput int) *builder { builder.tOne = builder.field.One() builder.maxVar = nbExternalInput + builder.varConstId = make([]int, nbExternalInput+1) + builder.constValues = make([]constraint.Element, 1) return &builder } @@ -106,11 +110,24 @@ func (builder *builder) Compile() (constraint.ConstraintSystem, error) { // ConstantValue returns always returns (nil, false) now, since the Golang frontend doesn't know the values of variables. func (builder *builder) ConstantValue(v frontend.Variable) (*big.Int, bool) { - return nil, false + coeff, ok := builder.constantValue(builder.toVariableId(v)) + if !ok { + return nil, false + } + return builder.field.ToBigInt(coeff), true +} + +func (builder *builder) constantValue(x int) (constraint.Element, bool) { + i := builder.varConstId[x] + if i == 0 { + return constraint.Element{}, false + } + return builder.constValues[i], true } func (builder *builder) addVarId() int { builder.maxVar += 1 + builder.varConstId = append(builder.varConstId, 0) return builder.maxVar } @@ -124,7 +141,10 @@ func (builder *builder) ceToId(x constraint.Element) int { ExtraId: 0, Const: x, }) - return builder.addVarId() + res := builder.addVarId() + builder.constValues = append(builder.constValues, x) + builder.varConstId[res] = len(builder.constValues) - 1 + return res } // toVariable will return (and allocate if neccesary) an Expression from given value @@ -147,6 +167,10 @@ func (builder *builder) toVariableId(input interface{}) int { } } +func (builder *builder) toVariable(input interface{}) frontend.Variable { + return newVariable(builder.toVariableId(input)) +} + // toVariables return frontend.Variable corresponding to inputs and the total size of the linear expressions func (builder *builder) toVariableIds(in ...frontend.Variable) []int { r := make([]int, 0, len(in)) @@ -195,8 +219,7 @@ func (builder *builder) newHintForId(id solver.HintID, nbOutputs int, inputs []f res := make([]frontend.Variable, nbOutputs) for i := 0; i < nbOutputs; i++ { - builder.maxVar += 1 - res[i] = newVariable(builder.maxVar) + res[i] = builder.addVar() } return res, nil } diff --git a/ecgo/examples/poseidon_m31/main.go b/ecgo/examples/poseidon_m31/main.go index c292e6bf..af9f3e44 100644 --- a/ecgo/examples/poseidon_m31/main.go +++ b/ecgo/examples/poseidon_m31/main.go @@ -4,11 +4,10 @@ import ( "fmt" "os" + poseidonM31 "github.com/PolyhedraZK/ExpanderCompilerCollection/circuit-std-go/poseidon-m31" "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo" "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/field/m31" - "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/poseidon" ecc_test "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/test" - "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/frontend" ) @@ -21,16 +20,14 @@ const NumRepeat = 120 type MockPoseidonM31Circuit struct { State [NumRepeat][16]frontend.Variable - Digest [NumRepeat]frontend.Variable `gnark:",public"` - Params *poseidon.PoseidonParams + Digest [NumRepeat]frontend.Variable } func (c *MockPoseidonM31Circuit) Define(api frontend.API) (err error) { // Define the circuit - engine := m31.Field{} for i := 0; i < NumRepeat; i++ { - digest := poseidon.PoseidonCircuit(api, engine, c.Params, c.State[i][:], true) - api.AssertIsEqual(digest, c.Digest[i]) + digest := poseidonM31.PoseidonM31x16Permutate(api, c.State[i][:]) + api.AssertIsEqual(digest[0], c.Digest[i]) } return @@ -38,42 +35,39 @@ func (c *MockPoseidonM31Circuit) Define(api frontend.API) (err error) { func M31CircuitBuild() { - param := poseidon.NewPoseidonParams() - - var states [NumRepeat][16]constraint.Element var stateVars [NumRepeat][16]frontend.Variable var outputVars [NumRepeat]frontend.Variable for i := 0; i < NumRepeat; i++ { - for j := 0; j < 16; j++ { - states[i][j] = constraint.Element{uint64(i)} - stateVars[i][j] = frontend.Variable(uint64(i)) + + for j := 0; j < 8; j++ { + stateVars[i][j] = frontend.Variable(0) + } + + for j := 8; j < 16; j++ { + stateVars[i][j] = frontend.Variable(114514) } - output := poseidon.PoseidonM31(param, states[i][:]) - outputVars[i] = frontend.Variable(output[0]) + + outputVars[i] = frontend.Variable(1021105124) } assignment := &MockPoseidonM31Circuit{ State: stateVars, Digest: outputVars, - Params: param, } // Ecc test circuit, err := ecgo.Compile(m31.ScalarField, &MockPoseidonM31Circuit{ State: stateVars, Digest: outputVars, - Params: param, }, frontend.WithCompressThreshold(32)) if err != nil { panic(err) } layered_circuit := circuit.GetLayeredCircuit() - // circuit.GetCircuitIr().Print() - err = os.WriteFile("poseidon_120_circuit_m31.txt", layered_circuit.Serialize(), 0o644) - if err != nil { + if err = os.WriteFile("poseidon_120_circuit_m31.txt", layered_circuit.Serialize(), 0o644); err != nil { panic(err) } inputSolver := circuit.GetInputSolver() @@ -81,8 +75,8 @@ func M31CircuitBuild() { if err != nil { panic(err) } - err = os.WriteFile("poseidon_120_witness_m31.txt", witness.Serialize(), 0o644) - if err != nil { + + if err = os.WriteFile("poseidon_120_witness_m31.txt", witness.Serialize(), 0o644); err != nil { panic(err) } if !ecc_test.CheckCircuit(layered_circuit, witness) { diff --git a/ecgo/poseidon/param.go b/ecgo/poseidon/param.go deleted file mode 100644 index 4e5938e0..00000000 --- a/ecgo/poseidon/param.go +++ /dev/null @@ -1,79 +0,0 @@ -package poseidon - -import "math/rand" - -type PoseidonParams struct { - // number of full rounds - NumFullRounds int - // number of half full rounds - NumHalfFullRounds int - // number of partial rounds - NumPartRounds int - // number of half full rounds - NumHalfPartialRounds int - // number of states - NumStates int - // mds matrix - MdsMatrix [][]uint32 - // external round constants - ExternalRoundConstant [][]uint32 - // internal round constants - InternalRoundConstant []uint32 -} - -// TODOs: the parameters are not secure. use a better way to generate the constants -func NewPoseidonParams() *PoseidonParams { - r := rand.New(rand.NewSource(42)) - - num_full_rounds := 8 - num_part_rounds := 14 - num_states := 16 - - external_round_constant := make([][]uint32, num_states) - for i := 0; i < num_states; i++ { - external_round_constant[i] = make([]uint32, num_full_rounds) - for j := 0; j < num_full_rounds; j++ { - external_round_constant[i][j] = randomM31(r) - } - } - - internal_round_constant := make([]uint32, num_part_rounds) - for i := 0; i < num_part_rounds; i++ { - internal_round_constant[i] = randomM31(r) - } - - // mds parameters adopted from Plonky3 - // https://github.com/Plonky3/Plonky3/blob/eeb4e37b20127c4daa871b2bad0df30a7c7380db/mersenne-31/src/mds.rs#L176 - mds := make([][]uint32, num_states) - mds[0] = make([]uint32, 16) - mds[0] = []uint32{1, 1, 51, 1, 11, 17, 2, 1, 101, 63, 15, 2, 67, 22, 13, 3} - for i := 1; i < 16; i++ { - mds[i] = make([]uint32, 16) - // cyclic rotation of the first row - for j := 0; j < 16; j++ { - mds[i][j] = mds[0][(j+i)%16] - } - - } - - return &PoseidonParams{ - NumFullRounds: num_full_rounds, - NumHalfFullRounds: num_full_rounds / 2, - NumPartRounds: num_part_rounds, - NumHalfPartialRounds: num_part_rounds / 2, - NumStates: num_states, - MdsMatrix: mds, - ExternalRoundConstant: external_round_constant, - InternalRoundConstant: internal_round_constant, - } -} - -func randomM31(r *rand.Rand) uint32 { - t := r.Uint32() & 0x7FFFFFFF - - for t == 0x7fffffff { - t = rand.Uint32() & 0x7FFFFFFF - } - - return t -} diff --git a/ecgo/poseidon/poseidon.go b/ecgo/poseidon/poseidon.go deleted file mode 100644 index 8d9ab3ff..00000000 --- a/ecgo/poseidon/poseidon.go +++ /dev/null @@ -1,110 +0,0 @@ -// Poseidon hash function, written in the layered circuit. -package poseidon - -import ( - "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/field/m31" - "github.com/consensys/gnark/constraint" -) - -type PoseidonInternalState struct { - AfterHalfFullRound [16]constraint.Element - AfterHalfPartialRound [16]constraint.Element - AfterPartialRound [16]constraint.Element -} - -func sBox(engine m31.Field, f constraint.Element) constraint.Element { - x2 := engine.Mul(f, f) - x4 := engine.Mul(x2, x2) - return engine.Mul(x4, f) -} - -func PoseidonM31(param *PoseidonParams, input []constraint.Element) constraint.Element { - _, output := PoseidonM31WithInternalStates(param, input, false) - return output -} - -// Poseidon hash function over M31 field. -// For convenience, function also outputs an internal state when the hash function is half complete. -func PoseidonM31WithInternalStates(param *PoseidonParams, input []constraint.Element, withState bool) (PoseidonInternalState, constraint.Element) { - // todo: pad the input if it is too short - if len(input) != param.NumStates { - panic("input length does not match the number of states in the Poseidon parameters") - } - - state := input - engine := m31.Field{} - internalState := PoseidonInternalState{} - - // Applies the full rounds. - for i := 0; i < param.NumHalfFullRounds; i++ { - for j := 0; j < param.NumStates; j++ { - state[j] = engine.Add(state[j], engine.FromInterface(param.ExternalRoundConstant[j][i])) - } - // we use original poseidon mds method here - // it seems to be more efficient than poseidon2 for us as it requires less number of additions - state = applyMdsMatrix(engine, state, param.MdsMatrix) - // applyExternalRoundMatrix(engine, state) - for j := 0; j < param.NumStates; j++ { - state[j] = sBox(engine, state[j]) - } - } - if withState { - copy(internalState.AfterHalfFullRound[:], state) - } - - // Applies the first half of partial rounds. - for i := 0; i < param.NumHalfPartialRounds; i++ { - state[0] = engine.Add(state[0], engine.FromInterface(param.InternalRoundConstant[i])) - // we use original poseidon mds method here - // it seems to be more efficient than poseidon2 for us as it requires less number of additions - state = applyMdsMatrix(engine, state, param.MdsMatrix) - // applyInternalRoundMatrix(engine, state) - state[0] = sBox(engine, state[0]) - } - - if withState { - copy(internalState.AfterHalfPartialRound[:], state) - } - - // Applies the second half of partial rounds. - for i := 0; i < param.NumHalfPartialRounds; i++ { - state[0] = engine.Add(state[0], engine.FromInterface(param.InternalRoundConstant[i+param.NumHalfPartialRounds])) - // we use original poseidon mds method here - // it seems to be more efficient than poseidon2 for us as it requires less number of additions - state = applyMdsMatrix(engine, state, param.MdsMatrix) - // applyInternalRoundMatrix(engine, state) - state[0] = sBox(engine, state[0]) - } - if withState { - copy(internalState.AfterPartialRound[:], state) - } - - // Applies the full rounds. - for i := 0; i < param.NumHalfFullRounds; i++ { - for j := 0; j < param.NumStates; j++ { - state[j] = engine.Add(state[j], engine.FromInterface(param.ExternalRoundConstant[j][i+param.NumHalfFullRounds])) - } - // we use original poseidon mds method here - // it seems to be more efficient than poseidon2 for us as it requires less number of additions - state = applyMdsMatrix(engine, state, param.MdsMatrix) - // applyExternalRoundMatrix(engine, state) - for j := 0; j < param.NumStates; j++ { - state[j] = sBox(engine, state[j]) - } - } - - return internalState, state[0] -} - -// we use original poseidon mds method here -// it seems to be more efficient than poseidon2 for us as it requires less number of additions -func applyMdsMatrix(engine m31.Field, state []constraint.Element, mds [][]uint32) []constraint.Element { - tmp := make([]constraint.Element, len(state)) - for i := 0; i < len(state); i++ { - tmp[i] = engine.Mul(state[0], constraint.Element{uint64(mds[i][0])}) - for j := 1; j < len(state); j++ { - tmp[i] = engine.Add(tmp[i], engine.Mul(state[j], constraint.Element{uint64(mds[i][j])})) - } - } - return tmp -} diff --git a/ecgo/poseidon/poseidon_circuit.go b/ecgo/poseidon/poseidon_circuit.go deleted file mode 100644 index 9f558573..00000000 --- a/ecgo/poseidon/poseidon_circuit.go +++ /dev/null @@ -1,138 +0,0 @@ -package poseidon - -import ( - "log" - "math/big" - - "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo" - "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/field/m31" - "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/utils/customgates" - "github.com/consensys/gnark/frontend" -) - -type PoseidonInternalStateVar struct { - AfterHalfFullRound [16]frontend.Variable - AfterHalfPartialRound [16]frontend.Variable - AfterPartialRound [16]frontend.Variable -} - -// Suppose we have a x^4 gate, which has id 12345 in the prover -const GATE_5TH_POWER_TYPE = 12345 -const GATE_4TH_POWER_COST = 20 - -const GATE_MUL_TYPE = 12346 -const GATE_MUL_COST = 20 - -func Mul(field *big.Int, inputs []*big.Int, outputs []*big.Int) error { - a := big.NewInt(0) - a.Mul(inputs[0], big.NewInt(1)) - outputs[0] = a - return nil -} - -func init() { - customgates.Register(GATE_5TH_POWER_TYPE, Power5, GATE_4TH_POWER_COST) - customgates.Register(GATE_MUL_TYPE, Mul, GATE_MUL_COST) -} - -// Main function of proving poseidon in circuit. -// -// To obtain a more efficient layered circuit representation, we also feed the internal state of the hash to this function. -func PoseidonCircuit( - api frontend.API, - engine m31.Field, - param *PoseidonParams, - input []frontend.Variable, - useRandomness bool) frontend.Variable { - // todo: pad the input if it is too short - if len(input) != param.NumStates { - log.Println("input length", len(input), "does not match the number of states in the Poseidon parameters") - panic("") - } - - // ============================ - // Applies the full rounds. - // ============================ - state := input - - for i := 0; i < param.NumHalfFullRounds; i++ { - // add round constant - for j := 0; j < param.NumStates; j++ { - state[j] = api.Add(state[j], param.ExternalRoundConstant[j][i]) - } - // apply affine transform - tmp := applyMdsMatrixCircuit(api, state, param.MdsMatrix) - state = tmp[:] - // sbox - for j := 0; j < param.NumStates; j++ { - state[j] = sBoxCircuit(api, state[j]) - } - } - - // ============================ - // Applies the first half of partial rounds. - // ============================ - - for i := 0; i < param.NumPartRounds; i++ { - // add round constant - state[0] = api.Add(state[0], param.InternalRoundConstant[i]) - // apply affine transform - tmp := applyMdsMatrixCircuit(api, state, param.MdsMatrix) - state = tmp[:] - // sbox - state[0] = sBoxCircuit(api, state[0]) - for j := 1; j < param.NumStates; j++ { - state[j] = api.(ecgo.API).CustomGate(GATE_MUL_TYPE, state[j]) - } - } - - // ============================ - // Applies the full rounds. - // ============================ - - for i := 0; i < param.NumHalfFullRounds; i++ { - // add round constant - for j := 0; j < param.NumStates; j++ { - state[j] = api.Add(state[j], param.ExternalRoundConstant[j][i+param.NumHalfFullRounds]) - } - // apply affine transform - tmp := applyMdsMatrixCircuit(api, state, param.MdsMatrix) - state = tmp[:] - // sbox - for j := 0; j < param.NumStates; j++ { - state[j] = sBoxCircuit(api, state[j]) - } - } - - return state[0] -} - -func accumulate(api frontend.API, a []frontend.Variable) frontend.Variable { - return api.Add(a[0], a[1], a[2:]...) -} - -func applyMdsMatrixCircuit(api frontend.API, x []frontend.Variable, mds [][]uint32) [16]frontend.Variable { - var res [16]frontend.Variable - for i := 0; i < 16; i++ { - var tmp [16]frontend.Variable - for j := 0; j < 16; j++ { - tmp[j] = api.Mul(x[j], mds[j][i]) - } - res[i] = accumulate(api, tmp[:]) - } - return res -} - -func Power5(field *big.Int, inputs []*big.Int, outputs []*big.Int) error { - a := big.NewInt(0) - a.Mul(inputs[0], inputs[0]) - a.Mul(a, a) - a.Mul(a, inputs[0]) - outputs[0] = a - return nil -} - -// S-Box: raise element to the power of 5 -func sBoxCircuit(api frontend.API, input frontend.Variable) frontend.Variable { - return api.(ecgo.API).CustomGate(GATE_5TH_POWER_TYPE, input) -} diff --git a/ecgo/poseidon/poseidon_circuit_test.go b/ecgo/poseidon/poseidon_circuit_test.go deleted file mode 100644 index 4068858e..00000000 --- a/ecgo/poseidon/poseidon_circuit_test.go +++ /dev/null @@ -1,71 +0,0 @@ -package poseidon - -import ( - "testing" - - "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo" - "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/field/m31" - "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/test" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/frontend" -) - -type MockPoseidonCircuit struct { - State [16]frontend.Variable `gnark:",public"` - Output frontend.Variable `gnark:",public"` -} - -func (c *MockPoseidonCircuit) Define(api frontend.API) (err error) { - param := NewPoseidonParams() - engine := m31.Field{} - t := PoseidonCircuit(api, engine, param, c.State[:], false) - api.AssertIsEqual(t, c.Output) - - return -} - -func TestPoseidonCircuit(t *testing.T) { - param := NewPoseidonParams() - - var states [16]constraint.Element - var stateVars [16]frontend.Variable - var outputVar frontend.Variable - - for j := 0; j < 16; j++ { - states[j] = constraint.Element{uint64(j)} - stateVars[j] = frontend.Variable(uint64(j)) - } - output := PoseidonM31(param, states[:]) - outputVar = frontend.Variable(output[0]) - - assignment := &MockPoseidonCircuit{ - State: stateVars, - Output: outputVar, - } - - // Gnark test disabled as it does not support randomness and custom gates - // err := test.IsSolved(&MockPoseidonCircuit{}, assignment, m31.ScalarField) - // if err != nil { - // panic(err) - // } - // fmt.Println("Gnark test passed") - - // Ecc test - circuit, err := ecgo.Compile(m31.ScalarField, &MockPoseidonCircuit{}, frontend.WithCompressThreshold(32)) - if err != nil { - panic(err) - } - - layered_circuit := circuit.GetLayeredCircuit() - // circuit.GetCircuitIr().Print() - - inputSolver := circuit.GetInputSolver() - witness, err := inputSolver.SolveInputAuto(assignment) - if err != nil { - panic(err) - } - - if !test.CheckCircuit(layered_circuit, witness) { - panic("verification failed") - } -} diff --git a/ecgo/poseidon/poseidon_test.go b/ecgo/poseidon/poseidon_test.go deleted file mode 100644 index afcd7a15..00000000 --- a/ecgo/poseidon/poseidon_test.go +++ /dev/null @@ -1,18 +0,0 @@ -package poseidon - -import ( - "testing" - - "github.com/consensys/gnark/constraint" - "github.com/stretchr/testify/assert" -) - -func TestPoseidon(t *testing.T) { - param := NewPoseidonParams() - - state := make([]constraint.Element, param.NumStates) - PoseidonM31(param, state) - - state = make([]constraint.Element, param.NumStates+1) - assert.Panics(t, func() { PoseidonM31(param, state) }) -} diff --git a/ecgo/utils/gnarkexpr/expr.go b/ecgo/utils/gnarkexpr/expr.go index 115115d3..e54ec638 100644 --- a/ecgo/utils/gnarkexpr/expr.go +++ b/ecgo/utils/gnarkexpr/expr.go @@ -22,7 +22,13 @@ func init() { } } +// gnark uses uint32 +const MaxVariables = (1 << 31) - 100 + func NewVar(x int) Expr { + if x < 0 || x >= MaxVariables { + panic("variable id out of range") + } v := builder.InternalVariable(uint32(x)) t := reflect.ValueOf(v).Index(0).Interface().(Expr) if t.WireID() != x { diff --git a/expander_compiler/Cargo.toml b/expander_compiler/Cargo.toml index a955bc20..0e51f2e0 100644 --- a/expander_compiler/Cargo.toml +++ b/expander_compiler/Cargo.toml @@ -12,12 +12,20 @@ clap.workspace = true ethnum.workspace = true halo2curves.workspace = true tiny-keccak.workspace = true +mpi_config.workspace = true +gkr_field_config.workspace = true expander_config.workspace = true expander_circuit.workspace = true +expander_transcript.workspace = true gkr.workspace = true arith.workspace = true gf2.workspace = true mersenne31.workspace = true +crosslayer_prototype.workspace = true + +[dev-dependencies] +rayon = "1.9" +sha2 = "0.10.8" [[bin]] name = "trivial_circuit" diff --git a/expander_compiler/ec_go_lib/Cargo.toml b/expander_compiler/ec_go_lib/Cargo.toml index 3315a2ee..842f375f 100644 --- a/expander_compiler/ec_go_lib/Cargo.toml +++ b/expander_compiler/ec_go_lib/Cargo.toml @@ -8,6 +8,8 @@ crate-type = ["dylib"] [dependencies] rand.workspace = true +gkr_field_config.workspace = true +mpi_config.workspace = true expander_config.workspace = true expander_circuit.workspace = true gkr.workspace = true diff --git a/expander_compiler/ec_go_lib/src/compile.rs b/expander_compiler/ec_go_lib/src/compile.rs new file mode 100644 index 00000000..755e2d6b --- /dev/null +++ b/expander_compiler/ec_go_lib/src/compile.rs @@ -0,0 +1,116 @@ +use expander_compiler::circuit::layered::NormalInputType; +use libc::{c_ulong, malloc}; +use std::ptr; +use std::slice; + +use expander_compiler::{ + circuit::{config, ir}, + utils::serde::Serde, +}; + +use super::*; + +#[repr(C)] +pub struct CompileResult { + ir_witness_gen: ByteArray, + layered: ByteArray, + error: ByteArray, +} + +fn compile_inner_with_config(ir_source: Vec) -> Result<(Vec, Vec), String> +where + C: config::Config, +{ + let ir_source = ir::source::RootCircuit::::deserialize_from(&ir_source[..]) + .map_err(|e| format!("failed to deserialize the source circuit: {}", e))?; + let (ir_witness_gen, layered) = + expander_compiler::compile::compile::<_, NormalInputType>(&ir_source) + .map_err(|e| e.to_string())?; + let mut ir_wg_s: Vec = Vec::new(); + ir_witness_gen + .serialize_into(&mut ir_wg_s) + .map_err(|e| format!("failed to serialize the witness generator: {}", e))?; + let mut layered_s: Vec = Vec::new(); + layered + .serialize_into(&mut layered_s) + .map_err(|e| format!("failed to serialize the layered circuit: {}", e))?; + Ok((ir_wg_s, layered_s)) +} + +fn compile_inner(ir_source: Vec, config_id: u64) -> Result<(Vec, Vec), String> { + match_config_id!(config_id, compile_inner_with_config, (ir_source)) +} + +fn to_compile_result(result: Result<(Vec, Vec), String>) -> CompileResult { + match result { + Ok((ir_witness_gen, layered)) => { + let ir_wg_len = ir_witness_gen.len(); + let layered_len = layered.len(); + let ir_wg_ptr = if ir_wg_len > 0 { + unsafe { + let ptr = malloc(ir_wg_len) as *mut u8; + ptr.copy_from(ir_witness_gen.as_ptr(), ir_wg_len); + ptr + } + } else { + ptr::null_mut() + }; + let layered_ptr = if layered_len > 0 { + unsafe { + let ptr = malloc(layered_len) as *mut u8; + ptr.copy_from(layered.as_ptr(), layered_len); + ptr + } + } else { + ptr::null_mut() + }; + CompileResult { + ir_witness_gen: ByteArray { + data: ir_wg_ptr, + length: ir_wg_len as c_ulong, + }, + layered: ByteArray { + data: layered_ptr, + length: layered_len as c_ulong, + }, + error: ByteArray { + data: ptr::null_mut(), + length: 0, + }, + } + } + Err(error) => { + let error_len = error.len(); + let error_ptr = if error_len > 0 { + unsafe { + let ptr = malloc(error_len) as *mut u8; + ptr.copy_from(error.as_ptr(), error_len); + ptr + } + } else { + ptr::null_mut() + }; + CompileResult { + ir_witness_gen: ByteArray { + data: ptr::null_mut(), + length: 0, + }, + layered: ByteArray { + data: ptr::null_mut(), + length: 0, + }, + error: ByteArray { + data: error_ptr, + length: error_len as c_ulong, + }, + } + } + } +} + +#[no_mangle] +pub extern "C" fn compile(ir_source: ByteArray, config_id: c_ulong) -> CompileResult { + let ir_source = unsafe { slice::from_raw_parts(ir_source.data, ir_source.length as usize) }; + let result = compile_inner(ir_source.to_vec(), config_id); + to_compile_result(result) +} diff --git a/expander_compiler/ec_go_lib/src/lib.rs b/expander_compiler/ec_go_lib/src/lib.rs index 800b2cd2..d710da0c 100644 --- a/expander_compiler/ec_go_lib/src/lib.rs +++ b/expander_compiler/ec_go_lib/src/lib.rs @@ -1,278 +1,27 @@ -use arith::FieldSerde; -use expander_compiler::circuit::layered; -use libc::{c_uchar, c_ulong, malloc}; -use std::io::Cursor; -use std::ptr; -use std::slice; - -use expander_compiler::{ - circuit::{config, ir}, - utils::serde::Serde, -}; +use expander_compiler::circuit::config::Config; +use libc::{c_uchar, c_ulong}; const ABI_VERSION: c_ulong = 4; -#[repr(C)] -pub struct ByteArray { - data: *mut c_uchar, - length: c_ulong, -} - -#[repr(C)] -pub struct CompileResult { - ir_witness_gen: ByteArray, - layered: ByteArray, - error: ByteArray, -} - -fn compile_inner_with_config(ir_source: Vec) -> Result<(Vec, Vec), String> -where - C: config::Config, -{ - let ir_source = ir::source::RootCircuit::::deserialize_from(&ir_source[..]) - .map_err(|e| format!("failed to deserialize the source circuit: {}", e))?; - let (ir_witness_gen, layered) = - expander_compiler::compile::compile(&ir_source).map_err(|e| e.to_string())?; - let mut ir_wg_s: Vec = Vec::new(); - ir_witness_gen - .serialize_into(&mut ir_wg_s) - .map_err(|e| format!("failed to serialize the witness generator: {}", e))?; - let mut layered_s: Vec = Vec::new(); - layered - .serialize_into(&mut layered_s) - .map_err(|e| format!("failed to serialize the layered circuit: {}", e))?; - Ok((ir_wg_s, layered_s)) -} - -fn compile_inner(ir_source: Vec, config_id: u64) -> Result<(Vec, Vec), String> { - match config_id { - 1 => compile_inner_with_config::(ir_source), - 2 => compile_inner_with_config::(ir_source), - 3 => compile_inner_with_config::(ir_source), - _ => Err(format!("unknown config id: {}", config_id)), - } -} - -fn to_compile_result(result: Result<(Vec, Vec), String>) -> CompileResult { - match result { - Ok((ir_witness_gen, layered)) => { - let ir_wg_len = ir_witness_gen.len(); - let layered_len = layered.len(); - let ir_wg_ptr = if ir_wg_len > 0 { - unsafe { - let ptr = malloc(ir_wg_len) as *mut u8; - ptr.copy_from(ir_witness_gen.as_ptr(), ir_wg_len); - ptr - } - } else { - ptr::null_mut() - }; - let layered_ptr = if layered_len > 0 { - unsafe { - let ptr = malloc(layered_len) as *mut u8; - ptr.copy_from(layered.as_ptr(), layered_len); - ptr - } - } else { - ptr::null_mut() - }; - CompileResult { - ir_witness_gen: ByteArray { - data: ir_wg_ptr, - length: ir_wg_len as c_ulong, - }, - layered: ByteArray { - data: layered_ptr, - length: layered_len as c_ulong, - }, - error: ByteArray { - data: ptr::null_mut(), - length: 0, - }, - } - } - Err(error) => { - let error_len = error.len(); - let error_ptr = if error_len > 0 { - unsafe { - let ptr = malloc(error_len) as *mut u8; - ptr.copy_from(error.as_ptr(), error_len); - ptr - } - } else { - ptr::null_mut() - }; - CompileResult { - ir_witness_gen: ByteArray { - data: ptr::null_mut(), - length: 0, - }, - layered: ByteArray { - data: ptr::null_mut(), - length: 0, - }, - error: ByteArray { - data: error_ptr, - length: error_len as c_ulong, - }, - } +#[macro_export] +macro_rules! match_config_id { + ($config_id:ident, $inner:ident, $args:tt) => { + match $config_id { + x if x == config::M31Config::CONFIG_ID as u64 => $inner:: $args, + x if x == config::BN254Config::CONFIG_ID as u64 => $inner:: $args, + x if x == config::GF2Config::CONFIG_ID as u64 => $inner:: $args, + _ => Err(format!("unknown config id: {}", $config_id)), } } } -#[no_mangle] -pub extern "C" fn compile(ir_source: ByteArray, config_id: c_ulong) -> CompileResult { - let ir_source = unsafe { slice::from_raw_parts(ir_source.data, ir_source.length as usize) }; - let result = compile_inner(ir_source.to_vec(), config_id); - to_compile_result(result) -} - -fn dump_proof_and_claimed_v( - proof: &expander_transcript::Proof, - claimed_v: &F, -) -> Vec { - let mut bytes = Vec::new(); - - proof.serialize_into(&mut bytes).unwrap(); // TODO: error propagation - claimed_v.serialize_into(&mut bytes).unwrap(); // TODO: error propagation - - bytes -} - -fn load_proof_and_claimed_v( - bytes: &[u8], -) -> Result<(expander_transcript::Proof, F), ()> { - let mut cursor = Cursor::new(bytes); - - let proof = expander_transcript::Proof::deserialize_from(&mut cursor).map_err(|_| ())?; - let claimed_v = F::deserialize_from(&mut cursor).map_err(|_| ())?; +pub mod compile; +pub mod proving; - Ok((proof, claimed_v)) -} - -fn prove_circuit_file_inner( - circuit_filename: &str, - witness: &[u8], -) -> Vec -where - C::SimdCircuitField: arith::SimdField, -{ - let config = expander_config::Config::::new( - expander_config::GKRScheme::Vanilla, - expander_config::MPIConfig::new(), - ); - let mut circuit = expander_circuit::Circuit::::load_circuit(circuit_filename); - let witness = layered::witness::Witness::::deserialize_from(witness).unwrap(); - let (simd_input, simd_public_input) = witness.to_simd::(); - circuit.layers[0].input_vals = simd_input; - circuit.public_input = simd_public_input; - circuit.evaluate(); - let mut prover = gkr::Prover::new(&config); - prover.prepare_mem(&circuit); - let (claimed_v, proof) = prover.prove(&mut circuit); - dump_proof_and_claimed_v(&proof, &claimed_v) -} - -fn verify_circuit_file_inner( - circuit_filename: &str, - witness: &[u8], - proof_and_claimed_v: &[u8], -) -> u8 -where - C::SimdCircuitField: arith::SimdField, -{ - let config = expander_config::Config::::new( - expander_config::GKRScheme::Vanilla, - expander_config::MPIConfig::new(), - ); - let mut circuit = expander_circuit::Circuit::::load_circuit(circuit_filename); - let witness = layered::witness::Witness::::deserialize_from(witness).unwrap(); - let (simd_input, simd_public_input) = witness.to_simd::(); - circuit.layers[0].input_vals = simd_input; - circuit.public_input = simd_public_input.clone(); - let (proof, claimed_v) = match load_proof_and_claimed_v(proof_and_claimed_v) { - Ok((proof, claimed_v)) => (proof, claimed_v), - Err(_) => { - return 0; - } - }; - let verifier = gkr::Verifier::new(&config); - verifier.verify(&mut circuit, &simd_public_input, &claimed_v, &proof) as u8 -} - -#[no_mangle] -pub extern "C" fn prove_circuit_file( - circuit_filename: ByteArray, - witness: ByteArray, - config_id: c_ulong, -) -> ByteArray { - let circuit_filename = unsafe { - let slice = slice::from_raw_parts(circuit_filename.data, circuit_filename.length as usize); - std::str::from_utf8(slice).unwrap() - }; - let witness = unsafe { slice::from_raw_parts(witness.data, witness.length as usize) }; - let proof = match config_id { - 1 => prove_circuit_file_inner::( - circuit_filename, - witness, - ), - 2 => prove_circuit_file_inner::( - circuit_filename, - witness, - ), - 3 => prove_circuit_file_inner::( - circuit_filename, - witness, - ), - _ => panic!("unknown config id: {}", config_id), - }; - let proof_len = proof.len(); - let proof_ptr = if proof_len > 0 { - unsafe { - let ptr = malloc(proof_len) as *mut u8; - ptr.copy_from(proof.as_ptr(), proof_len); - ptr - } - } else { - ptr::null_mut() - }; - ByteArray { - data: proof_ptr, - length: proof_len as c_ulong, - } -} - -#[no_mangle] -pub extern "C" fn verify_circuit_file( - circuit_filename: ByteArray, - witness: ByteArray, - proof: ByteArray, - config_id: c_ulong, -) -> c_uchar { - let circuit_filename = unsafe { - let slice = slice::from_raw_parts(circuit_filename.data, circuit_filename.length as usize); - std::str::from_utf8(slice).unwrap() - }; - let witness = unsafe { slice::from_raw_parts(witness.data, witness.length as usize) }; - let proof = unsafe { slice::from_raw_parts(proof.data, proof.length as usize) }; - match config_id { - 1 => verify_circuit_file_inner::( - circuit_filename, - witness, - proof, - ), - 2 => verify_circuit_file_inner::( - circuit_filename, - witness, - proof, - ), - 3 => verify_circuit_file_inner::( - circuit_filename, - witness, - proof, - ), - _ => panic!("unknown config id: {}", config_id), - } +#[repr(C)] +pub struct ByteArray { + data: *mut c_uchar, + length: c_ulong, } #[no_mangle] diff --git a/expander_compiler/ec_go_lib/src/proving.rs b/expander_compiler/ec_go_lib/src/proving.rs new file mode 100644 index 00000000..205e05bb --- /dev/null +++ b/expander_compiler/ec_go_lib/src/proving.rs @@ -0,0 +1,107 @@ +use expander_compiler::circuit::layered; +use libc::{c_uchar, c_ulong, malloc}; +use std::ptr; +use std::slice; + +use expander_compiler::{circuit::config, utils::serde::Serde}; + +use super::*; + +fn prove_circuit_file_inner( + circuit_filename: &str, + witness: &[u8], +) -> Result, String> { + let config = expander_config::Config::::new( + expander_config::GKRScheme::Vanilla, + mpi_config::MPIConfig::new(), + ); + let mut circuit = + expander_circuit::Circuit::::load_circuit(circuit_filename); + let witness = + layered::witness::Witness::::deserialize_from(witness).map_err(|e| e.to_string())?; + let (simd_input, simd_public_input) = witness.to_simd::(); + circuit.layers[0].input_vals = simd_input; + circuit.public_input = simd_public_input; + circuit.evaluate(); + let (claimed_v, proof) = gkr::executor::prove(&mut circuit, &config); + gkr::executor::dump_proof_and_claimed_v(&proof, &claimed_v).map_err(|e| e.to_string()) +} + +fn verify_circuit_file_inner( + circuit_filename: &str, + witness: &[u8], + proof_and_claimed_v: &[u8], +) -> Result { + let config = expander_config::Config::::new( + expander_config::GKRScheme::Vanilla, + mpi_config::MPIConfig::new(), + ); + let mut circuit = + expander_circuit::Circuit::::load_circuit(circuit_filename); + let witness = + layered::witness::Witness::::deserialize_from(witness).map_err(|e| e.to_string())?; + let (simd_input, simd_public_input) = witness.to_simd::(); + circuit.layers[0].input_vals = simd_input; + circuit.public_input = simd_public_input.clone(); + let (proof, claimed_v) = match gkr::executor::load_proof_and_claimed_v(proof_and_claimed_v) { + Ok((proof, claimed_v)) => (proof, claimed_v), + Err(_) => { + return Ok(0); + } + }; + Ok(gkr::executor::verify(&mut circuit, &config, &proof, &claimed_v) as u8) +} + +#[no_mangle] +pub extern "C" fn prove_circuit_file( + circuit_filename: ByteArray, + witness: ByteArray, + config_id: c_ulong, +) -> ByteArray { + let circuit_filename = unsafe { + let slice = slice::from_raw_parts(circuit_filename.data, circuit_filename.length as usize); + std::str::from_utf8(slice).unwrap() + }; + let witness = unsafe { slice::from_raw_parts(witness.data, witness.length as usize) }; + let proof = match_config_id!( + config_id, + prove_circuit_file_inner, + (circuit_filename, witness) + ) + .unwrap(); // TODO: handle error + let proof_len = proof.len(); + let proof_ptr = if proof_len > 0 { + unsafe { + let ptr = malloc(proof_len) as *mut u8; + ptr.copy_from(proof.as_ptr(), proof_len); + ptr + } + } else { + ptr::null_mut() + }; + ByteArray { + data: proof_ptr, + length: proof_len as c_ulong, + } +} + +#[no_mangle] +pub extern "C" fn verify_circuit_file( + circuit_filename: ByteArray, + witness: ByteArray, + proof: ByteArray, + config_id: c_ulong, +) -> c_uchar { + let circuit_filename = unsafe { + let slice = slice::from_raw_parts(circuit_filename.data, circuit_filename.length as usize); + std::str::from_utf8(slice).unwrap() + }; + let witness = unsafe { slice::from_raw_parts(witness.data, witness.length as usize) }; + let proof = unsafe { slice::from_raw_parts(proof.data, proof.length as usize) }; + match_config_id!( + config_id, + verify_circuit_file_inner, + (circuit_filename, witness, proof) + ) + .unwrap() // TODO: handle error +} diff --git a/expander_compiler/src/builder/final_build_opt.rs b/expander_compiler/src/builder/final_build_opt.rs index 6f79df1f..87a4d274 100644 --- a/expander_compiler/src/builder/final_build_opt.rs +++ b/expander_compiler/src/builder/final_build_opt.rs @@ -936,7 +936,7 @@ mod tests { } _ => panic!(), } - let inputs: Vec = (1..=100000).map(|i| CField::from(i)).collect(); + let inputs: Vec = (1..=100000).map(CField::from).collect(); let (out, ok) = root.eval_unsafe(inputs.clone()); let (out2, ok2) = root_processed.eval_unsafe(inputs); assert_eq!(out, out2); @@ -959,7 +959,7 @@ mod tests { assert_eq!(root.validate(), Ok(())); let root_processed = super::process(&root).unwrap(); assert_eq!(root_processed.validate(), Ok(())); - let inputs: Vec = (1..=100000).map(|i| CField::from(i)).collect(); + let inputs: Vec = (1..=100000).map(CField::from).collect(); let (out, ok) = root.eval_unsafe(inputs.clone()); let (out2, ok2) = root_processed.eval_unsafe(inputs); assert_eq!(out, out2); diff --git a/expander_compiler/src/circuit/config.rs b/expander_compiler/src/circuit/config.rs index f90fb6a5..0319fd44 100644 --- a/expander_compiler/src/circuit/config.rs +++ b/expander_compiler/src/circuit/config.rs @@ -5,6 +5,13 @@ use crate::field::Field; pub trait Config: Default + Clone + Ord + Debug + Hash + Copy + 'static { type CircuitField: Field; + type DefaultSimdField: arith::SimdField; + type DefaultGKRFieldConfig: gkr_field_config::GKRFieldConfig< + CircuitField = Self::CircuitField, + SimdCircuitField = Self::DefaultSimdField, + >; + type DefaultGKRConfig: expander_config::GKRConfig; + const CONFIG_ID: usize; const COST_INPUT: usize = 1000; @@ -22,6 +29,10 @@ pub struct M31Config {} impl Config for M31Config { type CircuitField = crate::field::M31; + type DefaultSimdField = mersenne31::M31x16; + type DefaultGKRFieldConfig = gkr_field_config::M31ExtConfig; + type DefaultGKRConfig = gkr::executor::M31ExtConfigSha2; + const CONFIG_ID: usize = 1; } @@ -31,6 +42,10 @@ pub struct BN254Config {} impl Config for BN254Config { type CircuitField = crate::field::BN254; + type DefaultSimdField = crate::field::BN254; + type DefaultGKRFieldConfig = gkr_field_config::BN254Config; + type DefaultGKRConfig = gkr::executor::BN254ConfigMIMC5; + const CONFIG_ID: usize = 2; } @@ -40,6 +55,10 @@ pub struct GF2Config {} impl Config for GF2Config { type CircuitField = crate::field::GF2; + type DefaultSimdField = gf2::GF2x8; + type DefaultGKRFieldConfig = gkr_field_config::GF2ExtConfig; + type DefaultGKRConfig = gkr::executor::GF2ExtConfigSha2; + const CONFIG_ID: usize = 3; // temporary fix for Keccak_GF2 diff --git a/expander_compiler/src/circuit/ir/common/rand_gen.rs b/expander_compiler/src/circuit/ir/common/rand_gen.rs index ded0094d..1bc64bdc 100644 --- a/expander_compiler/src/circuit/ir/common/rand_gen.rs +++ b/expander_compiler/src/circuit/ir/common/rand_gen.rs @@ -16,9 +16,7 @@ pub trait RandomConstraintType { } impl RandomConstraintType for RawConstraintType { - fn random(_r: impl RngCore) -> Self { - () - } + fn random(_r: impl RngCore) -> Self {} } pub struct RandomRange { diff --git a/expander_compiler/src/circuit/ir/dest/mod.rs b/expander_compiler/src/circuit/ir/dest/mod.rs index 07415c79..f6cdc750 100644 --- a/expander_compiler/src/circuit/ir/dest/mod.rs +++ b/expander_compiler/src/circuit/ir/dest/mod.rs @@ -16,6 +16,7 @@ use super::{ pub mod tests; pub mod display; +pub mod mul_fanout_limit; #[derive(Debug, Clone, Hash, PartialEq, Eq)] pub enum Instruction { diff --git a/expander_compiler/src/circuit/ir/dest/mul_fanout_limit.rs b/expander_compiler/src/circuit/ir/dest/mul_fanout_limit.rs new file mode 100644 index 00000000..61824d79 --- /dev/null +++ b/expander_compiler/src/circuit/ir/dest/mul_fanout_limit.rs @@ -0,0 +1,478 @@ +use super::*; + +// This module contains the implementation of the optimization that reduces the fanout of the input variables in multiplication gates. +// There are two ways to reduce the fanout of a variable: +// 1. Copy the whole expression to a new variable. This will copy all gates, and may increase the number of gates by a lot. +// 2. Create a relay expression of the variable. This may increase the layer of the circuit by 1. + +// These are the limits for the first method. +const MAX_COPIES_OF_VARIABLES: usize = 4; +const MAX_COPIES_OF_GATES: usize = 64; + +fn compute_max_copy_cnt(num_gates: usize) -> usize { + if num_gates == 0 { + return 0; + } + MAX_COPIES_OF_VARIABLES.min(MAX_COPIES_OF_GATES / num_gates) +} + +struct NewIdQueue { + queue: Vec<(usize, usize)>, + next: usize, + default_id: usize, +} + +impl NewIdQueue { + fn new(default_id: usize) -> Self { + Self { + queue: Vec::new(), + next: 0, + default_id, + } + } + + fn push(&mut self, id: usize, num: usize) { + self.queue.push((id, num)); + } + + fn get(&mut self) -> usize { + while self.next < self.queue.len() { + let (id, num) = self.queue[self.next]; + if num > 0 { + self.queue[self.next].1 -= 1; + return id; + } + self.next += 1; + } + self.default_id + } +} + +impl CircuitRelaxed { + fn solve_mul_fanout_limit(&self, limit: usize) -> CircuitRelaxed { + let mut max_copy_cnt = vec![0; self.num_inputs + 1]; + let mut mul_ref_cnt = vec![0; self.num_inputs + 1]; + let mut internal_var_insn_id = vec![None; self.num_inputs + 1]; + + for (i, insn) in self.instructions.iter().enumerate() { + match insn { + Instruction::ConstantLike { .. } => { + mul_ref_cnt.push(0); + max_copy_cnt.push(compute_max_copy_cnt(1)); + internal_var_insn_id.push(None); + } + Instruction::SubCircuitCall { num_outputs, .. } => { + for _ in 0..*num_outputs { + mul_ref_cnt.push(0); + max_copy_cnt.push(0); + internal_var_insn_id.push(None); + } + } + Instruction::InternalVariable { expr } => { + for term in expr.iter() { + if let VarSpec::Quad(x, y) = term.vars { + mul_ref_cnt[x] += 1; + mul_ref_cnt[y] += 1; + } + } + mul_ref_cnt.push(0); + max_copy_cnt.push(compute_max_copy_cnt(expr.len())); + internal_var_insn_id.push(Some(i)) + } + } + } + + let mut add_copy_cnt = vec![0; max_copy_cnt.len()]; + let mut relay_cnt = vec![0; max_copy_cnt.len()]; + let mut any_new = false; + + for i in (1..max_copy_cnt.len()).rev() { + let mc = max_copy_cnt[i].max(1); + if mul_ref_cnt[i] <= mc * limit { + add_copy_cnt[i] = ((mul_ref_cnt[i] + limit - 1) / limit).max(1) - 1; + any_new = true; + if let Some(j) = internal_var_insn_id[i] { + if let Instruction::InternalVariable { expr } = &self.instructions[j] { + for term in expr.iter() { + if let VarSpec::Quad(x, y) = term.vars { + mul_ref_cnt[x] += add_copy_cnt[i]; + mul_ref_cnt[y] += add_copy_cnt[i]; + } + } + } else { + unreachable!(); + } + } + } else { + // mul_ref_cnt[i] + relay_cnt[i] <= limit * (1 + relay_cnt[i]) + relay_cnt[i] = (mul_ref_cnt[i] - 2) / (limit - 1); + any_new = true; + } + } + + if !any_new { + return self.clone(); + } + + let mut new_id = vec![]; + let mut new_insns: Vec> = Vec::new(); + let mut new_var_max = self.num_inputs; + let mut last_solved_id = 0; + + for i in 0..=self.num_inputs { + new_id.push(NewIdQueue::new(i)); + } + + for insn in self.instructions.iter() { + while last_solved_id + 1 < new_id.len() { + last_solved_id += 1; + let x = last_solved_id; + if add_copy_cnt[x] == 0 && relay_cnt[x] == 0 { + continue; + } + let y = new_id[x].default_id; + new_id[x].push(y, limit); + for _ in 0..add_copy_cnt[x] { + let insn = new_insns.last().unwrap().clone(); + new_insns.push(insn); + new_var_max += 1; + new_id[x].push(new_var_max, limit); + } + for _ in 0..relay_cnt[x] { + let y = new_id[x].get(); + new_insns.push(Instruction::InternalVariable { + expr: Expression::new_linear(C::CircuitField::one(), y), + }); + new_var_max += 1; + new_id[x].push(new_var_max, limit); + } + } + match insn { + Instruction::ConstantLike { value } => { + new_insns.push(Instruction::ConstantLike { + value: value.clone(), + }); + new_var_max += 1; + new_id.push(NewIdQueue::new(new_var_max)); + } + Instruction::SubCircuitCall { + sub_circuit_id, + inputs, + num_outputs, + } => { + new_insns.push(Instruction::SubCircuitCall { + sub_circuit_id: *sub_circuit_id, + inputs: inputs.iter().map(|x| new_id[*x].default_id).collect(), + num_outputs: *num_outputs, + }); + for _ in 0..*num_outputs { + new_var_max += 1; + let x = new_id.len(); + new_id.push(NewIdQueue::new(new_var_max)); + assert_eq!(add_copy_cnt[x], 0); + } + } + Instruction::InternalVariable { expr } => { + let x = new_id.len(); + if add_copy_cnt[x] > 0 { + assert_eq!(relay_cnt[x], 0); + } + for _ in 0..=add_copy_cnt[x] { + let mut new_terms = vec![]; + for term in expr.iter() { + if let VarSpec::Quad(x, y) = term.vars { + new_terms.push(Term { + vars: VarSpec::Quad(new_id[x].get(), new_id[y].get()), + coef: term.coef, + }); + } else { + new_terms.push(Term { + vars: term.vars.replace_vars(|x| new_id[x].default_id), + coef: term.coef, + }); + } + } + new_insns.push(Instruction::InternalVariable { + expr: Expression::from_terms(new_terms), + }); + new_var_max += 1; + } + new_id.push(NewIdQueue::new(new_var_max)); + if add_copy_cnt[x] > 0 { + for i in 0..=add_copy_cnt[x] { + new_id[x].push(new_var_max - add_copy_cnt[x] + i, limit); + } + last_solved_id = x; + } + } + } + } + + CircuitRelaxed { + instructions: new_insns, + num_inputs: self.num_inputs, + outputs: self.outputs.iter().map(|x| new_id[*x].default_id).collect(), + constraints: self + .constraints + .iter() + .map(|x| new_id[*x].default_id) + .collect(), + } + } +} + +impl RootCircuitRelaxed { + pub fn solve_mul_fanout_limit(&self, limit: usize) -> RootCircuitRelaxed { + if limit <= 1 { + panic!("limit must be greater than 1"); + } + + let mut circuits = HashMap::new(); + for (id, circuit) in self.circuits.iter() { + circuits.insert(*id, circuit.solve_mul_fanout_limit(limit)); + } + RootCircuitRelaxed { + circuits, + num_public_inputs: self.num_public_inputs, + expected_num_output_zeroes: self.expected_num_output_zeroes, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::circuit::config::{Config, M31Config as C}; + use crate::circuit::layered::{InputUsize, NormalInputType}; + use crate::field::FieldArith; + use rand::{RngCore, SeedableRng}; + + type CField = ::CircuitField; + + fn verify_mul_fanout(rc: &RootCircuitRelaxed, limit: usize) { + for circuit in rc.circuits.values() { + let mut mul_ref_cnt = vec![0; circuit.num_inputs + 1]; + for insn in circuit.instructions.iter() { + match insn { + Instruction::ConstantLike { .. } => {} + Instruction::SubCircuitCall { .. } => {} + Instruction::InternalVariable { expr } => { + for term in expr.iter() { + if let VarSpec::Quad(x, y) = term.vars { + mul_ref_cnt[x] += 1; + mul_ref_cnt[y] += 1; + } + } + } + } + for _ in 0..insn.num_outputs() { + mul_ref_cnt.push(0); + } + } + for x in mul_ref_cnt.iter().skip(1) { + assert!(*x <= limit); + } + } + } + + fn do_test(root: RootCircuitRelaxed, limits: Vec) { + for lim in limits.iter() { + let new_root = root.solve_mul_fanout_limit(*lim); + assert_eq!(new_root.validate(), Ok(())); + assert_eq!(new_root.input_size(), root.input_size()); + verify_mul_fanout(&new_root, *lim); + let inputs: Vec = (0..root.input_size()) + .map(|_| CField::random_unsafe(&mut rand::thread_rng())) + .collect(); + let (out1, cond1) = root.eval_unsafe(inputs.clone()); + let (out2, cond2) = new_root.eval_unsafe(inputs); + assert_eq!(out1, out2); + assert_eq!(cond1, cond2); + } + } + + #[test] + fn fanout_test1() { + let mut circuit = CircuitRelaxed { + instructions: Vec::new(), + constraints: Vec::new(), + outputs: Vec::new(), + num_inputs: 2, + }; + for i in 3..=1003 { + circuit.instructions.push(Instruction::InternalVariable { + expr: Expression::new_quad(CField::one(), 1, 2), + }); + circuit.constraints.push(i); + circuit.outputs.push(i); + } + + let mut root = RootCircuitRelaxed::::default(); + root.circuits.insert(0, circuit); + + do_test(root, vec![2, 3, 4, 5, 6, 7, 8, 9, 10, 16, 64, 2000]); + } + + #[test] + fn fanout_test2() { + let mut circuit = CircuitRelaxed { + instructions: Vec::new(), + constraints: Vec::new(), + outputs: Vec::new(), + num_inputs: 1, + }; + for _ in 0..2 { + circuit.instructions.push(Instruction::InternalVariable { + expr: Expression::new_quad(CField::from(100), 1, 1), + }); + } + for i in 4..=1003 { + circuit.instructions.push(Instruction::InternalVariable { + expr: Expression::new_quad(CField::from(10), 2, 3), + }); + circuit.constraints.push(i); + circuit.outputs.push(i); + } + + let mut root = RootCircuitRelaxed::::default(); + root.circuits.insert(0, circuit); + + do_test(root, vec![2, 3, 4, 5, 6, 7, 8, 9, 10, 16, 64, 2000]); + } + + #[test] + fn fanout_test3() { + let mut circuit = CircuitRelaxed { + instructions: Vec::new(), + constraints: Vec::new(), + outputs: Vec::new(), + num_inputs: 1, + }; + for _ in 0..2 { + circuit.instructions.push(Instruction::SubCircuitCall { + sub_circuit_id: 1, + inputs: vec![1], + num_outputs: 1, + }); + } + for i in 4..=1003 { + circuit.instructions.push(Instruction::InternalVariable { + expr: Expression::new_quad(CField::from(10), 2, 3), + }); + circuit.constraints.push(i); + circuit.outputs.push(i); + } + + let mut root = RootCircuitRelaxed::::default(); + root.circuits.insert(0, circuit); + root.circuits.insert( + 1, + CircuitRelaxed { + instructions: vec![Instruction::InternalVariable { + expr: Expression::new_quad(CField::from(100), 1, 1), + }], + constraints: vec![], + outputs: vec![2], + num_inputs: 1, + }, + ); + + do_test(root, vec![2, 3, 4, 5, 6, 7, 8, 9, 10, 16, 64, 2000]); + } + + #[test] + fn fanout_test_random() { + let mut rnd = rand::rngs::StdRng::seed_from_u64(3); + let mut circuit = CircuitRelaxed { + instructions: Vec::new(), + constraints: Vec::new(), + outputs: Vec::new(), + num_inputs: 100, + }; + let mut q = vec![]; + for i in 1..=100 { + for _ in 0..5 { + q.push(i); + } + if i % 20 == 0 { + for _ in 0..100 { + q.push(i); + } + } + } + + let n = 10003; + + for i in 101..=n { + let mut terms = vec![]; + let mut c = q.len() / 2; + if i != n { + c = c.min(5); + } + for _ in 0..c { + let x = q.swap_remove(rnd.next_u64() as usize % q.len()); + let y = q.swap_remove(rnd.next_u64() as usize % q.len()); + terms.push(Term { + vars: VarSpec::Quad(x, y), + coef: CField::one(), + }); + } + circuit.instructions.push(Instruction::InternalVariable { + expr: Expression::from_terms(terms), + }); + circuit.constraints.push(i); + circuit.outputs.push(i); + for _ in 0..5 { + q.push(i); + } + if i % 20 == 0 { + for _ in 0..100 { + q.push(i); + } + } + } + + let mut root = RootCircuitRelaxed::::default(); + root.circuits.insert(0, circuit); + + do_test(root, vec![2, 3, 4, 5, 6, 7, 8, 9, 10, 16, 64, 2000]); + } + + #[test] + fn full_fanout_test_and_dump() { + use crate::circuit::ir::common::rand_gen::{RandomCircuitConfig, RandomRange}; + use crate::utils::serde::Serde; + + let config = RandomCircuitConfig { + seed: 2, + num_circuits: RandomRange { min: 20, max: 20 }, + num_inputs: RandomRange { min: 1, max: 3 }, + num_instructions: RandomRange { min: 30, max: 50 }, + num_constraints: RandomRange { min: 0, max: 5 }, + num_outputs: RandomRange { min: 1, max: 3 }, + num_terms: RandomRange { min: 1, max: 5 }, + sub_circuit_prob: 0.05, + }; + let root = crate::circuit::ir::source::RootCircuit::::random(&config); + assert_eq!(root.validate(), Ok(())); + let (_, circuit) = crate::compile::compile_with_options::<_, NormalInputType>( + &root, + crate::compile::CompileOptions::default().with_mul_fanout_limit(256), + ) + .unwrap(); + assert_eq!(circuit.validate(), Ok(())); + for segment in circuit.segments.iter() { + let mut ref_num = vec![0; segment.num_inputs.get(0)]; + for m in segment.gate_muls.iter() { + ref_num[m.inputs[0].offset] += 1; + ref_num[m.inputs[1].offset] += 1; + } + for x in ref_num.iter() { + assert!(*x <= 256); + } + } + + let mut buf = Vec::new(); + circuit.serialize_into(&mut buf).unwrap(); + } +} diff --git a/expander_compiler/src/circuit/ir/expr.rs b/expander_compiler/src/circuit/ir/expr.rs index e9743f88..d6724091 100644 --- a/expander_compiler/src/circuit/ir/expr.rs +++ b/expander_compiler/src/circuit/ir/expr.rs @@ -79,6 +79,18 @@ impl VarSpec { (_, VarSpec::RandomLinear(_)) => panic!("unexpected situation: RandomLinear"), } } + pub fn replace_vars usize>(&self, f: F) -> Self { + match self { + VarSpec::Const => VarSpec::Const, + VarSpec::Linear(x) => VarSpec::Linear(f(*x)), + VarSpec::Quad(x, y) => VarSpec::Quad(f(*x), f(*y)), + VarSpec::Custom { gate_type, inputs } => VarSpec::Custom { + gate_type: *gate_type, + inputs: inputs.iter().cloned().map(&f).collect(), + }, + VarSpec::RandomLinear(x) => VarSpec::RandomLinear(f(*x)), + } + } } impl Ord for Term { @@ -310,16 +322,7 @@ impl Expression { .iter() .map(|term| Term { coef: term.coef, - vars: match &term.vars { - VarSpec::Const => VarSpec::Const, - VarSpec::Linear(index) => VarSpec::Linear(f(*index)), - VarSpec::Quad(index1, index2) => VarSpec::Quad(f(*index1), f(*index2)), - VarSpec::Custom { gate_type, inputs } => VarSpec::Custom { - gate_type: *gate_type, - inputs: inputs.iter().cloned().map(&f).collect(), - }, - VarSpec::RandomLinear(index) => VarSpec::RandomLinear(f(*index)), - }, + vars: term.vars.replace_vars(&f), }) .collect(); Expression { terms } diff --git a/expander_compiler/src/circuit/ir/hint_normalized/mod.rs b/expander_compiler/src/circuit/ir/hint_normalized/mod.rs index b5b169d7..408a5e5f 100644 --- a/expander_compiler/src/circuit/ir/hint_normalized/mod.rs +++ b/expander_compiler/src/circuit/ir/hint_normalized/mod.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; use crate::field::FieldArith; +use crate::hints::registry::HintCaller; use crate::utils::error::Error; use crate::{ circuit::{ @@ -201,6 +202,38 @@ impl common::Instruction for Instruction { } } +impl Instruction { + fn eval_safe( + &self, + values: &[C::CircuitField], + public_inputs: &[C::CircuitField], + hint_caller: &mut impl HintCaller, + ) -> EvalResult { + if let Instruction::ConstantLike(coef) = self { + return match coef { + Coef::Constant(c) => EvalResult::Value(*c), + Coef::PublicInput(i) => EvalResult::Value(public_inputs[*i]), + Coef::Random => EvalResult::Error(Error::UserError( + "random coef occured in witness solver".to_string(), + )), + }; + } + if let Instruction::Hint { + hint_id, + inputs, + num_outputs, + } = self + { + let inputs: Vec = inputs.iter().map(|i| values[*i]).collect(); + return match hints::safe_impl(hint_caller, *hint_id, &inputs, *num_outputs) { + Ok(outputs) => EvalResult::Values(outputs), + Err(e) => EvalResult::Error(e), + }; + } + self.eval_unsafe(values) + } +} + pub type Circuit = common::Circuit>; pub type RootCircuit = common::RootCircuit>; @@ -443,41 +476,27 @@ impl RootCircuit { self.circuits.insert(0, c0); } - pub fn eval_with_public_inputs( + pub fn eval_safe( &self, inputs: Vec, public_inputs: &[C::CircuitField], + hint_caller: &mut impl HintCaller, ) -> Result, Error> { assert_eq!(inputs.len(), self.input_size()); - self.eval_sub_with_public_inputs(&self.circuits[&0], inputs, public_inputs) + self.eval_sub_safe(&self.circuits[&0], inputs, public_inputs, hint_caller) } - fn eval_sub_with_public_inputs( + fn eval_sub_safe( &self, circuit: &Circuit, inputs: Vec, public_inputs: &[C::CircuitField], + hint_caller: &mut impl HintCaller, ) -> Result, Error> { let mut values = vec![C::CircuitField::zero(); 1]; values.extend(inputs); for insn in circuit.instructions.iter() { - if let Instruction::ConstantLike(coef) = insn { - match coef { - Coef::Constant(c) => { - values.push(*c); - } - Coef::PublicInput(i) => { - values.push(public_inputs[*i]); - } - Coef::Random => { - return Err(Error::UserError( - "random coef occured in witness solver".to_string(), - )); - } - } - continue; - } - match insn.eval_unsafe(&values) { + match insn.eval_safe(&values, public_inputs, hint_caller) { EvalResult::Value(v) => { values.push(v); } @@ -485,10 +504,11 @@ impl RootCircuit { values.append(&mut vs); } EvalResult::SubCircuitCall(sub_circuit_id, inputs) => { - let res = self.eval_sub_with_public_inputs( + let res = self.eval_sub_safe( &self.circuits[&sub_circuit_id], inputs.iter().map(|&i| values[i]).collect(), public_inputs, + hint_caller, )?; values.extend(res); } diff --git a/expander_compiler/src/circuit/ir/hint_normalized/witness_solver.rs b/expander_compiler/src/circuit/ir/hint_normalized/witness_solver.rs index 970307d4..77473faa 100644 --- a/expander_compiler/src/circuit/ir/hint_normalized/witness_solver.rs +++ b/expander_compiler/src/circuit/ir/hint_normalized/witness_solver.rs @@ -11,10 +11,11 @@ impl WitnessSolver { &self, vars: Vec, public_vars: Vec, + hint_caller: &mut impl HintCaller, ) -> Result<(Vec, usize), Error> { assert_eq!(vars.len(), self.circuit.input_size()); assert_eq!(public_vars.len(), self.circuit.num_public_inputs); - let mut a = self.circuit.eval_with_public_inputs(vars, &public_vars)?; + let mut a = self.circuit.eval_safe(vars, &public_vars, hint_caller)?; let res_len = a.len(); a.extend(public_vars); Ok((a, res_len)) @@ -24,8 +25,10 @@ impl WitnessSolver { &self, vars: Vec, public_vars: Vec, + hint_caller: &mut impl HintCaller, ) -> Result, Error> { - let (values, num_inputs_per_witness) = self.solve_witness_inner(vars, public_vars)?; + let (values, num_inputs_per_witness) = + self.solve_witness_inner(vars, public_vars, hint_caller)?; Ok(Witness { num_witnesses: 1, num_inputs_per_witness, @@ -40,12 +43,13 @@ impl WitnessSolver { &self, num_witnesses: usize, f: F, + hint_caller: &mut impl HintCaller, ) -> Result, Error> { let mut values = Vec::new(); let mut num_inputs_per_witness = 0; for i in 0..num_witnesses { let (a, b) = f(i); - let (a, num) = self.solve_witness_inner(a, b)?; + let (a, num) = self.solve_witness_inner(a, b, hint_caller)?; values.extend(a); num_inputs_per_witness = num; } diff --git a/expander_compiler/src/circuit/ir/source/chains.rs b/expander_compiler/src/circuit/ir/source/chains.rs index d90af870..64ed3b7b 100644 --- a/expander_compiler/src/circuit/ir/source/chains.rs +++ b/expander_compiler/src/circuit/ir/source/chains.rs @@ -140,6 +140,7 @@ impl Circuit { } impl RootCircuit { + // this function must be used with remove_unreachable pub fn detect_chains(&mut self) { for (_, circuit) in self.circuits.iter_mut() { circuit.detect_chains(); diff --git a/expander_compiler/src/circuit/ir/source/tests.rs b/expander_compiler/src/circuit/ir/source/tests.rs index e4f7e08b..7b789b30 100644 --- a/expander_compiler/src/circuit/ir/source/tests.rs +++ b/expander_compiler/src/circuit/ir/source/tests.rs @@ -1,7 +1,7 @@ use rand::{Rng, RngCore}; use super::{ - ConstraintType, + Circuit, ConstraintType, Instruction::{self, ConstantLike, LinComb, Mul}, RootCircuit, }; @@ -190,3 +190,75 @@ fn opt_remove_unreachable_2() { } } } + +fn test_detect_chains_inner(is_mul: bool, seq_typ: usize) { + let n = 1000000; + let mut root = RootCircuit::::default(); + let mut insns = vec![]; + let mut lst = 1; + let get_insn = if is_mul { + |x, y| Instruction::::Mul(vec![x, y]) + } else { + |x, y| { + Instruction::LinComb(expr::LinComb { + terms: vec![ + expr::LinCombTerm { + coef: CField::one(), + var: x, + }, + expr::LinCombTerm { + coef: CField::one(), + var: y, + }, + ], + constant: CField::zero(), + }) + } + }; + if seq_typ == 1 { + lst = n; + for i in (1..n).rev() { + insns.push(get_insn(lst, i)); + lst = n * 2 - i; + } + } else if seq_typ == 2 { + for i in 2..=n { + insns.push(get_insn(lst, i)); + lst = n - 1 + i; + } + } else { + let mut q: Vec = (1..=n).collect(); + let mut i = 0; + lst = n; + while i + 1 < q.len() { + lst += 1; + insns.push(get_insn(q[i], q[i + 1])); + q.push(lst); + i += 2; + } + } + root.circuits.insert( + 0, + Circuit:: { + num_inputs: n, + instructions: insns, + constraints: vec![], + outputs: vec![lst], + }, + ); + assert_eq!(root.validate(), Ok(())); + root.detect_chains(); + let (root, _) = root.remove_unreachable(); + println!("{:?}", root); + assert_eq!(root.validate(), Ok(())); +} + +#[test] +fn test_detect_chains() { + test_detect_chains_inner(false, 1); + test_detect_chains_inner(false, 2); + test_detect_chains_inner(false, 3); + test_detect_chains_inner(true, 1); + test_detect_chains_inner(true, 2); + test_detect_chains_inner(true, 3); +} diff --git a/expander_compiler/src/circuit/layered/export.rs b/expander_compiler/src/circuit/layered/export.rs index c5718704..844c34e2 100644 --- a/expander_compiler/src/circuit/layered/export.rs +++ b/expander_compiler/src/circuit/layered/export.rs @@ -1,8 +1,8 @@ use super::*; -impl Circuit { +impl Circuit { pub fn export_to_expander< - DestConfig: expander_config::GKRConfig, + DestConfig: gkr_field_config::GKRFieldConfig, >( &self, ) -> expander_circuit::RecursiveCircuit { @@ -10,7 +10,7 @@ impl Circuit { .segments .iter() .map(|seg| expander_circuit::Segment { - i_var_num: seg.num_inputs.trailing_zeros() as usize, + i_var_num: seg.num_inputs.get(0).trailing_zeros() as usize, o_var_num: seg.num_outputs.trailing_zeros() as usize, gate_muls: seg .gate_muls @@ -33,7 +33,7 @@ impl Circuit { .map(|gate| { let (c, r) = gate.coef.export_to_expander(); expander_circuit::GateUni { - i_ids: [gate.inputs[0]], + i_ids: [gate.inputs[0].offset()], o_id: gate.output, coef: c, coef_type: r, @@ -50,7 +50,7 @@ impl Circuit { seg.1 .iter() .map(|alloc| expander_circuit::Allocation { - i_offset: alloc.input_offset, + i_offset: alloc.input_offset.get(0), o_offset: alloc.output_offset, }) .collect(), @@ -68,3 +68,71 @@ impl Circuit { } } } + +impl Circuit { + pub fn export_to_expander< + DestConfig: gkr_field_config::GKRFieldConfig, + >( + &self, + ) -> crosslayer_prototype::CrossLayerRecursiveCircuit { + let mut segments = Vec::new(); + for segment in self.segments.iter() { + let mut gate_adds = Vec::new(); + let mut gate_relays = Vec::new(); + for gate in segment.gate_adds.iter() { + if gate.inputs[0].layer() == 0 { + gate_adds.push(gate.export_to_crosslayer_simple()); + } else { + let (c, r) = gate.coef.export_to_expander(); + assert_eq!(r, expander_circuit::CoefType::Constant); + gate_relays.push(crosslayer_prototype::CrossLayerRelay { + i_id: gate.inputs[0].offset(), + o_id: gate.output, + i_layer: gate.inputs[0].layer(), + coef: c, + }); + } + } + assert_eq!(segment.gate_customs.len(), 0); + segments.push(crosslayer_prototype::CrossLayerSegment { + input_size: segment.num_inputs.to_vec(), + output_size: segment.num_outputs, + child_segs: segment + .child_segs + .iter() + .map(|seg| { + ( + seg.0, + seg.1 + .iter() + .map(|alloc| crosslayer_prototype::Allocation { + i_offset: alloc.input_offset.to_vec(), + o_offset: alloc.output_offset, + }) + .collect(), + ) + }) + .collect(), + gate_muls: segment + .gate_muls + .iter() + .map(|gate| gate.export_to_crosslayer_simple()) + .collect(), + gate_csts: segment + .gate_consts + .iter() + .map(|gate| gate.export_to_crosslayer_simple()) + .collect(), + gate_adds, + gate_relay: gate_relays, + }); + } + crosslayer_prototype::CrossLayerRecursiveCircuit { + num_public_inputs: self.num_public_inputs, + num_outputs: self.num_actual_outputs, + expected_num_output_zeros: self.expected_num_output_zeroes, + layers: self.layer_ids.clone(), + segments, + } + } +} diff --git a/expander_compiler/src/circuit/layered/mod.rs b/expander_compiler/src/circuit/layered/mod.rs index 83a72fb0..72472950 100644 --- a/expander_compiler/src/circuit/layered/mod.rs +++ b/expander_compiler/src/circuit/layered/mod.rs @@ -2,7 +2,11 @@ use std::{fmt, hash::Hash}; use arith::FieldForECC; -use crate::{field::FieldArith, hints, utils::error::Error}; +use crate::{ + field::FieldArith, + hints, + utils::{error::Error, serde::Serde}, +}; use super::config::Config; @@ -109,78 +113,263 @@ impl Coef { } } +#[derive(Debug, Clone, Copy, Default, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub struct CrossLayerInput { + // the actual layer of the input is (output_layer-1-layer) + pub layer: usize, + pub offset: usize, +} + +#[derive(Debug, Clone, Copy, Default, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub struct NormalInput { + pub offset: usize, +} + +pub trait Input: + std::fmt::Debug + + std::fmt::Display + + Clone + + Copy + + Default + + Hash + + PartialEq + + Eq + + PartialOrd + + Ord + + Serde +{ + fn layer(&self) -> usize; + fn offset(&self) -> usize; + fn set_offset(&mut self, offset: usize); + fn new(layer: usize, offset: usize) -> Self; +} + +impl Input for CrossLayerInput { + fn layer(&self) -> usize { + self.layer + } + fn offset(&self) -> usize { + self.offset + } + fn set_offset(&mut self, offset: usize) { + self.offset = offset; + } + fn new(layer: usize, offset: usize) -> Self { + CrossLayerInput { layer, offset } + } +} + +impl Input for NormalInput { + fn layer(&self) -> usize { + 0 + } + fn offset(&self) -> usize { + self.offset + } + fn set_offset(&mut self, offset: usize) { + self.offset = offset; + } + fn new(layer: usize, offset: usize) -> Self { + if layer != 0 { + panic!("new called on non-zero layer"); + } + NormalInput { offset } + } +} + +#[derive(Debug, Default, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub struct CrossLayerInputUsize { + v: Vec, +} + +#[derive(Debug, Default, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub struct NormalInputUsize { + v: usize, +} + +pub trait InputUsize: + std::fmt::Debug + Default + Clone + Hash + PartialEq + Eq + PartialOrd + Ord + Serde +{ + type Iter<'a>: Iterator + where + Self: 'a; + fn len(&self) -> usize; + fn iter(&self) -> Self::Iter<'_>; + fn get(&self, i: usize) -> usize { + self.iter().nth(i).unwrap() + } + fn is_empty(&self) -> bool { + self.len() == 0 + } + fn from_vec(v: Vec) -> Self; + fn to_vec(&self) -> Vec { + self.iter().collect() + } +} + +impl InputUsize for CrossLayerInputUsize { + type Iter<'a> = std::iter::Copied>; + fn len(&self) -> usize { + self.v.len() + } + fn iter(&self) -> Self::Iter<'_> { + self.v.iter().copied() + } + fn from_vec(v: Vec) -> Self { + CrossLayerInputUsize { v } + } +} + +impl InputUsize for NormalInputUsize { + type Iter<'a> = std::iter::Once; + fn len(&self) -> usize { + 1 + } + fn iter(&self) -> Self::Iter<'_> { + std::iter::once(self.v) + } + fn from_vec(v: Vec) -> Self { + if v.len() != 1 { + panic!("from_vec called on non-singleton vec"); + } + NormalInputUsize { v: v[0] } + } +} + +pub trait InputType: + std::fmt::Debug + Default + Clone + Hash + PartialEq + Eq + PartialOrd + Ord +{ + type Input: Input; + type InputUsize: InputUsize; + const CROSS_LAYER_RELAY: bool; +} + +#[derive(Debug, Default, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub struct CrossLayerInputType; + +impl InputType for CrossLayerInputType { + type Input = CrossLayerInput; + type InputUsize = CrossLayerInputUsize; + const CROSS_LAYER_RELAY: bool = true; +} + +#[derive(Debug, Default, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub struct NormalInputType; + +impl InputType for NormalInputType { + type Input = NormalInput; + type InputUsize = NormalInputUsize; + const CROSS_LAYER_RELAY: bool = false; +} + #[derive(Debug, Clone, Hash, PartialEq, Eq)] -pub struct Gate { - pub inputs: [usize; INPUT_NUM], +pub struct Gate { + pub inputs: [I::Input; INPUT_NUM], pub output: usize, pub coef: Coef, } -impl Gate { +impl Gate { pub fn export_to_expander< - DestConfig: expander_config::GKRConfig, + DestConfig: gkr_field_config::GKRFieldConfig, >( &self, ) -> expander_circuit::Gate { let (c, r) = self.coef.export_to_expander(); + let mut i_ids: [usize; INPUT_NUM] = [0; INPUT_NUM]; + for (x, y) in self.inputs.iter().zip(i_ids.iter_mut()) { + *y = x.offset(); + } expander_circuit::Gate { - i_ids: self.inputs, + i_ids, o_id: self.output, coef: c, coef_type: r, - gate_type: 2 - INPUT_NUM, // TODO: check this + gate_type: 2 - INPUT_NUM, } } } -pub type GateMul = Gate; -pub type GateAdd = Gate; -pub type GateConst = Gate; +impl Gate { + pub fn export_to_crosslayer_simple< + DestConfig: gkr_field_config::GKRFieldConfig, + >( + &self, + ) -> crosslayer_prototype::SimpleGate { + let (c, r) = self.coef.export_to_expander(); + let mut i_ids: [usize; INPUT_NUM] = [0; INPUT_NUM]; + for (x, y) in self.inputs.iter().zip(i_ids.iter_mut()) { + assert_eq!(x.layer(), 0); + *y = x.offset(); + } + crosslayer_prototype::SimpleGate { + i_ids, + o_id: self.output, + coef: c, + coef_type: match r { + expander_circuit::CoefType::Constant => crosslayer_prototype::CoefType::Constant, + expander_circuit::CoefType::Random => crosslayer_prototype::CoefType::Random, + expander_circuit::CoefType::PublicInput(x) => { + crosslayer_prototype::CoefType::PublicInput(x) + } + }, + } + } +} + +pub type GateMul = Gate; +pub type GateAdd = Gate; +pub type GateConst = Gate; #[derive(Debug, Clone, Hash, PartialEq, Eq)] -pub struct GateCustom { +pub struct GateCustom { pub gate_type: usize, - pub inputs: Vec, + pub inputs: Vec, pub output: usize, pub coef: Coef, } #[derive(Debug, Clone, PartialOrd, Ord, PartialEq, Eq)] -pub struct Allocation { - pub input_offset: usize, +pub struct Allocation { + pub input_offset: I::InputUsize, pub output_offset: usize, } -pub type ChildSpec = (usize, Vec); +pub type ChildSpec = (usize, Vec>); #[derive(Default, Debug, Clone, PartialOrd, Ord, PartialEq, Eq)] -pub struct Segment { - pub num_inputs: usize, +pub struct Segment { + pub num_inputs: I::InputUsize, pub num_outputs: usize, - pub child_segs: Vec, - pub gate_muls: Vec>, - pub gate_adds: Vec>, - pub gate_consts: Vec>, - pub gate_customs: Vec>, + pub child_segs: Vec>, + pub gate_muls: Vec>, + pub gate_adds: Vec>, + pub gate_consts: Vec>, + pub gate_customs: Vec>, } #[derive(Debug, Clone, PartialOrd, Ord, PartialEq, Eq)] -pub struct Circuit { +pub struct Circuit { pub num_public_inputs: usize, pub num_actual_outputs: usize, pub expected_num_output_zeroes: usize, - pub segments: Vec>, + pub segments: Vec>, pub layer_ids: Vec, } -impl Circuit { +impl Circuit { pub fn validate(&self) -> Result<(), Error> { for (i, seg) in self.segments.iter().enumerate() { - if seg.num_inputs == 0 || (seg.num_inputs & (seg.num_inputs - 1)) != 0 { - return Err(Error::InternalError(format!( - "segment {} inputlen {} not power of 2", - i, seg.num_inputs - ))); + for (j, x) in seg.num_inputs.iter().enumerate() { + if x == 0 || (x & (x - 1)) != 0 { + return Err(Error::InternalError(format!( + "segment {} input {} len {} not power of 2", + i, j, x + ))); + } + } + if seg.num_inputs.len() == 0 { + return Err(Error::InternalError(format!("segment {} inputlen 0", i))); } if seg.num_outputs == 0 || (seg.num_outputs & (seg.num_outputs - 1)) != 0 { return Err(Error::InternalError(format!( @@ -189,20 +378,53 @@ impl Circuit { ))); } for m in seg.gate_muls.iter() { - if m.inputs[0] >= seg.num_inputs - || m.inputs[1] >= seg.num_inputs - || m.output >= seg.num_outputs - { + if m.inputs[0].layer() >= self.layer_ids.len() { + return Err(Error::InternalError(format!( + "segment {} mul gate ({:?}, {:?}, {}) input 0 layer out of range", + i, m.inputs[0], m.inputs[1], m.output + ))); + } + if m.inputs[1].layer() >= self.layer_ids.len() { + return Err(Error::InternalError(format!( + "segment {} mul gate ({:?}, {:?}, {}) input 1 layer out of range", + i, m.inputs[0], m.inputs[1], m.output + ))); + } + if m.inputs[0].offset() >= seg.num_inputs.get(m.inputs[0].layer()) { + return Err(Error::InternalError(format!( + "segment {} mul gate ({:?}, {:?}, {}) input 0 out of range", + i, m.inputs[0], m.inputs[1], m.output + ))); + } + if m.inputs[1].offset() >= seg.num_inputs.get(m.inputs[1].layer()) { + return Err(Error::InternalError(format!( + "segment {} mul gate ({:?}, {:?}, {}) input 1 out of range", + i, m.inputs[0], m.inputs[1], m.output + ))); + } + if m.output >= seg.num_outputs { return Err(Error::InternalError(format!( - "segment {} mul gate ({}, {}, {}) out of range", + "segment {} mul gate ({:?}, {:?}, {}) out of range", i, m.inputs[0], m.inputs[1], m.output ))); } } for a in seg.gate_adds.iter() { - if a.inputs[0] >= seg.num_inputs || a.output >= seg.num_outputs { + if a.inputs[0].layer() >= self.layer_ids.len() { return Err(Error::InternalError(format!( - "segment {} add gate ({}, {}) out of range", + "segment {} add gate ({:?}, {}) input layer out of range", + i, a.inputs[0], a.output + ))); + } + if a.inputs[0].offset() >= seg.num_inputs.get(a.inputs[0].layer()) { + return Err(Error::InternalError(format!( + "segment {} add gate ({:?}, {}) input out of range", + i, a.inputs[0], a.output + ))); + } + if a.output >= seg.num_outputs { + return Err(Error::InternalError(format!( + "segment {} add gate ({:?}, {}) out of range", i, a.inputs[0], a.output ))); } @@ -216,11 +438,17 @@ impl Circuit { } } for cu in seg.gate_customs.iter() { - for &input in cu.inputs.iter() { - if input >= seg.num_inputs { + for input in cu.inputs.iter() { + if input.layer() >= self.layer_ids.len() { + return Err(Error::InternalError(format!( + "segment {} custom gate {} input layer out of range", + i, cu.output + ))); + } + if input.offset() >= seg.num_inputs.get(input.layer()) { return Err(Error::InternalError(format!( "segment {} custom gate {} input out of range", - i, input + i, cu.output ))); } } @@ -239,18 +467,43 @@ impl Circuit { ))); } let subc = &self.segments[*sub_id]; + if subc.num_inputs.len() > seg.num_inputs.len() { + return Err(Error::InternalError(format!( + "segment {} subcircuit {} input length {} larger than {}", + i, + sub_id, + subc.num_inputs.len(), + seg.num_inputs.len() + ))); + } for a in allocs.iter() { - if a.input_offset % subc.num_inputs != 0 { + if a.input_offset.len() != subc.num_inputs.len() { return Err(Error::InternalError(format!( - "segment {} subcircuit {} input offset {} not aligned to {}", - i, sub_id, a.input_offset, subc.num_inputs + "segment {} subcircuit {} input offset {:?} length not equal to {}", + i, + sub_id, + a.input_offset, + subc.num_inputs.len() ))); } - if a.input_offset + subc.num_inputs > seg.num_inputs { - return Err(Error::InternalError(format!( - "segment {} subcircuit {} input offset {} out of range", - i, sub_id, a.input_offset - ))); + for ((x, y), z) in a + .input_offset + .iter() + .zip(subc.num_inputs.iter()) + .zip(seg.num_inputs.iter()) + { + if x % y != 0 { + return Err(Error::InternalError(format!( + "segment {} subcircuit {} input offset {} not aligned to {}", + i, sub_id, x, y + ))); + } + if x + y > z { + return Err(Error::InternalError(format!( + "segment {} subcircuit {} input offset {} out of range", + i, sub_id, x + ))); + } } if a.output_offset % subc.num_outputs != 0 { return Err(Error::InternalError(format!( @@ -275,65 +528,101 @@ impl Circuit { if self.layer_ids.is_empty() { return Err(Error::InternalError("empty layer".to_string())); } - for i in 1..self.layer_ids.len() { - let cur = &self.segments[self.layer_ids[i]]; - let prev = &self.segments[self.layer_ids[i - 1]]; - if cur.num_inputs != prev.num_outputs { + let mut layer_sizes = Vec::with_capacity(self.layer_ids.len() + 1); + layer_sizes.push(self.segments[self.layer_ids[0]].num_inputs.get(0)); + for l in self.layer_ids.iter() { + layer_sizes.push(self.segments[*l].num_outputs); + } + for (i, l) in self.layer_ids.iter().enumerate() { + let cur = &self.segments[*l]; + if cur.num_inputs.len() > i + 1 { return Err(Error::InternalError(format!( - "segment {} inputlen {} not equal to segment {} outputlen {}", - self.layer_ids[i], - cur.num_inputs, - self.layer_ids[i - 1], - prev.num_outputs + "layer {} input length {} larger than {}", + i, + cur.num_inputs.len(), + i + 1 ))); } - } - let (input_mask, output_mask) = self.compute_masks(); - for i in 1..self.layer_ids.len() { - for j in 0..self.segments[self.layer_ids[i]].num_inputs { - if input_mask[self.layer_ids[i]][j] && !output_mask[self.layer_ids[i - 1]][j] { + for (j, x) in cur.num_inputs.iter().enumerate() { + if x != layer_sizes[i - j] { return Err(Error::InternalError(format!( - "circuit {} input {} not initialized by circuit {} output", - self.layer_ids[i], + "layer {} input {} length {} not equal to {}", + i, j, - self.layer_ids[i - 1] + x, + layer_sizes[i - j] ))); } } } + let (input_mask, output_mask) = self.compute_masks(); + for i in 1..self.layer_ids.len() { + for (l, len) in self.segments[self.layer_ids[i]] + .num_inputs + .iter() + .enumerate() + { + if i == l { + // if this is also the global input, it's always initialized + continue; + } + for j in 0..len { + if input_mask[self.layer_ids[i]][l][j] + && !output_mask[self.layer_ids[i - 1 - l]][j] + { + return Err(Error::InternalError(format!( + "circuit {} (layer {}) input {} not initialized by circuit {} (layer {}) output", + self.layer_ids[i], + i, + j, + self.layer_ids[i - 1 - l], + i - 1 - l + ))); + } + } + } + } Ok(()) } - fn compute_masks(&self) -> (Vec>, Vec>) { - let mut input_mask: Vec> = Vec::with_capacity(self.segments.len()); + fn compute_masks(&self) -> (Vec>>, Vec>) { + let mut input_mask: Vec>> = Vec::with_capacity(self.segments.len()); let mut output_mask: Vec> = Vec::with_capacity(self.segments.len()); for seg in self.segments.iter() { - let mut input_mask_seg = vec![false; seg.num_inputs]; + let mut input_mask_seg: Vec> = + seg.num_inputs.iter().map(|x| vec![false; x]).collect(); let mut output_mask_seg = vec![false; seg.num_outputs]; for m in seg.gate_muls.iter() { - input_mask_seg[m.inputs[0]] = true; - input_mask_seg[m.inputs[1]] = true; + input_mask_seg[m.inputs[0].layer()][m.inputs[0].offset()] = true; + input_mask_seg[m.inputs[1].layer()][m.inputs[1].offset()] = true; output_mask_seg[m.output] = true; } for a in seg.gate_adds.iter() { - input_mask_seg[a.inputs[0]] = true; + input_mask_seg[a.inputs[0].layer()][a.inputs[0].offset()] = true; output_mask_seg[a.output] = true; } for cs in seg.gate_consts.iter() { output_mask_seg[cs.output] = true; } for cu in seg.gate_customs.iter() { - for &input in cu.inputs.iter() { - input_mask_seg[input] = true; + for input in cu.inputs.iter() { + input_mask_seg[input.layer()][input.offset()] = true; } output_mask_seg[cu.output] = true; } for (sub_id, allocs) in seg.child_segs.iter() { let subc = &self.segments[*sub_id]; for a in allocs.iter() { - for j in 0..subc.num_inputs { - input_mask_seg[a.input_offset + j] = - input_mask_seg[a.input_offset + j] || input_mask[*sub_id][j]; + for (l, (off, len)) in a + .input_offset + .iter() + .zip(subc.num_inputs.iter()) + .enumerate() + { + for i in 0..len { + input_mask_seg[l][off + i] = + input_mask_seg[l][off + i] || input_mask[*sub_id][l][i]; + } } for j in 0..subc.num_outputs { output_mask_seg[a.output_offset + j] = @@ -348,19 +637,24 @@ impl Circuit { } pub fn input_size(&self) -> usize { - self.segments[self.layer_ids[0]].num_inputs + self.segments[self.layer_ids[0]].num_inputs.get(0) } pub fn eval_unsafe(&self, inputs: Vec) -> (Vec, bool) { if inputs.len() != self.input_size() { panic!("input length mismatch"); } - let mut cur = inputs; - for &id in self.layer_ids.iter() { - let mut next = vec![C::CircuitField::zero(); self.segments[id].num_outputs]; - self.apply_segment_unsafe(&self.segments[id], &cur, &mut next); - cur = next; + let mut cur = vec![inputs]; + for id in self.layer_ids.iter() { + let mut next = vec![C::CircuitField::zero(); self.segments[*id].num_outputs]; + let mut inputs: Vec<&[C::CircuitField]> = Vec::new(); + for i in 0..self.segments[*id].num_inputs.len() { + inputs.push(&cur[cur.len() - i - 1]); + } + self.apply_segment_unsafe(&self.segments[*id], &inputs, &mut next); + cur.push(next); } + let cur = cur.last().unwrap(); let mut constraints_satisfied = true; for out in cur.iter().take(self.expected_num_output_zeroes) { if !out.is_zero() { @@ -376,35 +670,45 @@ impl Circuit { fn apply_segment_unsafe( &self, - seg: &Segment, - cur: &[C::CircuitField], + seg: &Segment, + cur: &[&[C::CircuitField]], nxt: &mut [C::CircuitField], ) { for m in seg.gate_muls.iter() { - nxt[m.output] += cur[m.inputs[0]] * cur[m.inputs[1]] * m.coef.get_value_unsafe(); + nxt[m.output] += cur[m.inputs[0].layer()][m.inputs[0].offset()] + * cur[m.inputs[1].layer()][m.inputs[1].offset()] + * m.coef.get_value_unsafe(); } for a in seg.gate_adds.iter() { - nxt[a.output] += cur[a.inputs[0]] * a.coef.get_value_unsafe(); + nxt[a.output] += + cur[a.inputs[0].layer()][a.inputs[0].offset()] * a.coef.get_value_unsafe(); } for cs in seg.gate_consts.iter() { nxt[cs.output] += cs.coef.get_value_unsafe(); } for cu in seg.gate_customs.iter() { let mut inputs = Vec::with_capacity(cu.inputs.len()); - for &input in cu.inputs.iter() { - inputs.push(cur[input]); + for input in cu.inputs.iter() { + inputs.push(cur[input.layer()][input.offset()]); } let outputs = hints::stub_impl(cu.gate_type, &inputs, 1); - for (i, &output) in outputs.iter().enumerate() { - nxt[cu.output + i] += output * cu.coef.get_value_unsafe(); + for (i, output) in outputs.iter().enumerate() { + nxt[cu.output + i] += *output * cu.coef.get_value_unsafe(); } } for (sub_id, allocs) in seg.child_segs.iter() { let subc = &self.segments[*sub_id]; for a in allocs.iter() { + let inputs = a + .input_offset + .iter() + .zip(subc.num_inputs.iter()) + .enumerate() + .map(|(l, (off, len))| &cur[l][off..off + len]) + .collect::>(); self.apply_segment_unsafe( subc, - &cur[a.input_offset..a.input_offset + subc.num_inputs], + &inputs, &mut nxt[a.output_offset..a.output_offset + subc.num_outputs], ); } @@ -419,17 +723,22 @@ impl Circuit { if inputs.len() != self.input_size() { panic!("input length mismatch"); } - let mut cur = inputs; - for &id in self.layer_ids.iter() { - let mut next = vec![C::CircuitField::zero(); self.segments[id].num_outputs]; + let mut cur = vec![inputs]; + for id in self.layer_ids.iter() { + let mut next = vec![C::CircuitField::zero(); self.segments[*id].num_outputs]; + let mut inputs: Vec<&[C::CircuitField]> = Vec::new(); + for i in 0..self.segments[*id].num_inputs.len() { + inputs.push(&cur[cur.len() - i - 1]); + } self.apply_segment_with_public_inputs( - &self.segments[id], - &cur, + &self.segments[*id], + &inputs, &mut next, public_inputs, ); - cur = next; + cur.push(next); } + let cur = cur.last().unwrap(); let mut constraints_satisfied = true; for out in cur.iter().take(self.expected_num_output_zeroes) { if !out.is_zero() { @@ -445,38 +754,46 @@ impl Circuit { fn apply_segment_with_public_inputs( &self, - seg: &Segment, - cur: &[C::CircuitField], + seg: &Segment, + cur: &[&[C::CircuitField]], nxt: &mut [C::CircuitField], public_inputs: &[C::CircuitField], ) { for m in seg.gate_muls.iter() { - nxt[m.output] += cur[m.inputs[0]] - * cur[m.inputs[1]] + nxt[m.output] += cur[m.inputs[0].layer()][m.inputs[0].offset()] + * cur[m.inputs[1].layer()][m.inputs[1].offset()] * m.coef.get_value_with_public_inputs(public_inputs); } for a in seg.gate_adds.iter() { - nxt[a.output] += cur[a.inputs[0]] * a.coef.get_value_with_public_inputs(public_inputs); + nxt[a.output] += cur[a.inputs[0].layer()][a.inputs[0].offset()] + * a.coef.get_value_with_public_inputs(public_inputs); } for cs in seg.gate_consts.iter() { nxt[cs.output] += cs.coef.get_value_with_public_inputs(public_inputs); } for cu in seg.gate_customs.iter() { let mut inputs = Vec::with_capacity(cu.inputs.len()); - for &input in cu.inputs.iter() { - inputs.push(cur[input]); + for input in cu.inputs.iter() { + inputs.push(cur[input.layer()][input.offset()]); } let outputs = hints::stub_impl(cu.gate_type, &inputs, 1); - for (i, &output) in outputs.iter().enumerate() { - nxt[cu.output + i] += output * cu.coef.get_value_unsafe(); + for (i, output) in outputs.iter().enumerate() { + nxt[cu.output + i] += *output * cu.coef.get_value_with_public_inputs(public_inputs); } } for (sub_id, allocs) in seg.child_segs.iter() { let subc = &self.segments[*sub_id]; for a in allocs.iter() { + let inputs = a + .input_offset + .iter() + .zip(subc.num_inputs.iter()) + .enumerate() + .map(|(l, (off, len))| &cur[l][off..off + len]) + .collect::>(); self.apply_segment_with_public_inputs( subc, - &cur[a.input_offset..a.input_offset + subc.num_inputs], + &inputs, &mut nxt[a.output_offset..a.output_offset + subc.num_outputs], public_inputs, ); @@ -505,15 +822,27 @@ impl fmt::Display for Coef { } } -impl fmt::Display for Segment { +impl fmt::Display for CrossLayerInput { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "(layer={}, offset={})", self.layer, self.offset) + } +} + +impl fmt::Display for NormalInput { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.offset) + } +} + +impl fmt::Display for Segment { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - writeln!(f, "input={} output={}", self.num_inputs, self.num_outputs)?; + writeln!(f, "input={:?} output={}", self.num_inputs, self.num_outputs)?; for (sub_id, allocs) in self.child_segs.iter() { writeln!(f, "apply circuit {} at:", sub_id)?; for a in allocs.iter() { writeln!( f, - " input_offset={} output_offset={}", + " input_offset={:?} output_offset={}", a.input_offset, a.output_offset )?; } @@ -545,7 +874,7 @@ impl fmt::Display for Segment { } } -impl fmt::Display for Circuit { +impl fmt::Display for Circuit { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { for (i, seg) in self.segments.iter().enumerate() { write!(f, "Circuit {}: {}", i, seg)?; diff --git a/expander_compiler/src/circuit/layered/opt.rs b/expander_compiler/src/circuit/layered/opt.rs index 9bc232b6..188445ec 100644 --- a/expander_compiler/src/circuit/layered/opt.rs +++ b/expander_compiler/src/circuit/layered/opt.rs @@ -9,23 +9,14 @@ use crate::utils::{misc::next_power_of_two, union_find::UnionFind}; use super::*; -impl PartialOrd for Gate { +impl PartialOrd for Gate { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } -impl Ord for Gate { +impl Ord for Gate { fn cmp(&self, other: &Self) -> Ordering { - match self.output.cmp(&other.output) { - Ordering::Less => { - return Ordering::Less; - } - Ordering::Greater => { - return Ordering::Greater; - } - Ordering::Equal => {} - }; for i in 0..INPUT_NUM { match self.inputs[i].cmp(&other.inputs[i]) { Ordering::Less => { @@ -37,17 +28,26 @@ impl Ord for Gate { Ordering::Equal => {} }; } + match self.output.cmp(&other.output) { + Ordering::Less => { + return Ordering::Less; + } + Ordering::Greater => { + return Ordering::Greater; + } + Ordering::Equal => {} + }; self.coef.cmp(&other.coef) } } -impl PartialOrd for GateCustom { +impl PartialOrd for GateCustom { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } -impl Ord for GateCustom { +impl Ord for GateCustom { fn cmp(&self, other: &Self) -> Ordering { match self.gate_type.cmp(&other.gate_type) { Ordering::Less => { @@ -58,15 +58,6 @@ impl Ord for GateCustom { } Ordering::Equal => {} }; - match self.output.cmp(&other.output) { - Ordering::Less => { - return Ordering::Less; - } - Ordering::Greater => { - return Ordering::Greater; - } - Ordering::Equal => {} - }; match self.inputs.len().cmp(&other.inputs.len()) { Ordering::Less => { return Ordering::Less; @@ -87,18 +78,27 @@ impl Ord for GateCustom { Ordering::Equal => {} }; } + match self.output.cmp(&other.output) { + Ordering::Less => { + return Ordering::Less; + } + Ordering::Greater => { + return Ordering::Greater; + } + Ordering::Equal => {} + }; self.coef.cmp(&other.coef) } } -trait GateOpt: PartialEq + Ord + Clone { +trait GateOpt: PartialEq + Ord + Clone { fn coef_add(&mut self, coef: Coef); fn can_merge_with(&self, other: &Self) -> bool; fn get_coef(&self) -> Coef; - fn add_offset(&self, in_offset: usize, out_offset: usize) -> Self; + fn add_offset(&self, in_offset: &I::InputUsize, out_offset: usize) -> Self; } -impl GateOpt for Gate { +impl GateOpt for Gate { fn coef_add(&mut self, coef: Coef) { self.coef = self.coef.add_constant(coef.get_constant().unwrap()); } @@ -111,10 +111,10 @@ impl GateOpt for Gate { fn get_coef(&self) -> Coef { self.coef.clone() } - fn add_offset(&self, in_offset: usize, out_offset: usize) -> Self { + fn add_offset(&self, in_offset: &I::InputUsize, out_offset: usize) -> Self { let mut inputs = self.inputs; for input in inputs.iter_mut() { - *input += in_offset; + input.set_offset(input.offset() + in_offset.get(input.layer())); } let output = self.output + out_offset; let coef = self.coef.clone(); @@ -126,7 +126,7 @@ impl GateOpt for Gate { } } -impl GateOpt for GateCustom { +impl GateOpt for GateCustom { fn coef_add(&mut self, coef: Coef) { self.coef = self.coef.add_constant(coef.get_constant().unwrap()); } @@ -140,10 +140,10 @@ impl GateOpt for GateCustom { fn get_coef(&self) -> Coef { self.coef.clone() } - fn add_offset(&self, in_offset: usize, out_offset: usize) -> Self { + fn add_offset(&self, in_offset: &I::InputUsize, out_offset: usize) -> Self { let mut inputs = self.inputs.clone(); for input in inputs.iter_mut() { - *input += in_offset; + input.set_offset(input.offset() + in_offset.get(input.layer())); } let output = self.output + out_offset; let coef = self.coef.clone(); @@ -156,7 +156,7 @@ impl GateOpt for GateCustom { } } -fn dedup_gates>(gates: &mut Vec, trim_zero: bool) { +fn dedup_gates>(gates: &mut Vec, trim_zero: bool) { gates.sort(); let mut lst = 0; for i in 1..gates.len() { @@ -188,14 +188,14 @@ fn dedup_gates>(gates: &mut Vec, trim_zero: bool) { } #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -enum UniGate { - Mul(GateMul), - Add(GateAdd), - Const(GateConst), - Custom(GateCustom), +enum UniGate { + Mul(GateMul), + Add(GateAdd), + Const(GateConst), + Custom(GateCustom), } -impl Segment { +impl Segment { fn dedup_gates(&mut self) { let mut occured_outputs = vec![false; self.num_outputs]; for gate in self.gate_muls.iter_mut() { @@ -239,7 +239,7 @@ impl Segment { self.gate_consts.sort(); } - fn sample_gates(&self, num_gates: usize, mut rng: impl RngCore) -> HashSet> { + fn sample_gates(&self, num_gates: usize, mut rng: impl RngCore) -> HashSet> { let tot_gates = self.num_all_gates(); let mut ids: HashSet = HashSet::new(); while ids.len() < num_gates && ids.len() < tot_gates { @@ -251,25 +251,25 @@ impl Segment { let tot_mul = self.gate_muls.len(); let tot_add = self.gate_adds.len(); let tot_const = self.gate_consts.len(); - for id in ids.iter() { - if *id < tot_mul { - gates.insert(UniGate::Mul(self.gate_muls[*id].clone())); - } else if *id < tot_mul + tot_add { - gates.insert(UniGate::Add(self.gate_adds[*id - tot_mul].clone())); - } else if *id < tot_mul + tot_add + tot_const { + for &id in ids.iter() { + if id < tot_mul { + gates.insert(UniGate::Mul(self.gate_muls[id].clone())); + } else if id < tot_mul + tot_add { + gates.insert(UniGate::Add(self.gate_adds[id - tot_mul].clone())); + } else if id < tot_mul + tot_add + tot_const { gates.insert(UniGate::Const( - self.gate_consts[*id - tot_mul - tot_add].clone(), + self.gate_consts[id - tot_mul - tot_add].clone(), )); } else { gates.insert(UniGate::Custom( - self.gate_customs[*id - tot_mul - tot_add - tot_const].clone(), + self.gate_customs[id - tot_mul - tot_add - tot_const].clone(), )); } } gates } - fn all_gates(&self) -> HashSet> { + fn all_gates(&self) -> HashSet> { let mut gates = HashSet::new(); for gate in self.gate_muls.iter() { gates.insert(UniGate::Mul(gate.clone())); @@ -293,7 +293,7 @@ impl Segment { + self.gate_customs.len() } - fn remove_gates(&mut self, gates: &HashSet>) { + fn remove_gates(&mut self, gates: &HashSet>) { let mut new_gates = Vec::new(); for gate in self.gate_muls.iter() { if !gates.contains(&UniGate::Mul(gate.clone())) { @@ -324,7 +324,7 @@ impl Segment { self.gate_customs = new_gates; } - fn from_uni_gates(gates: &HashSet>) -> Self { + fn from_uni_gates(gates: &HashSet>) -> Self { let mut gate_muls = Vec::new(); let mut gate_adds = Vec::new(); let mut gate_consts = Vec::new(); @@ -341,17 +341,23 @@ impl Segment { gate_adds.sort(); gate_consts.sort(); gate_customs.sort(); - let mut max_input = 0; + let mut max_input = Vec::new(); let mut max_output = 0; for gate in gate_muls.iter() { for input in gate.inputs.iter() { - max_input = max_input.max(*input); + while max_input.len() <= input.layer() { + max_input.push(0); + } + max_input[input.layer()] = max_input[input.layer()].max(input.offset()); } max_output = max_output.max(gate.output); } for gate in gate_adds.iter() { for input in gate.inputs.iter() { - max_input = max_input.max(*input); + while max_input.len() <= input.layer() { + max_input.push(0); + } + max_input[input.layer()] = max_input[input.layer()].max(input.offset()); } max_output = max_output.max(gate.output); } @@ -360,12 +366,19 @@ impl Segment { } for gate in gate_customs.iter() { for input in gate.inputs.iter() { - max_input = max_input.max(*input); + while max_input.len() <= input.layer() { + max_input.push(0); + } + max_input[input.layer()] = max_input[input.layer()].max(input.offset()); } max_output = max_output.max(gate.output); } + if max_input.is_empty() { + max_input.push(0); + } + let num_inputs_vec = max_input.iter().map(|x| next_power_of_two(x + 1)).collect(); Segment { - num_inputs: next_power_of_two(max_input + 1), + num_inputs: I::InputUsize::from_vec(num_inputs_vec), num_outputs: next_power_of_two(max_output + 1), gate_muls, gate_adds, @@ -376,17 +389,17 @@ impl Segment { } } -impl Circuit { +impl Circuit { pub fn dedup_gates(&mut self) { for segment in self.segments.iter_mut() { segment.dedup_gates(); } } - fn expand_gates, F: Fn(usize) -> bool, G: Fn(&Segment) -> &Vec>( + fn expand_gates, F: Fn(usize) -> bool, G: Fn(&Segment) -> &Vec>( &self, segment_id: usize, - prev_segments: &[Segment], + prev_segments: &[Segment], should_expand: F, get_gates: G, ) -> Vec { @@ -397,7 +410,7 @@ impl Circuit { let sub_segment = &prev_segments[*sub_segment_id]; let sub_gates = get_gates(sub_segment).clone(); for allocation in allocations.iter() { - let in_offset = allocation.input_offset; + let in_offset = &allocation.input_offset; let out_offset = allocation.output_offset; for gate in sub_gates.iter() { gates.push(gate.add_offset(in_offset, out_offset)); @@ -411,9 +424,9 @@ impl Circuit { fn expand_segment bool>( &self, segment_id: usize, - prev_segments: &[Segment], + prev_segments: &[Segment], should_expand: F, - ) -> Segment { + ) -> Segment { let segment = &self.segments[segment_id]; let gate_muls = self.expand_gates(segment_id, prev_segments, &should_expand, |s| &s.gate_muls); @@ -443,8 +456,14 @@ impl Circuit { } for sub_allocation in sub_allocations.iter() { for allocation in allocations.iter() { + let input_offset_vec = sub_allocation + .input_offset + .iter() + .zip(allocation.input_offset.iter()) + .map(|(x, y)| x + y) + .collect(); let new_allocation = Allocation { - input_offset: sub_allocation.input_offset + allocation.input_offset, + input_offset: I::InputUsize::from_vec(input_offset_vec), output_offset: sub_allocation.output_offset + allocation.output_offset, }; @@ -462,7 +481,7 @@ impl Circuit { } let child_segs = child_segs_map.into_iter().collect(); Segment { - num_inputs: segment.num_inputs, + num_inputs: segment.num_inputs.clone(), num_outputs: segment.num_outputs, gate_muls, gate_adds, @@ -537,7 +556,7 @@ impl Circuit { new_child_segs.push((new_id[sub_segment.0], sub_segment.1.clone())); } let mut seg = Segment { - num_inputs: segment.num_inputs, + num_inputs: segment.num_inputs.clone(), num_outputs: segment.num_outputs, gate_muls: segment.gate_muls.clone(), gate_adds: segment.gate_adds.clone(), @@ -565,12 +584,12 @@ impl Circuit { const COMMON_THRESHOLD_PERCENT: usize = 5; const COMMON_THRESHOLD_VALUE: usize = 10; let mut rng = rand::rngs::StdRng::seed_from_u64(123); //for deterministic - let sampled_gates: Vec>> = self + let sampled_gates: Vec>> = self .segments .iter() .map(|segment| segment.sample_gates(SAMPLE_PER_SEGMENT, &mut rng)) .collect(); - let all_gates: Vec>> = self + let all_gates: Vec>> = self .segments .iter() .map(|segment| segment.all_gates()) @@ -617,7 +636,7 @@ impl Circuit { if cnt < COMMON_THRESHOLD_VALUE { continue; } - let merged_gates: HashSet> = group_gates[x] + let merged_gates: HashSet> = group_gates[x] .intersection(&group_gates[y]) .cloned() .collect(); @@ -629,7 +648,7 @@ impl Circuit { size[uf.find(i)] += 1; } let mut rm_id: Vec> = vec![None; self.segments.len()]; - let mut new_segments: Vec> = Vec::new(); + let mut new_segments: Vec> = Vec::new(); let mut new_id = vec![!0; self.segments.len()]; for i in 0..self.segments.len() { if i == uf.find(i) && size[i] > 1 && group_gates[i].len() >= COMMON_THRESHOLD_VALUE { @@ -644,7 +663,7 @@ impl Circuit { new_child_segs.push((new_id[sub_segment.0], sub_segment.1.clone())); } let mut seg = Segment { - num_inputs: segment.num_inputs, + num_inputs: segment.num_inputs.clone(), num_outputs: segment.num_outputs, gate_muls: segment.gate_muls.clone(), gate_adds: segment.gate_adds.clone(), @@ -655,10 +674,11 @@ impl Circuit { let parent_id = uf.find(segment_id); if let Some(common_id) = rm_id[parent_id] { seg.remove_gates(&group_gates[parent_id]); + let common_seg = &new_segments[common_id]; seg.child_segs.push(( common_id, vec![Allocation { - input_offset: 0, + input_offset: I::InputUsize::from_vec(vec![0; common_seg.num_inputs.len()]), output_offset: 0, }], )); @@ -691,10 +711,14 @@ mod tests { utils::error::Error, }; + use super::{CrossLayerInputType, InputType, NormalInputType}; + type CField = ::CircuitField; - fn get_random_layered_circuit(rcc: &RandomCircuitConfig) -> Option> { - let root = ir::dest::RootCircuitRelaxed::::random(&rcc); + fn get_random_layered_circuit( + rcc: &RandomCircuitConfig, + ) -> Option> { + let root = ir::dest::RootCircuitRelaxed::::random(rcc); let mut root = root.export_constraints(); root.reassign_duplicate_sub_circuit_outputs(); let root = root.remove_unreachable().0; @@ -716,8 +740,7 @@ mod tests { Some(lc) } - #[test] - fn dedup_gates_random() { + fn dedup_gates_random_() { let mut config = RandomCircuitConfig { seed: 0, num_circuits: RandomRange { min: 1, max: 10 }, @@ -730,7 +753,7 @@ mod tests { }; for i in 0..3000 { config.seed = i + 400000; - let lc = match get_random_layered_circuit(&config) { + let lc = match get_random_layered_circuit::(&config) { Some(lc) => lc, None => { continue; @@ -753,7 +776,12 @@ mod tests { } #[test] - fn expand_small_segments_random() { + fn dedup_gates_random() { + dedup_gates_random_::(); + dedup_gates_random_::(); + } + + fn expand_small_segments_random_() { let mut config = RandomCircuitConfig { seed: 0, num_circuits: RandomRange { min: 1, max: 100 }, @@ -766,7 +794,7 @@ mod tests { }; for i in 0..3000 { config.seed = i + 500000; - let lc = match get_random_layered_circuit(&config) { + let lc = match get_random_layered_circuit::(&config) { Some(lc) => lc, None => { continue; @@ -788,7 +816,12 @@ mod tests { } #[test] - fn find_common_parts_random() { + fn expand_small_segments_random() { + expand_small_segments_random_::(); + expand_small_segments_random_::(); + } + + fn find_common_parts_random_() { let mut config = RandomCircuitConfig { seed: 0, num_circuits: RandomRange { min: 1, max: 100 }, @@ -801,7 +834,7 @@ mod tests { }; for i in 0..3000 { config.seed = i + 600000; - let lc = match get_random_layered_circuit(&config) { + let lc = match get_random_layered_circuit::(&config) { Some(lc) => lc, None => { continue; @@ -821,4 +854,10 @@ mod tests { } } } + + #[test] + fn find_common_parts_random() { + find_common_parts_random_::(); + find_common_parts_random_::(); + } } diff --git a/expander_compiler/src/circuit/layered/serde.rs b/expander_compiler/src/circuit/layered/serde.rs index 99776ce6..c32a793e 100644 --- a/expander_compiler/src/circuit/layered/serde.rs +++ b/expander_compiler/src/circuit/layered/serde.rs @@ -41,7 +41,53 @@ impl Serde for Coef { } } -impl Serde for Gate { +impl Serde for CrossLayerInput { + fn serialize_into(&self, mut writer: W) -> Result<(), IoError> { + self.layer.serialize_into(&mut writer)?; + self.offset.serialize_into(&mut writer)?; + Ok(()) + } + fn deserialize_from(mut reader: R) -> Result { + let layer = usize::deserialize_from(&mut reader)?; + let offset = usize::deserialize_from(&mut reader)?; + Ok(CrossLayerInput { layer, offset }) + } +} + +impl Serde for NormalInput { + fn serialize_into(&self, mut writer: W) -> Result<(), IoError> { + self.offset.serialize_into(&mut writer)?; + Ok(()) + } + fn deserialize_from(mut reader: R) -> Result { + let offset = usize::deserialize_from(&mut reader)?; + Ok(NormalInput { offset }) + } +} + +impl Serde for CrossLayerInputUsize { + fn serialize_into(&self, writer: W) -> Result<(), IoError> { + self.v.serialize_into(writer) + } + fn deserialize_from(reader: R) -> Result { + Ok(CrossLayerInputUsize { + v: Vec::::deserialize_from(reader)?, + }) + } +} + +impl Serde for NormalInputUsize { + fn serialize_into(&self, writer: W) -> Result<(), IoError> { + self.v.serialize_into(writer) + } + fn deserialize_from(reader: R) -> Result { + Ok(NormalInputUsize { + v: usize::deserialize_from(reader)?, + }) + } +} + +impl Serde for Gate { fn serialize_into(&self, mut writer: W) -> Result<(), IoError> { for input in &self.inputs { input.serialize_into(&mut writer)?; @@ -51,9 +97,9 @@ impl Serde for Gate { Ok(()) } fn deserialize_from(mut reader: R) -> Result { - let mut inputs = [0; INPUT_NUM]; + let mut inputs = [I::Input::default(); INPUT_NUM]; for input in inputs.iter_mut() { - *input = usize::deserialize_from(&mut reader)?; + *input = I::Input::deserialize_from(&mut reader)?; } let output = usize::deserialize_from(&mut reader)?; let coef = Coef::deserialize_from(&mut reader)?; @@ -65,14 +111,14 @@ impl Serde for Gate { } } -impl Serde for Allocation { +impl Serde for Allocation { fn serialize_into(&self, mut writer: W) -> Result<(), IoError> { self.input_offset.serialize_into(&mut writer)?; self.output_offset.serialize_into(&mut writer)?; Ok(()) } fn deserialize_from(mut reader: R) -> Result { - let input_offset = usize::deserialize_from(&mut reader)?; + let input_offset = I::InputUsize::deserialize_from(&mut reader)?; let output_offset = usize::deserialize_from(&mut reader)?; Ok(Allocation { input_offset, @@ -81,7 +127,7 @@ impl Serde for Allocation { } } -impl Serde for ChildSpec { +impl Serde for ChildSpec { fn serialize_into(&self, mut writer: W) -> Result<(), IoError> { self.0.serialize_into(&mut writer)?; self.1.serialize_into(&mut writer)?; @@ -89,12 +135,12 @@ impl Serde for ChildSpec { } fn deserialize_from(mut reader: R) -> Result { let sub_circuit_id = usize::deserialize_from(&mut reader)?; - let allocs = Vec::::deserialize_from(&mut reader)?; + let allocs = Vec::>::deserialize_from(&mut reader)?; Ok((sub_circuit_id, allocs)) } } -impl Serde for GateCustom { +impl Serde for GateCustom { fn serialize_into(&self, mut writer: W) -> Result<(), IoError> { self.gate_type.serialize_into(&mut writer)?; self.inputs.serialize_into(&mut writer)?; @@ -104,7 +150,7 @@ impl Serde for GateCustom { } fn deserialize_from(mut reader: R) -> Result { let gate_type = usize::deserialize_from(&mut reader)?; - let inputs = Vec::::deserialize_from(&mut reader)?; + let inputs = Vec::::deserialize_from(&mut reader)?; let output = usize::deserialize_from(&mut reader)?; let coef = Coef::::deserialize_from(&mut reader)?; Ok(GateCustom { @@ -116,7 +162,7 @@ impl Serde for GateCustom { } } -impl Serde for Segment { +impl Serde for Segment { fn serialize_into(&self, mut writer: W) -> Result<(), IoError> { self.num_inputs.serialize_into(&mut writer)?; self.num_outputs.serialize_into(&mut writer)?; @@ -128,13 +174,13 @@ impl Serde for Segment { Ok(()) } fn deserialize_from(mut reader: R) -> Result { - let num_inputs = usize::deserialize_from(&mut reader)?; + let num_inputs = I::InputUsize::deserialize_from(&mut reader)?; let num_outputs = usize::deserialize_from(&mut reader)?; - let child_segs = Vec::::deserialize_from(&mut reader)?; - let gate_muls = Vec::>::deserialize_from(&mut reader)?; - let gate_adds = Vec::>::deserialize_from(&mut reader)?; - let gate_consts = Vec::>::deserialize_from(&mut reader)?; - let gate_customs = Vec::>::deserialize_from(&mut reader)?; + let child_segs = Vec::>::deserialize_from(&mut reader)?; + let gate_muls = Vec::>::deserialize_from(&mut reader)?; + let gate_adds = Vec::>::deserialize_from(&mut reader)?; + let gate_consts = Vec::>::deserialize_from(&mut reader)?; + let gate_customs = Vec::>::deserialize_from(&mut reader)?; Ok(Segment { num_inputs, num_outputs, @@ -149,10 +195,10 @@ impl Serde for Segment { const MAGIC: usize = 3914834606642317635; -impl Serde for Circuit { +impl Serde for Circuit { fn serialize_into(&self, mut writer: W) -> Result<(), IoError> { MAGIC.serialize_into(&mut writer)?; - C::CircuitField::modulus().serialize_into(&mut writer)?; + C::CircuitField::MODULUS.serialize_into(&mut writer)?; self.num_public_inputs.serialize_into(&mut writer)?; self.num_actual_outputs.serialize_into(&mut writer)?; self.expected_num_output_zeroes @@ -170,7 +216,7 @@ impl Serde for Circuit { )); } let modulus = ethnum::U256::deserialize_from(&mut reader)?; - if modulus != C::CircuitField::modulus() { + if modulus != C::CircuitField::MODULUS { return Err(IoError::new( std::io::ErrorKind::InvalidData, "invalid modulus", @@ -179,7 +225,7 @@ impl Serde for Circuit { let num_public_inputs = usize::deserialize_from(&mut reader)?; let num_actual_outputs = usize::deserialize_from(&mut reader)?; let expected_num_output_zeroes = usize::deserialize_from(&mut reader)?; - let segments = Vec::>::deserialize_from(&mut reader)?; + let segments = Vec::>::deserialize_from(&mut reader)?; let layer_ids = Vec::::deserialize_from(&mut reader)?; Ok(Circuit { num_public_inputs, @@ -199,7 +245,7 @@ mod tests { ir::{common::rand_gen::*, dest::RootCircuit}, }; - fn test_serde_for_field() { + fn test_serde_for_field() { let mut config = RandomCircuitConfig { seed: 0, num_circuits: RandomRange { min: 1, max: 20 }, @@ -218,15 +264,18 @@ mod tests { assert_eq!(circuit.validate(), Ok(())); let mut buf = Vec::new(); circuit.serialize_into(&mut buf).unwrap(); - let circuit2 = Circuit::::deserialize_from(&buf[..]).unwrap(); + let circuit2 = Circuit::::deserialize_from(&buf[..]).unwrap(); assert_eq!(circuit, circuit2); } } #[test] fn test_serde() { - test_serde_for_field::(); - test_serde_for_field::(); - test_serde_for_field::(); + test_serde_for_field::(); + test_serde_for_field::(); + test_serde_for_field::(); + test_serde_for_field::(); + test_serde_for_field::(); + test_serde_for_field::(); } } diff --git a/expander_compiler/src/circuit/layered/stats.rs b/expander_compiler/src/circuit/layered/stats.rs index 4549fe1b..8528391f 100644 --- a/expander_compiler/src/circuit/layered/stats.rs +++ b/expander_compiler/src/circuit/layered/stats.rs @@ -1,6 +1,6 @@ use crate::circuit::config::Config; -use super::Circuit; +use super::{Circuit, InputType, InputUsize}; pub struct Stats { // number of layers in the final circuit @@ -31,7 +31,7 @@ struct CircuitStats { num_expanded_cst: usize, } -impl Circuit { +impl Circuit { pub fn get_stats(&self) -> Stats { let mut m: Vec = Vec::with_capacity(self.segments.len()); let mut ar = Stats { @@ -83,12 +83,20 @@ impl Circuit { } } } - for i in 0..self.segments[self.layer_ids[0]].num_inputs { - if input_mask[self.layer_ids[0]][i] { + let mut global_input_mask = vec![false; self.input_size()]; + for (l, &id) in self.layer_ids.iter().enumerate() { + if self.segments[id].num_inputs.len() > l { + for (g, i) in global_input_mask.iter_mut().zip(input_mask[id][l].iter()) { + *g |= *i; + } + } + } + for x in global_input_mask.iter() { + if *x { ar.num_inputs += 1; } } - ar.total_cost = self.segments[self.layer_ids[0]].num_inputs * C::COST_INPUT; + ar.total_cost = self.input_size() * C::COST_INPUT; ar.total_cost += ar.num_total_gates * C::COST_VARIABLE; ar.total_cost += ar.num_expanded_mul * C::COST_MUL; ar.total_cost += ar.num_expanded_add * C::COST_ADD; diff --git a/expander_compiler/src/circuit/layered/tests.rs b/expander_compiler/src/circuit/layered/tests.rs index c0360537..434eb454 100644 --- a/expander_compiler/src/circuit/layered/tests.rs +++ b/expander_compiler/src/circuit/layered/tests.rs @@ -1,22 +1,25 @@ +use std::vec; + use super::{Allocation, Circuit, Coef, GateAdd, GateConst, GateMul, Segment}; use crate::circuit::config::{Config, M31Config as C}; +use crate::circuit::layered::{NormalInput, NormalInputType, NormalInputUsize}; use crate::field::FieldArith; type CField = ::CircuitField; #[test] fn simple() { - let circuit: Circuit = Circuit { + let circuit: Circuit = Circuit { num_public_inputs: 0, num_actual_outputs: 2, expected_num_output_zeroes: 0, segments: vec![ Segment { - num_inputs: 2, + num_inputs: NormalInputUsize { v: 2 }, num_outputs: 1, child_segs: vec![], gate_muls: vec![GateMul { - inputs: [0, 1], + inputs: [NormalInput { offset: 0 }, NormalInput { offset: 1 }], output: 0, coef: Coef::Constant(CField::from(2)), }], @@ -25,17 +28,17 @@ fn simple() { gate_customs: vec![], }, Segment { - num_inputs: 4, + num_inputs: NormalInputUsize { v: 4 }, num_outputs: 2, child_segs: vec![( 0, vec![ Allocation { - input_offset: 0, + input_offset: NormalInputUsize { v: 0 }, output_offset: 0, }, Allocation { - input_offset: 2, + input_offset: NormalInputUsize { v: 2 }, output_offset: 1, }, ], @@ -46,24 +49,24 @@ fn simple() { gate_customs: vec![], }, Segment { - num_inputs: 2, + num_inputs: NormalInputUsize { v: 2 }, num_outputs: 2, child_segs: vec![( 0, vec![Allocation { - input_offset: 0, + input_offset: NormalInputUsize { v: 0 }, output_offset: 0, }], )], gate_muls: vec![], gate_adds: vec![ GateAdd { - inputs: [0], + inputs: [NormalInput { offset: 0 }], output: 1, coef: Coef::Constant(CField::from(3)), }, GateAdd { - inputs: [1], + inputs: [NormalInput { offset: 1 }], output: 1, coef: Coef::Constant(CField::from(4)), }, diff --git a/expander_compiler/src/circuit/layered/witness.rs b/expander_compiler/src/circuit/layered/witness.rs index 8e955c7a..ded1eaf5 100644 --- a/expander_compiler/src/circuit/layered/witness.rs +++ b/expander_compiler/src/circuit/layered/witness.rs @@ -9,7 +9,7 @@ pub struct Witness { pub values: Vec, } -impl Circuit { +impl Circuit { pub fn run(&self, witness: &Witness) -> Vec { if witness.num_witnesses == 0 { panic!("expected at least 1 witness") @@ -33,18 +33,18 @@ impl Witness { where T: arith::SimdField, { - match self.num_witnesses.cmp(&T::pack_size()) { + match self.num_witnesses.cmp(&T::PACK_SIZE) { std::cmp::Ordering::Less => { println!( "Warning: not enough witnesses, expect {}, got {}", - T::pack_size(), + T::PACK_SIZE, self.num_witnesses ) } std::cmp::Ordering::Greater => { println!( "Warning: dropping additional witnesses, expect {}, got {}", - T::pack_size(), + T::PACK_SIZE, self.num_witnesses ) } @@ -55,10 +55,10 @@ impl Witness { let mut res = Vec::with_capacity(ni); let mut res_public = Vec::with_capacity(np); for i in 0..ni + np { - let mut values: Vec = (0..self.num_witnesses.min(T::pack_size())) + let mut values: Vec = (0..self.num_witnesses.min(T::PACK_SIZE)) .map(|j| self.values[j * (ni + np) + i]) .collect(); - values.resize(T::pack_size(), C::CircuitField::zero()); + values.resize(T::PACK_SIZE, C::CircuitField::zero()); let simd_value = T::pack(&values); if i < ni { res.push(simd_value); @@ -76,7 +76,7 @@ impl Serde for Witness { let num_inputs_per_witness = usize::deserialize_from(&mut reader)?; let num_public_inputs_per_witness = usize::deserialize_from(&mut reader)?; let modulus = ethnum::U256::deserialize_from(&mut reader)?; - if modulus != C::CircuitField::modulus() { + if modulus != C::CircuitField::MODULUS { return Err(std::io::Error::new( std::io::ErrorKind::InvalidData, "invalid modulus", @@ -100,7 +100,7 @@ impl Serde for Witness { self.num_inputs_per_witness.serialize_into(&mut writer)?; self.num_public_inputs_per_witness .serialize_into(&mut writer)?; - C::CircuitField::modulus().serialize_into(&mut writer)?; + C::CircuitField::MODULUS.serialize_into(&mut writer)?; for v in &self.values { v.serialize_into(&mut writer)?; } diff --git a/expander_compiler/src/compile/mod.rs b/expander_compiler/src/compile/mod.rs index a3fa6a06..848a4a7f 100644 --- a/expander_compiler/src/compile/mod.rs +++ b/expander_compiler/src/compile/mod.rs @@ -1,6 +1,11 @@ use crate::{ builder, - circuit::{config::Config, input_mapping::InputMapping, ir, layered}, + circuit::{ + config::Config, + input_mapping::InputMapping, + ir, + layered::{self, InputType}, + }, layering, utils::error::Error, }; @@ -10,6 +15,18 @@ mod random_circuit_tests; #[cfg(test)] mod tests; +#[derive(Default)] +pub struct CompileOptions { + pub mul_fanout_limit: Option, +} + +impl CompileOptions { + pub fn with_mul_fanout_limit(mut self, mul_fanout_limit: usize) -> Self { + self.mul_fanout_limit = Some(mul_fanout_limit); + self + } +} + fn optimize_until_fixed_point(x: &T, im: &mut InputMapping, f: F) -> T where T: Clone + Eq, @@ -47,9 +64,16 @@ fn print_stat(stat_name: &str, stat: usize, is_last: bool) { } } -pub fn compile( +pub fn compile( + r_source: &ir::source::RootCircuit, +) -> Result<(ir::hint_normalized::RootCircuit, layered::Circuit), Error> { + compile_with_options(r_source, CompileOptions::default()) +} + +pub fn compile_with_options( r_source: &ir::source::RootCircuit, -) -> Result<(ir::hint_normalized::RootCircuit, layered::Circuit), Error> { + options: CompileOptions, +) -> Result<(ir::hint_normalized::RootCircuit, layered::Circuit), Error> { r_source.validate()?; let mut src_im = InputMapping::new_identity(r_source.input_size()); @@ -114,6 +138,15 @@ pub fn compile( .validate() .map_err(|e| e.prepend("dest relaxed ir circuit invalid"))?; + let r_dest_relaxed_opt = if let Some(limit) = options.mul_fanout_limit { + r_dest_relaxed_opt.solve_mul_fanout_limit(limit) + } else { + r_dest_relaxed_opt + }; + r_dest_relaxed_opt + .validate() + .map_err(|e| e.prepend("dest relaxed ir circuit invalid"))?; + let r_dest_relaxed_p2 = if C::ENABLE_RANDOM_COMBINATION { r_dest_relaxed_opt } else { @@ -127,18 +160,21 @@ pub fn compile( .validate() .map_err(|e| e.prepend("dest relaxed ir circuit invalid"))?; - let r_dest_relaxed_p3 = layering::ir_split::split_to_single_layer(&r_dest_relaxed_p2); - r_dest_relaxed_p3 - .validate() - .map_err(|e| e.prepend("dest relaxed ir circuit invalid"))?; - - let r_dest_relaxed_p3_opt = optimize_until_fixed_point(&r_dest_relaxed_p3, &mut hl_im, |r| { - let (mut r, im) = r.remove_unreachable(); - r.reassign_duplicate_sub_circuit_outputs(); - (r, im) - }); + let r_dest_relaxed_p3 = if I::CROSS_LAYER_RELAY { + r_dest_relaxed_p2 + } else { + let r = layering::ir_split::split_to_single_layer(&r_dest_relaxed_p2); + r.validate() + .map_err(|e| e.prepend("dest relaxed ir circuit invalid"))?; + + optimize_until_fixed_point(&r, &mut hl_im, |r| { + let (mut r, im) = r.remove_unreachable(); + r.reassign_duplicate_sub_circuit_outputs(); + (r, im) + }) + }; - let r_dest = r_dest_relaxed_p3_opt.solve_duplicates(); + let r_dest = r_dest_relaxed_p3.solve_duplicates(); let r_dest_opt = optimize_until_fixed_point(&r_dest, &mut hl_im, |r| { let (mut r, im) = r.remove_unreachable(); diff --git a/expander_compiler/src/compile/random_circuit_tests.rs b/expander_compiler/src/compile/random_circuit_tests.rs index 74cca80e..b4bae12c 100644 --- a/expander_compiler/src/compile/random_circuit_tests.rs +++ b/expander_compiler/src/compile/random_circuit_tests.rs @@ -5,18 +5,19 @@ use crate::{ common::rand_gen::{RandomCircuitConfig, RandomRange}, source::RootCircuit as IrSourceRoot, }, + layered::{CrossLayerInputType, InputType, NormalInputType}, }, compile::compile, field::FieldArith, utils::error::Error, }; -fn do_test(mut config: RandomCircuitConfig, seed: RandomRange) { +fn do_test(mut config: RandomCircuitConfig, seed: RandomRange) { for i in seed.min..seed.max { config.seed = i; let root = IrSourceRoot::::random(&config); assert_eq!(root.validate(), Ok(())); - let res = compile(&root); + let res = compile::<_, I>(&root); match res { Ok((ir_hint_normalized, layered_circuit)) => { assert_eq!(ir_hint_normalized.validate(), Ok(())); @@ -60,7 +61,7 @@ fn do_test(mut config: RandomCircuitConfig, seed: RandomRange) { } } -fn do_tests(seed: usize) { +fn do_tests(seed: usize) { let config = RandomCircuitConfig { seed: 0, num_circuits: RandomRange { min: 1, max: 10 }, @@ -71,7 +72,7 @@ fn do_tests(seed: usize) { num_terms: RandomRange { min: 1, max: 5 }, sub_circuit_prob: 0.5, }; - do_test::( + do_test::( config, RandomRange { min: 100000 + seed, @@ -88,7 +89,7 @@ fn do_tests(seed: usize) { num_terms: RandomRange { min: 1, max: 5 }, sub_circuit_prob: 0.05, }; - do_test::( + do_test::( config, RandomRange { min: 200000 + seed, @@ -99,21 +100,35 @@ fn do_tests(seed: usize) { #[test] fn test_m31() { - do_tests::(1000000); + do_tests::(1000000); } #[test] fn test_bn254() { - do_tests::(2000000); + do_tests::(2000000); } #[test] fn test_gf2() { - do_tests::(3000000); + do_tests::(3000000); } #[test] -fn deterministic() { +fn test_m31_cross() { + do_tests::(4000000); +} + +#[test] +fn test_bn254_cross() { + do_tests::(5000000); +} + +#[test] +fn test_gf2_cross() { + do_tests::(6000000); +} + +fn deterministic_() { let mut config = RandomCircuitConfig { seed: 0, num_circuits: RandomRange { min: 1, max: 10 }, @@ -128,8 +143,8 @@ fn deterministic() { config.seed = i; let root = IrSourceRoot::::random(&config); assert_eq!(root.validate(), Ok(())); - let res = compile(&root); - let res2 = compile(&root); + let res = compile::<_, I>(&root); + let res2 = compile::<_, I>(&root); match (res, res2) { ( Ok((ir_hint_normalized, layered_circuit)), @@ -157,3 +172,13 @@ fn deterministic() { } } } + +#[test] +fn deterministic_normal() { + deterministic_::(); +} + +#[test] +fn deterministic_cross() { + deterministic_::(); +} diff --git a/expander_compiler/src/compile/tests.rs b/expander_compiler/src/compile/tests.rs index 224de57f..db4d327e 100644 --- a/expander_compiler/src/compile/tests.rs +++ b/expander_compiler/src/compile/tests.rs @@ -1,6 +1,7 @@ use crate::circuit::{ config::{Config, M31Config as C}, ir, + layered::NormalInputType, }; type CField = ::CircuitField; @@ -25,7 +26,7 @@ fn simple_div() { }, ); assert_eq!(root.validate(), Ok(())); - let (input_solver, lc) = super::compile(&root).unwrap(); + let (input_solver, lc) = super::compile::<_, NormalInputType>(&root).unwrap(); assert_eq!(input_solver.circuits[&0].outputs.len(), 4); let (o, cond) = lc.eval_unsafe(vec![ CField::from(2), diff --git a/expander_compiler/src/frontend/api.rs b/expander_compiler/src/frontend/api.rs index d0bad08d..7a9f2688 100644 --- a/expander_compiler/src/frontend/api.rs +++ b/expander_compiler/src/frontend/api.rs @@ -19,6 +19,8 @@ pub trait BasicAPI { binary_op!(xor); binary_op!(or); binary_op!(and); + + fn display(&self, _label: &str, _x: impl ToVariableOrValue) {} fn div( &mut self, x: impl ToVariableOrValue, @@ -26,7 +28,9 @@ pub trait BasicAPI { checked: bool, ) -> Variable; fn neg(&mut self, x: impl ToVariableOrValue) -> Variable; - fn inverse(&mut self, x: impl ToVariableOrValue) -> Variable; + fn inverse(&mut self, x: impl ToVariableOrValue) -> Variable { + self.div(1, x, true) + } fn is_zero(&mut self, x: impl ToVariableOrValue) -> Variable; fn assert_is_zero(&mut self, x: impl ToVariableOrValue); fn assert_is_non_zero(&mut self, x: impl ToVariableOrValue); @@ -35,13 +39,32 @@ pub trait BasicAPI { &mut self, x: impl ToVariableOrValue, y: impl ToVariableOrValue, - ); + ) { + let diff = self.sub(x, y); + self.assert_is_zero(diff); + } fn assert_is_different( &mut self, x: impl ToVariableOrValue, y: impl ToVariableOrValue, - ); + ) { + let diff = self.sub(x, y); + self.assert_is_non_zero(diff); + } fn get_random_value(&mut self) -> Variable; + fn new_hint( + &mut self, + hint_key: &str, + inputs: &[Variable], + num_outputs: usize, + ) -> Vec; + fn constant(&mut self, x: impl ToVariableOrValue) -> Variable; + // try to get the value of a compile-time constant variable + // this function has different behavior in normal and debug mode, in debug mode it always returns Some(value) + fn constant_value( + &mut self, + x: impl ToVariableOrValue, + ) -> Option; } pub trait UnconstrainedAPI { @@ -66,3 +89,11 @@ pub trait UnconstrainedAPI { binary_op!(unconstrained_bit_and); binary_op!(unconstrained_bit_xor); } + +pub trait RootAPI: Sized + BasicAPI + UnconstrainedAPI + 'static { + fn memorized_simple_call) -> Vec + 'static>( + &mut self, + f: F, + inputs: &[Variable], + ) -> Vec; +} diff --git a/expander_compiler/src/frontend/builder.rs b/expander_compiler/src/frontend/builder.rs index 220927d3..bc92c972 100644 --- a/expander_compiler/src/frontend/builder.rs +++ b/expander_compiler/src/frontend/builder.rs @@ -13,16 +13,17 @@ use crate::{ layered::Coef, }, field::{Field, FieldArith}, - hints, + hints::{self, registry::hint_key_to_id}, utils::function_id::get_function_id, }; -use super::api::{BasicAPI, UnconstrainedAPI}; +use super::api::{BasicAPI, RootAPI, UnconstrainedAPI}; pub struct Builder { instructions: Vec>, constraints: Vec, - var_max: usize, + var_const_id: Vec, + const_values: Vec, num_inputs: usize, } @@ -31,12 +32,26 @@ pub struct Variable { id: usize, } +impl Variable { + pub fn id(&self) -> usize { + self.id + } +} + +pub fn new_variable(id: usize) -> Variable { + Variable { id } +} + +pub fn get_variable_id(v: Variable) -> usize { + v.id +} + pub enum VariableOrValue { Variable(Variable), Value(F), } -pub trait ToVariableOrValue { +pub trait ToVariableOrValue: Clone { fn convert_to_variable_or_value(self) -> VariableOrValue; } @@ -45,7 +60,7 @@ impl NotVariable for u32 {} impl NotVariable for U256 {} impl NotVariable for F {} -impl + NotVariable> ToVariableOrValue for T { +impl + NotVariable + Clone> ToVariableOrValue for T { fn convert_to_variable_or_value(self) -> VariableOrValue { VariableOrValue::Value(self.into()) } @@ -69,8 +84,9 @@ impl Builder { Builder { instructions: Vec::new(), constraints: Vec::new(), - var_max: num_inputs, num_inputs, + var_const_id: vec![0; num_inputs + 1], + const_values: vec![C::CircuitField::zero()], }, (1..=num_inputs).map(|id| Variable { id }).collect(), ) @@ -91,15 +107,20 @@ impl Builder { VariableOrValue::Value(v) => { self.instructions .push(SourceInstruction::ConstantLike(Coef::Constant(v))); - self.var_max += 1; - Variable { id: self.var_max } + self.var_const_id.push(self.const_values.len()); + self.const_values.push(v); + Variable { + id: self.var_const_id.len() - 1, + } } } } fn new_var(&mut self) -> Variable { - self.var_max += 1; - Variable { id: self.var_max } + self.var_const_id.push(0); + Variable { + id: self.var_const_id.len() - 1, + } } } @@ -109,6 +130,13 @@ impl BasicAPI for Builder { x: impl ToVariableOrValue, y: impl ToVariableOrValue, ) -> Variable { + let xc = self.constant_value(x.clone()); + let yc = self.constant_value(y.clone()); + if let Some(xv) = xc { + if let Some(yv) = yc { + return self.constant(xv + yv); + } + } let x = self.convert_to_variable(x); let y = self.convert_to_variable(y); self.instructions.push(SourceInstruction::LinComb(LinComb { @@ -132,6 +160,13 @@ impl BasicAPI for Builder { x: impl ToVariableOrValue, y: impl ToVariableOrValue, ) -> Variable { + let xc = self.constant_value(x.clone()); + let yc = self.constant_value(y.clone()); + if let Some(xv) = xc { + if let Some(yv) = yc { + return self.constant(xv - yv); + } + } let x = self.convert_to_variable(x); let y = self.convert_to_variable(y); self.instructions.push(SourceInstruction::LinComb(LinComb { @@ -151,6 +186,10 @@ impl BasicAPI for Builder { } fn neg(&mut self, x: impl ToVariableOrValue) -> Variable { + let xc = self.constant_value(x.clone()); + if let Some(xv) = xc { + return self.constant(-xv); + } let x = self.convert_to_variable(x); self.instructions.push(SourceInstruction::LinComb(LinComb { terms: vec![LinCombTerm { @@ -167,6 +206,13 @@ impl BasicAPI for Builder { x: impl ToVariableOrValue, y: impl ToVariableOrValue, ) -> Variable { + let xc = self.constant_value(x.clone()); + let yc = self.constant_value(y.clone()); + if let Some(xv) = xc { + if let Some(yv) = yc { + return self.constant(xv * yv); + } + } let x = self.convert_to_variable(x); let y = self.convert_to_variable(y); self.instructions @@ -180,6 +226,21 @@ impl BasicAPI for Builder { y: impl ToVariableOrValue, checked: bool, ) -> Variable { + let xc = self.constant_value(x.clone()); + let yc = self.constant_value(y.clone()); + if let Some(xv) = xc { + if let Some(yv) = yc { + let res = if yv.is_zero() { + if checked || !xv.is_zero() { + panic!("division by zero"); + } + C::CircuitField::zero() + } else { + xv * yv.inv().unwrap() + }; + return self.constant(res); + } + } let x = self.convert_to_variable(x); let y = self.convert_to_variable(y); self.instructions.push(SourceInstruction::Div { @@ -190,15 +251,20 @@ impl BasicAPI for Builder { self.new_var() } - fn inverse(&mut self, x: impl ToVariableOrValue) -> Variable { - self.div(1, x, true) - } - fn xor( &mut self, x: impl ToVariableOrValue, y: impl ToVariableOrValue, ) -> Variable { + let xc = self.constant_value(x.clone()); + let yc = self.constant_value(y.clone()); + if let Some(xv) = xc { + if let Some(yv) = yc { + self.assert_is_bool(xv); + self.assert_is_bool(yv); + return self.constant(C::CircuitField::from((xv != yv) as u32)); + } + } let x = self.convert_to_variable(x); let y = self.convert_to_variable(y); self.instructions.push(SourceInstruction::BoolBinOp { @@ -214,6 +280,17 @@ impl BasicAPI for Builder { x: impl ToVariableOrValue, y: impl ToVariableOrValue, ) -> Variable { + let xc = self.constant_value(x.clone()); + let yc = self.constant_value(y.clone()); + if let Some(xv) = xc { + if let Some(yv) = yc { + self.assert_is_bool(xv); + self.assert_is_bool(yv); + return self.constant(C::CircuitField::from( + (!xv.is_zero() || !yv.is_zero()) as u32, + )); + } + } let x = self.convert_to_variable(x); let y = self.convert_to_variable(y); self.instructions.push(SourceInstruction::BoolBinOp { @@ -229,6 +306,17 @@ impl BasicAPI for Builder { x: impl ToVariableOrValue, y: impl ToVariableOrValue, ) -> Variable { + let xc = self.constant_value(x.clone()); + let yc = self.constant_value(y.clone()); + if let Some(xv) = xc { + if let Some(yv) = yc { + self.assert_is_bool(xv); + self.assert_is_bool(yv); + return self.constant(C::CircuitField::from( + (!xv.is_zero() && !yv.is_zero()) as u32, + )); + } + } let x = self.convert_to_variable(x); let y = self.convert_to_variable(y); self.instructions.push(SourceInstruction::BoolBinOp { @@ -240,12 +328,22 @@ impl BasicAPI for Builder { } fn is_zero(&mut self, x: impl ToVariableOrValue) -> Variable { + let xc = self.constant_value(x.clone()); + if let Some(xv) = xc { + return self.constant(C::CircuitField::from(xv.is_zero() as u32)); + } let x = self.convert_to_variable(x); self.instructions.push(SourceInstruction::IsZero(x.id)); self.new_var() } fn assert_is_zero(&mut self, x: impl ToVariableOrValue) { + let xc = self.constant_value(x.clone()); + if let Some(xv) = xc { + if !xv.is_zero() { + panic!("assert_is_zero failed"); + } + } let x = self.convert_to_variable(x); self.constraints.push(SourceConstraint { typ: source::ConstraintType::Zero, @@ -254,6 +352,12 @@ impl BasicAPI for Builder { } fn assert_is_non_zero(&mut self, x: impl ToVariableOrValue) { + let xc = self.constant_value(x.clone()); + if let Some(xv) = xc { + if xv.is_zero() { + panic!("assert_is_zero failed"); + } + } let x = self.convert_to_variable(x); self.constraints.push(SourceConstraint { typ: source::ConstraintType::NonZero, @@ -262,6 +366,12 @@ impl BasicAPI for Builder { } fn assert_is_bool(&mut self, x: impl ToVariableOrValue) { + let xc = self.constant_value(x.clone()); + if let Some(xv) = xc { + if !xv.is_zero() && xv != C::CircuitField::one() { + panic!("assert_is_bool failed"); + } + } let x = self.convert_to_variable(x); self.constraints.push(SourceConstraint { typ: source::ConstraintType::Bool, @@ -269,28 +379,45 @@ impl BasicAPI for Builder { }); } - fn assert_is_equal( - &mut self, - x: impl ToVariableOrValue, - y: impl ToVariableOrValue, - ) { - let diff = self.sub(x, y); - self.assert_is_zero(diff); + fn get_random_value(&mut self) -> Variable { + self.instructions + .push(SourceInstruction::ConstantLike(Coef::Random)); + self.new_var() } - fn assert_is_different( + fn new_hint( &mut self, - x: impl ToVariableOrValue, - y: impl ToVariableOrValue, - ) { - let diff = self.sub(x, y); - self.assert_is_non_zero(diff); + hint_key: &str, + inputs: &[Variable], + num_outputs: usize, + ) -> Vec { + self.instructions.push(SourceInstruction::Hint { + hint_id: hint_key_to_id(hint_key), + inputs: inputs.iter().map(|v| v.id).collect(), + num_outputs, + }); + (0..num_outputs).map(|_| self.new_var()).collect() } - fn get_random_value(&mut self) -> Variable { - self.instructions - .push(SourceInstruction::ConstantLike(Coef::Random)); - self.new_var() + fn constant(&mut self, value: impl ToVariableOrValue) -> Variable { + self.convert_to_variable(value) + } + + fn constant_value( + &mut self, + x: impl ToVariableOrValue<::CircuitField>, + ) -> Option<::CircuitField> { + match x.convert_to_variable_or_value() { + VariableOrValue::Variable(v) => { + let t = self.var_const_id[v.id]; + if t != 0 { + Some(self.const_values[t]) + } else { + None + } + } + VariableOrValue::Value(v) => Some(v), + } } } @@ -406,10 +533,6 @@ impl BasicAPI for RootBuilder { self.last_builder().div(x, y, checked) } - fn inverse(&mut self, x: impl ToVariableOrValue) -> Variable { - self.last_builder().inverse(x) - } - fn is_zero(&mut self, x: impl ToVariableOrValue) -> Variable { self.last_builder().is_zero(x) } @@ -426,24 +549,54 @@ impl BasicAPI for RootBuilder { self.last_builder().assert_is_bool(x) } - fn assert_is_equal( + fn get_random_value(&mut self) -> Variable { + self.last_builder().get_random_value() + } + + fn new_hint( &mut self, - x: impl ToVariableOrValue, - y: impl ToVariableOrValue, - ) { - self.last_builder().assert_is_equal(x, y) + hint_key: &str, + inputs: &[Variable], + num_outputs: usize, + ) -> Vec { + self.last_builder().new_hint(hint_key, inputs, num_outputs) + } + + fn constant(&mut self, x: impl ToVariableOrValue<::CircuitField>) -> Variable { + self.last_builder().constant(x) } - fn assert_is_different( + fn constant_value( &mut self, - x: impl ToVariableOrValue, - y: impl ToVariableOrValue, - ) { - self.last_builder().assert_is_different(x, y) + x: impl ToVariableOrValue<::CircuitField>, + ) -> Option<::CircuitField> { + self.last_builder().constant_value(x) } +} - fn get_random_value(&mut self) -> Variable { - self.last_builder().get_random_value() +impl RootAPI for RootBuilder { + fn memorized_simple_call) -> Vec + 'static>( + &mut self, + f: F, + inputs: &[Variable], + ) -> Vec { + let mut hasher = tiny_keccak::Keccak::v256(); + hasher.update(b"simple"); + hasher.update(&inputs.len().to_le_bytes()); + hasher.update(&get_function_id::().to_le_bytes()); + let mut hash = [0u8; 32]; + hasher.finalize(&mut hash); + + let circuit_id = usize::from_le_bytes(hash[0..8].try_into().unwrap()); + if let Some(prev_hash) = self.full_hash_id.get(&circuit_id) { + if *prev_hash != hash { + panic!("subcircuit id collision"); + } + } else { + self.full_hash_id.insert(circuit_id, hash); + } + + self.call_sub_circuit(circuit_id, inputs, f) } } @@ -484,7 +637,7 @@ impl RootBuilder { } } - fn last_builder(&mut self) -> &mut Builder { + pub fn last_builder(&mut self) -> &mut Builder { &mut self.current_builders.last_mut().unwrap().1 } @@ -524,34 +677,6 @@ impl RootBuilder { }); outputs } - - pub fn memorized_simple_call) -> Vec + 'static>( - &mut self, - f: F, - inputs: &[Variable], - ) -> Vec { - let mut hasher = tiny_keccak::Keccak::v256(); - hasher.update(b"simple"); - hasher.update(&inputs.len().to_le_bytes()); - hasher.update(&get_function_id::().to_le_bytes()); - let mut hash = [0u8; 32]; - hasher.finalize(&mut hash); - - let circuit_id = usize::from_le_bytes(hash[0..8].try_into().unwrap()); - if let Some(prev_hash) = self.full_hash_id.get(&circuit_id) { - if *prev_hash != hash { - panic!("subcircuit id collision"); - } - } else { - self.full_hash_id.insert(circuit_id, hash); - } - - self.call_sub_circuit(circuit_id, inputs, f) - } - - pub fn constant>(&mut self, value: T) -> Variable { - self.last_builder().convert_to_variable(value) - } } impl UnconstrainedAPI for RootBuilder { diff --git a/expander_compiler/src/frontend/circuit.rs b/expander_compiler/src/frontend/circuit.rs index 6f9c65d7..90cd7d19 100644 --- a/expander_compiler/src/frontend/circuit.rs +++ b/expander_compiler/src/frontend/circuit.rs @@ -164,7 +164,12 @@ pub use declare_circuit_num_vars; use crate::circuit::config::Config; +use super::api::RootAPI; use super::builder::RootBuilder; pub trait Define { fn define(&self, api: &mut RootBuilder); } + +pub trait GenericDefine { + fn define>(&self, api: &mut Builder); +} diff --git a/expander_compiler/src/frontend/debug.rs b/expander_compiler/src/frontend/debug.rs new file mode 100644 index 00000000..0b97b111 --- /dev/null +++ b/expander_compiler/src/frontend/debug.rs @@ -0,0 +1,464 @@ +use crate::{ + circuit::{ + config::Config, + ir::{ + common::{EvalResult, Instruction}, + source::{BoolBinOpType, Instruction as IrInstruction, UnconstrainedBinOpType}, + }, + }, + field::FieldArith, + hints::registry::{hint_key_to_id, HintCaller}, +}; + +use super::{ + api::{BasicAPI, RootAPI, UnconstrainedAPI}, + builder::{get_variable_id, new_variable, ToVariableOrValue, VariableOrValue}, + Variable, +}; + +pub struct DebugBuilder> { + values: Vec, + hint_caller: H, +} + +impl> BasicAPI for DebugBuilder { + fn display(&self, str: &str, x: impl ToVariableOrValue<::CircuitField>) { + let x = self.convert_to_value(x); + println!("{}: {:?}", str, x); + } + + fn add( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_value(x); + let y = self.convert_to_value(y); + self.return_as_variable(x + y) + } + fn sub( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_value(x); + let y = self.convert_to_value(y); + self.return_as_variable(x - y) + } + fn mul( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_value(x); + let y = self.convert_to_value(y); + self.return_as_variable(x * y) + } + fn xor( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::BoolBinOp { + x, + y, + op: BoolBinOpType::Xor, + }) + } + fn or( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::BoolBinOp { + x, + y, + op: BoolBinOpType::Or, + }) + } + fn and( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::BoolBinOp { + x, + y, + op: BoolBinOpType::And, + }) + } + fn div( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + checked: bool, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::Div { x, y, checked }) + } + fn neg(&mut self, x: impl ToVariableOrValue) -> Variable { + let x = self.convert_to_value(x); + self.return_as_variable(-x) + } + fn is_zero(&mut self, x: impl ToVariableOrValue) -> Variable { + let x = self.convert_to_id(x); + self.eval_ir_insn(IrInstruction::IsZero(x)) + } + fn assert_is_zero(&mut self, x: impl ToVariableOrValue) { + let x = self.convert_to_value(x); + assert!(x.is_zero()); + } + fn assert_is_non_zero(&mut self, x: impl ToVariableOrValue) { + let x = self.convert_to_value(x); + assert!(!x.is_zero()); + } + fn assert_is_bool(&mut self, x: impl ToVariableOrValue) { + let x = self.convert_to_value(x); + assert!(x.is_zero() || x == C::CircuitField::one()); + } + fn get_random_value(&mut self) -> Variable { + let v = C::CircuitField::random_unsafe(&mut rand::thread_rng()); + self.return_as_variable(v) + } + fn new_hint( + &mut self, + hint_key: &str, + inputs: &[Variable], + num_outputs: usize, + ) -> Vec { + let inputs: Vec = + inputs.iter().map(|v| self.convert_to_value(v)).collect(); + match self + .hint_caller + .call(hint_key_to_id(hint_key), &inputs, num_outputs) + { + Ok(outputs) => outputs + .into_iter() + .map(|v| self.return_as_variable(v)) + .collect(), + Err(e) => panic!("Hint error: {:?}", e), + } + } + fn constant(&mut self, x: impl ToVariableOrValue<::CircuitField>) -> Variable { + let x = self.convert_to_value(x); + self.return_as_variable(x) + } + fn constant_value( + &mut self, + x: impl ToVariableOrValue<::CircuitField>, + ) -> Option<::CircuitField> { + Some(self.convert_to_value(x)) + } +} + +impl> UnconstrainedAPI for DebugBuilder { + fn unconstrained_identity(&mut self, x: impl ToVariableOrValue) -> Variable { + self.constant(x) + } + fn unconstrained_add( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + self.add(x, y) + } + fn unconstrained_mul( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + self.mul(x, y) + } + fn unconstrained_div( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::Div, + }) + } + fn unconstrained_pow( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::Pow, + }) + } + fn unconstrained_int_div( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::IntDiv, + }) + } + fn unconstrained_mod( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::Mod, + }) + } + fn unconstrained_shift_l( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::ShiftL, + }) + } + fn unconstrained_shift_r( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::ShiftR, + }) + } + fn unconstrained_lesser_eq( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::LesserEq, + }) + } + fn unconstrained_greater_eq( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::GreaterEq, + }) + } + fn unconstrained_lesser( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::Lesser, + }) + } + fn unconstrained_greater( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::Greater, + }) + } + fn unconstrained_eq( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::Eq, + }) + } + fn unconstrained_not_eq( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::NotEq, + }) + } + fn unconstrained_bool_or( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::BoolOr, + }) + } + fn unconstrained_bool_and( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::BoolAnd, + }) + } + fn unconstrained_bit_or( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::BitOr, + }) + } + fn unconstrained_bit_and( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::BitAnd, + }) + } + fn unconstrained_bit_xor( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::BitXor, + }) + } +} + +impl> RootAPI for DebugBuilder { + fn memorized_simple_call) -> Vec + 'static>( + &mut self, + f: F, + inputs: &[Variable], + ) -> Vec { + let inputs = inputs.to_vec(); + f(self, &inputs) + } +} + +impl> DebugBuilder { + pub fn new( + inputs: Vec, + public_inputs: Vec, + hint_caller: H, + ) -> (Self, Vec, Vec) { + let mut builder = DebugBuilder { + values: vec![C::CircuitField::zero()], + hint_caller, + }; + let vars = (1..=inputs.len()).map(new_variable).collect(); + let public_vars = (inputs.len() + 1..=inputs.len() + public_inputs.len()) + .map(new_variable) + .collect(); + builder.values.extend(inputs); + builder.values.extend(public_inputs); + (builder, vars, public_vars) + } + + fn convert_to_value>(&self, value: T) -> C::CircuitField { + match value.convert_to_variable_or_value() { + VariableOrValue::Variable(v) => self.values[get_variable_id(v)], + VariableOrValue::Value(v) => v, + } + } + + fn convert_to_id>(&mut self, value: T) -> usize { + match value.convert_to_variable_or_value() { + VariableOrValue::Variable(v) => get_variable_id(v), + VariableOrValue::Value(v) => { + let id = self.values.len(); + self.values.push(v); + id + } + } + } + + fn return_as_variable(&mut self, value: C::CircuitField) -> Variable { + let id = self.values.len(); + self.values.push(value); + new_variable(id) + } + + fn eval_ir_insn(&mut self, insn: IrInstruction) -> Variable { + match insn.eval_unsafe(&self.values) { + EvalResult::Error(e) => panic!("error: {:?}", e), + EvalResult::SubCircuitCall(_, _) => unreachable!(), + EvalResult::Value(v) => self.return_as_variable(v), + EvalResult::Values(_) => unreachable!(), + } + } +} diff --git a/expander_compiler/src/frontend/mod.rs b/expander_compiler/src/frontend/mod.rs index 1b087b34..761ead61 100644 --- a/expander_compiler/src/frontend/mod.rs +++ b/expander_compiler/src/frontend/mod.rs @@ -1,21 +1,25 @@ use builder::RootBuilder; +use crate::circuit::layered::{CrossLayerInputType, NormalInputType}; use crate::circuit::{ir, layered}; mod api; mod builder; mod circuit; +mod debug; mod variables; mod witness; pub use circuit::declare_circuit; pub type API = builder::RootBuilder; pub use crate::circuit::config::*; -pub use crate::field::{Field, BN254, GF2, M31}; +pub use crate::compile::CompileOptions; +pub use crate::field::{Field, FieldArith, FieldModulus, BN254, GF2, M31}; +pub use crate::hints::registry::{EmptyHintCaller, HintCaller, HintRegistry}; pub use crate::utils::error::Error; -pub use api::BasicAPI; +pub use api::{BasicAPI, RootAPI}; pub use builder::Variable; -pub use circuit::Define; +pub use circuit::{Define, GenericDefine}; pub use witness::WitnessSolver; pub mod internal { @@ -29,7 +33,37 @@ pub mod internal { pub mod extra { pub use super::api::UnconstrainedAPI; + pub use super::debug::DebugBuilder; + pub use crate::hints::registry::{EmptyHintCaller, HintCaller, HintRegistry}; pub use crate::utils::serde::Serde; + + use super::*; + + pub fn debug_eval< + C: Config, + Cir: internal::DumpLoadTwoVariables + GenericDefine + Clone, + CA: internal::DumpLoadTwoVariables, + H: HintCaller, + >( + circuit: &Cir, + assignment: &CA, + hint_caller: H, + ) { + let (num_inputs, num_public_inputs) = circuit.num_vars(); + let (a_num_inputs, a_num_public_inputs) = assignment.num_vars(); + assert_eq!(num_inputs, a_num_inputs); + assert_eq!(num_public_inputs, a_num_public_inputs); + let mut inputs = Vec::new(); + let mut public_inputs = Vec::new(); + assignment.dump_into(&mut inputs, &mut public_inputs); + let (mut root_builder, input_variables, public_input_variables) = + DebugBuilder::::new(inputs, public_inputs, hint_caller); + let mut circuit = circuit.clone(); + let mut vars_ptr = input_variables.as_slice(); + let mut public_vars_ptr = public_input_variables.as_slice(); + circuit.load_from(&mut vars_ptr, &mut public_vars_ptr); + circuit.define(&mut root_builder); + } } #[cfg(test)] @@ -51,16 +85,68 @@ fn build + Define + pub struct CompileResult { pub witness_solver: WitnessSolver, - pub layered_circuit: layered::Circuit, + pub layered_circuit: layered::Circuit, +} + +pub struct CompileResultCrossLayer { + pub witness_solver: WitnessSolver, + pub layered_circuit: layered::Circuit, } pub fn compile + Define + Clone>( circuit: &Cir, ) -> Result, Error> { let root = build(circuit); - let (irw, lc) = crate::compile::compile::(&root)?; + let (irw, lc) = crate::compile::compile::(&root)?; + Ok(CompileResult { + witness_solver: WitnessSolver { circuit: irw }, + layered_circuit: lc, + }) +} + +fn build_generic< + C: Config, + Cir: internal::DumpLoadTwoVariables + GenericDefine + Clone, +>( + circuit: &Cir, +) -> ir::source::RootCircuit { + let (num_inputs, num_public_inputs) = circuit.num_vars(); + let (mut root_builder, input_variables, public_input_variables) = + RootBuilder::::new(num_inputs, num_public_inputs); + let mut circuit = circuit.clone(); + let mut vars_ptr = input_variables.as_slice(); + let mut public_vars_ptr = public_input_variables.as_slice(); + circuit.load_from(&mut vars_ptr, &mut public_vars_ptr); + circuit.define(&mut root_builder); + root_builder.build() +} + +pub fn compile_generic< + C: Config, + Cir: internal::DumpLoadTwoVariables + GenericDefine + Clone, +>( + circuit: &Cir, + options: CompileOptions, +) -> Result, Error> { + let root = build_generic(circuit); + let (irw, lc) = crate::compile::compile_with_options::(&root, options)?; Ok(CompileResult { witness_solver: WitnessSolver { circuit: irw }, layered_circuit: lc, }) } + +pub fn compile_generic_cross_layer< + C: Config, + Cir: internal::DumpLoadTwoVariables + GenericDefine + Clone, +>( + circuit: &Cir, + options: CompileOptions, +) -> Result, Error> { + let root = build_generic(circuit); + let (irw, lc) = crate::compile::compile_with_options::(&root, options)?; + Ok(CompileResultCrossLayer { + witness_solver: WitnessSolver { circuit: irw }, + layered_circuit: lc, + }) +} diff --git a/expander_compiler/src/frontend/tests.rs b/expander_compiler/src/frontend/tests.rs index ed967b87..31bc9102 100644 --- a/expander_compiler/src/frontend/tests.rs +++ b/expander_compiler/src/frontend/tests.rs @@ -44,9 +44,9 @@ fn test_circuit_declaration() { c.dump_into(&mut vars, &mut public_vars); assert_eq!((vars.len(), public_vars.len()), c.num_vars()); let mut c2 = Circuit1::::default(); - let mut vars_ref = &mut vars.as_slice(); - let mut public_vars_ref = &mut public_vars.as_slice(); - c2.load_from(&mut vars_ref, &mut public_vars_ref); + let vars_ref = &mut vars.as_slice(); + let public_vars_ref = &mut public_vars.as_slice(); + c2.load_from(vars_ref, public_vars_ref); assert_eq!(vars_ref.len(), 0); assert_eq!(public_vars_ref.len(), 0); assert_eq!(c.a, c2.a); diff --git a/expander_compiler/src/frontend/witness.rs b/expander_compiler/src/frontend/witness.rs index f686fe15..06b4f5bb 100644 --- a/expander_compiler/src/frontend/witness.rs +++ b/expander_compiler/src/frontend/witness.rs @@ -1,5 +1,8 @@ pub use crate::circuit::ir::hint_normalized::witness_solver::WitnessSolver; -use crate::circuit::layered::witness::Witness; +use crate::{ + circuit::layered::witness::Witness, + hints::registry::{EmptyHintCaller, HintCaller}, +}; use super::{internal, Config, Error}; @@ -7,22 +10,42 @@ impl WitnessSolver { pub fn solve_witness>( &self, assignment: &Cir, + ) -> Result, Error> { + self.solve_witness_with_hints(assignment, &mut EmptyHintCaller) + } + + pub fn solve_witness_with_hints>( + &self, + assignment: &Cir, + hint_caller: &mut impl HintCaller, ) -> Result, Error> { let mut vars = Vec::new(); let mut public_vars = Vec::new(); assignment.dump_into(&mut vars, &mut public_vars); - self.solve_witness_from_raw_inputs(vars, public_vars) + self.solve_witness_from_raw_inputs(vars, public_vars, hint_caller) } pub fn solve_witnesses>( &self, assignments: &[Cir], ) -> Result, Error> { - self.solve_witnesses_from_raw_inputs(assignments.len(), |i| { - let mut vars = Vec::new(); - let mut public_vars = Vec::new(); - assignments[i].dump_into(&mut vars, &mut public_vars); - (vars, public_vars) - }) + self.solve_witnesses_with_hints(assignments, &mut EmptyHintCaller) + } + + pub fn solve_witnesses_with_hints>( + &self, + assignments: &[Cir], + hint_caller: &mut impl HintCaller, + ) -> Result, Error> { + self.solve_witnesses_from_raw_inputs( + assignments.len(), + |i| { + let mut vars = Vec::new(); + let mut public_vars = Vec::new(); + assignments[i].dump_into(&mut vars, &mut public_vars); + (vars, public_vars) + }, + hint_caller, + ) } } diff --git a/expander_compiler/src/hints/builtin.rs b/expander_compiler/src/hints/builtin.rs new file mode 100644 index 00000000..76dc5ebe --- /dev/null +++ b/expander_compiler/src/hints/builtin.rs @@ -0,0 +1,321 @@ +use std::hash::{DefaultHasher, Hash, Hasher}; + +use ethnum::U256; +use rand::RngCore; + +use crate::{field::Field, utils::error::Error}; + +#[repr(u64)] +pub enum BuiltinHintIds { + Identity = 0xccc000000000, + Div, + Eq, + NotEq, + BoolOr, + BoolAnd, + BitOr, + BitAnd, + BitXor, + Select, + Pow, + IntDiv, + Mod, + ShiftL, + ShiftR, + LesserEq, + GreaterEq, + Lesser, + Greater, +} + +#[cfg(not(target_pointer_width = "64"))] +compile_error!("compilation is only allowed for 64-bit targets"); + +impl BuiltinHintIds { + pub fn from_usize(id: usize) -> Option { + if id < (BuiltinHintIds::Identity as u64 as usize) { + return None; + } + if id > (BuiltinHintIds::Identity as u64 as usize + 100) { + return None; + } + match id { + x if x == BuiltinHintIds::Identity as u64 as usize => Some(BuiltinHintIds::Identity), + x if x == BuiltinHintIds::Div as u64 as usize => Some(BuiltinHintIds::Div), + x if x == BuiltinHintIds::Eq as u64 as usize => Some(BuiltinHintIds::Eq), + x if x == BuiltinHintIds::NotEq as u64 as usize => Some(BuiltinHintIds::NotEq), + x if x == BuiltinHintIds::BoolOr as u64 as usize => Some(BuiltinHintIds::BoolOr), + x if x == BuiltinHintIds::BoolAnd as u64 as usize => Some(BuiltinHintIds::BoolAnd), + x if x == BuiltinHintIds::BitOr as u64 as usize => Some(BuiltinHintIds::BitOr), + x if x == BuiltinHintIds::BitAnd as u64 as usize => Some(BuiltinHintIds::BitAnd), + x if x == BuiltinHintIds::BitXor as u64 as usize => Some(BuiltinHintIds::BitXor), + x if x == BuiltinHintIds::Select as u64 as usize => Some(BuiltinHintIds::Select), + x if x == BuiltinHintIds::Pow as u64 as usize => Some(BuiltinHintIds::Pow), + x if x == BuiltinHintIds::IntDiv as u64 as usize => Some(BuiltinHintIds::IntDiv), + x if x == BuiltinHintIds::Mod as u64 as usize => Some(BuiltinHintIds::Mod), + x if x == BuiltinHintIds::ShiftL as u64 as usize => Some(BuiltinHintIds::ShiftL), + x if x == BuiltinHintIds::ShiftR as u64 as usize => Some(BuiltinHintIds::ShiftR), + x if x == BuiltinHintIds::LesserEq as u64 as usize => Some(BuiltinHintIds::LesserEq), + x if x == BuiltinHintIds::GreaterEq as u64 as usize => Some(BuiltinHintIds::GreaterEq), + x if x == BuiltinHintIds::Lesser as u64 as usize => Some(BuiltinHintIds::Lesser), + x if x == BuiltinHintIds::Greater as u64 as usize => Some(BuiltinHintIds::Greater), + _ => None, + } + } +} + +fn stub_impl_general(hint_id: usize, inputs: &Vec, num_outputs: usize) -> Vec { + let mut hasher = DefaultHasher::new(); + hint_id.hash(&mut hasher); + inputs.hash(&mut hasher); + let mut outputs = Vec::with_capacity(num_outputs); + for _ in 0..num_outputs { + let t = hasher.finish(); + outputs.push(F::from(t as u32)); + t.hash(&mut hasher); + } + outputs +} + +fn validate_builtin_hint( + hint_id: BuiltinHintIds, + num_inputs: usize, + num_outputs: usize, +) -> Result<(), Error> { + match hint_id { + BuiltinHintIds::Identity => { + if num_inputs != num_outputs { + return Err(Error::InternalError( + "identity hint requires exactly the same number of inputs and outputs" + .to_string(), + )); + } + if num_inputs == 0 { + return Err(Error::InternalError( + "identity hint requires at least 1 input".to_string(), + )); + } + } + BuiltinHintIds::Div + | BuiltinHintIds::Eq + | BuiltinHintIds::NotEq + | BuiltinHintIds::BoolOr + | BuiltinHintIds::BoolAnd + | BuiltinHintIds::BitOr + | BuiltinHintIds::BitAnd + | BuiltinHintIds::BitXor + | BuiltinHintIds::Pow + | BuiltinHintIds::IntDiv + | BuiltinHintIds::Mod + | BuiltinHintIds::ShiftL + | BuiltinHintIds::ShiftR + | BuiltinHintIds::LesserEq + | BuiltinHintIds::GreaterEq + | BuiltinHintIds::Lesser + | BuiltinHintIds::Greater => { + if num_inputs != 2 { + return Err(Error::InternalError( + "binary op requires exactly 2 inputs".to_string(), + )); + } + if num_outputs != 1 { + return Err(Error::InternalError( + "binary op requires exactly 1 output".to_string(), + )); + } + } + BuiltinHintIds::Select => { + if num_inputs != 3 { + return Err(Error::InternalError( + "select requires exactly 3 inputs".to_string(), + )); + } + if num_outputs != 1 { + return Err(Error::InternalError( + "select requires exactly 1 output".to_string(), + )); + } + } + } + Ok(()) +} + +pub fn validate_hint(hint_id: usize, num_inputs: usize, num_outputs: usize) -> Result<(), Error> { + match BuiltinHintIds::from_usize(hint_id) { + Some(hint_id) => validate_builtin_hint(hint_id, num_inputs, num_outputs), + None => { + if num_outputs == 0 { + return Err(Error::InternalError( + "custom hint requires at least 1 output".to_string(), + )); + } + if num_inputs == 0 { + return Err(Error::InternalError( + "custom hint requires at least 1 input".to_string(), + )); + } + Ok(()) + } + } +} + +pub fn impl_builtin_hint( + hint_id: BuiltinHintIds, + inputs: &[F], + num_outputs: usize, +) -> Vec { + match hint_id { + BuiltinHintIds::Identity => inputs.iter().take(num_outputs).cloned().collect(), + BuiltinHintIds::Div => binop_hint(inputs, |x, y| match y.inv() { + Some(inv) => x * inv, + None => F::zero(), + }), + BuiltinHintIds::Eq => binop_hint(inputs, |x, y| F::from((x == y) as u32)), + BuiltinHintIds::NotEq => binop_hint(inputs, |x, y| F::from((x != y) as u32)), + BuiltinHintIds::BoolOr => binop_hint(inputs, |x, y| { + F::from((!x.is_zero() || !y.is_zero()) as u32) + }), + BuiltinHintIds::BoolAnd => binop_hint(inputs, |x, y| { + F::from((!x.is_zero() && !y.is_zero()) as u32) + }), + BuiltinHintIds::BitOr => binop_hint_on_u256(inputs, |x, y| x | y), + BuiltinHintIds::BitAnd => binop_hint_on_u256(inputs, |x, y| x & y), + BuiltinHintIds::BitXor => binop_hint_on_u256(inputs, |x, y| x ^ y), + BuiltinHintIds::Select => { + let mut outputs = Vec::with_capacity(num_outputs); + outputs.push(if !inputs[0].is_zero() { + inputs[1] + } else { + inputs[2] + }); + outputs + } + BuiltinHintIds::Pow => binop_hint(inputs, |x, y| { + let mut t = x; + let mut res = F::one(); + let mut y: U256 = y.to_u256(); + while y != U256::ZERO { + if y & U256::from(1u32) != U256::ZERO { + res *= t; + } + y >>= 1; + t = t * t; + } + res + }), + BuiltinHintIds::IntDiv => { + binop_hint_on_u256( + inputs, + |x, y| if y == U256::ZERO { U256::ZERO } else { x / y }, + ) + } + BuiltinHintIds::Mod => { + binop_hint_on_u256( + inputs, + |x, y| if y == U256::ZERO { U256::ZERO } else { x % y }, + ) + } + BuiltinHintIds::ShiftL => binop_hint_on_u256(inputs, |x, y| circom_shift_l_impl::(x, y)), + BuiltinHintIds::ShiftR => binop_hint_on_u256(inputs, |x, y| circom_shift_r_impl::(x, y)), + BuiltinHintIds::LesserEq => binop_hint(inputs, |x, y| F::from((x <= y) as u32)), + BuiltinHintIds::GreaterEq => binop_hint(inputs, |x, y| F::from((x >= y) as u32)), + BuiltinHintIds::Lesser => binop_hint(inputs, |x, y| F::from((x < y) as u32)), + BuiltinHintIds::Greater => binop_hint(inputs, |x, y| F::from((x > y) as u32)), + } +} + +fn binop_hint F>(inputs: &[F], f: G) -> Vec { + vec![f(inputs[0], inputs[1])] +} + +fn binop_hint_on_u256 U256>(inputs: &[F], f: G) -> Vec { + let x_u256: U256 = inputs[0].to_u256(); + let y_u256: U256 = inputs[1].to_u256(); + let z_u256 = f(x_u256, y_u256); + vec![F::from_u256(z_u256)] +} + +pub fn stub_impl(hint_id: usize, inputs: &Vec, num_outputs: usize) -> Vec { + match BuiltinHintIds::from_usize(hint_id) { + Some(hint_id) => impl_builtin_hint(hint_id, inputs, num_outputs), + None => stub_impl_general(hint_id, inputs, num_outputs), + } +} + +pub fn random_builtin(mut rand: impl RngCore) -> (usize, usize, usize) { + loop { + let hint_id = (rand.next_u64() as usize % 100) + (BuiltinHintIds::Identity as u64 as usize); + if let Some(hint_id) = BuiltinHintIds::from_usize(hint_id) { + match hint_id { + BuiltinHintIds::Identity => { + let num_inputs = (rand.next_u64() % 10) as usize + 1; + let num_outputs = num_inputs; + return (hint_id as usize, num_inputs, num_outputs); + } + BuiltinHintIds::Div + | BuiltinHintIds::Eq + | BuiltinHintIds::NotEq + | BuiltinHintIds::BoolOr + | BuiltinHintIds::BoolAnd + | BuiltinHintIds::BitOr + | BuiltinHintIds::BitAnd + | BuiltinHintIds::BitXor + | BuiltinHintIds::Pow + | BuiltinHintIds::IntDiv + | BuiltinHintIds::Mod + | BuiltinHintIds::ShiftL + | BuiltinHintIds::ShiftR + | BuiltinHintIds::LesserEq + | BuiltinHintIds::GreaterEq + | BuiltinHintIds::Lesser + | BuiltinHintIds::Greater => { + return (hint_id as usize, 2, 1); + } + BuiltinHintIds::Select => { + return (hint_id as usize, 3, 1); + } + } + } + } +} + +pub fn u256_bit_length(x: U256) -> usize { + 256 - x.leading_zeros() as usize +} + +pub fn circom_shift_l_impl(x: U256, k: U256) -> U256 { + let top = F::MODULUS / 2; + if k <= top { + let shift = if (k >> U256::from(64u32)) == U256::ZERO { + k.as_u64() as usize + } else { + u256_bit_length(F::MODULUS) + }; + if shift >= 256 { + return U256::ZERO; + } + let value = x << shift; + let mask = U256::from(1u32) << u256_bit_length(F::MODULUS); + let mask = mask - 1; + value & mask + } else { + circom_shift_r_impl::(x, F::MODULUS - k) + } +} + +pub fn circom_shift_r_impl(x: U256, k: U256) -> U256 { + let top = F::MODULUS / 2; + if k <= top { + let shift = if (k >> U256::from(64u32)) == U256::ZERO { + k.as_u64() as usize + } else { + u256_bit_length(F::MODULUS) + }; + if shift >= 256 { + return U256::ZERO; + } + x >> shift + } else { + circom_shift_l_impl::(x, F::MODULUS - k) + } +} diff --git a/expander_compiler/src/hints/mod.rs b/expander_compiler/src/hints/mod.rs index b9a312c6..05a8cf64 100644 --- a/expander_compiler/src/hints/mod.rs +++ b/expander_compiler/src/hints/mod.rs @@ -1,321 +1,20 @@ -use std::hash::{DefaultHasher, Hash, Hasher}; +pub mod builtin; +pub mod registry; -use ethnum::U256; -use rand::RngCore; +pub use builtin::*; -use crate::{field::Field, utils::error::Error}; - -#[repr(u64)] -pub enum BuiltinHintIds { - Identity = 0xccc000000000, - Div, - Eq, - NotEq, - BoolOr, - BoolAnd, - BitOr, - BitAnd, - BitXor, - Select, - Pow, - IntDiv, - Mod, - ShiftL, - ShiftR, - LesserEq, - GreaterEq, - Lesser, - Greater, -} - -#[cfg(not(target_pointer_width = "64"))] -compile_error!("compilation is only allowed for 64-bit targets"); - -impl BuiltinHintIds { - pub fn from_usize(id: usize) -> Option { - if id < (BuiltinHintIds::Identity as u64 as usize) { - return None; - } - if id > (BuiltinHintIds::Identity as u64 as usize + 100) { - return None; - } - match id { - x if x == BuiltinHintIds::Identity as u64 as usize => Some(BuiltinHintIds::Identity), - x if x == BuiltinHintIds::Div as u64 as usize => Some(BuiltinHintIds::Div), - x if x == BuiltinHintIds::Eq as u64 as usize => Some(BuiltinHintIds::Eq), - x if x == BuiltinHintIds::NotEq as u64 as usize => Some(BuiltinHintIds::NotEq), - x if x == BuiltinHintIds::BoolOr as u64 as usize => Some(BuiltinHintIds::BoolOr), - x if x == BuiltinHintIds::BoolAnd as u64 as usize => Some(BuiltinHintIds::BoolAnd), - x if x == BuiltinHintIds::BitOr as u64 as usize => Some(BuiltinHintIds::BitOr), - x if x == BuiltinHintIds::BitAnd as u64 as usize => Some(BuiltinHintIds::BitAnd), - x if x == BuiltinHintIds::BitXor as u64 as usize => Some(BuiltinHintIds::BitXor), - x if x == BuiltinHintIds::Select as u64 as usize => Some(BuiltinHintIds::Select), - x if x == BuiltinHintIds::Pow as u64 as usize => Some(BuiltinHintIds::Pow), - x if x == BuiltinHintIds::IntDiv as u64 as usize => Some(BuiltinHintIds::IntDiv), - x if x == BuiltinHintIds::Mod as u64 as usize => Some(BuiltinHintIds::Mod), - x if x == BuiltinHintIds::ShiftL as u64 as usize => Some(BuiltinHintIds::ShiftL), - x if x == BuiltinHintIds::ShiftR as u64 as usize => Some(BuiltinHintIds::ShiftR), - x if x == BuiltinHintIds::LesserEq as u64 as usize => Some(BuiltinHintIds::LesserEq), - x if x == BuiltinHintIds::GreaterEq as u64 as usize => Some(BuiltinHintIds::GreaterEq), - x if x == BuiltinHintIds::Lesser as u64 as usize => Some(BuiltinHintIds::Lesser), - x if x == BuiltinHintIds::Greater as u64 as usize => Some(BuiltinHintIds::Greater), - _ => None, - } - } -} - -fn stub_impl_general(hint_id: usize, inputs: &Vec, num_outputs: usize) -> Vec { - let mut hasher = DefaultHasher::new(); - hint_id.hash(&mut hasher); - inputs.hash(&mut hasher); - let mut outputs = Vec::with_capacity(num_outputs); - for _ in 0..num_outputs { - let t = hasher.finish(); - outputs.push(F::from(t as u32)); - t.hash(&mut hasher); - } - outputs -} - -fn validate_builtin_hint( - hint_id: BuiltinHintIds, - num_inputs: usize, - num_outputs: usize, -) -> Result<(), Error> { - match hint_id { - BuiltinHintIds::Identity => { - if num_inputs != num_outputs { - return Err(Error::InternalError( - "identity hint requires exactly the same number of inputs and outputs" - .to_string(), - )); - } - if num_inputs == 0 { - return Err(Error::InternalError( - "identity hint requires at least 1 input".to_string(), - )); - } - } - BuiltinHintIds::Div - | BuiltinHintIds::Eq - | BuiltinHintIds::NotEq - | BuiltinHintIds::BoolOr - | BuiltinHintIds::BoolAnd - | BuiltinHintIds::BitOr - | BuiltinHintIds::BitAnd - | BuiltinHintIds::BitXor - | BuiltinHintIds::Pow - | BuiltinHintIds::IntDiv - | BuiltinHintIds::Mod - | BuiltinHintIds::ShiftL - | BuiltinHintIds::ShiftR - | BuiltinHintIds::LesserEq - | BuiltinHintIds::GreaterEq - | BuiltinHintIds::Lesser - | BuiltinHintIds::Greater => { - if num_inputs != 2 { - return Err(Error::InternalError( - "binary op requires exactly 2 inputs".to_string(), - )); - } - if num_outputs != 1 { - return Err(Error::InternalError( - "binary op requires exactly 1 output".to_string(), - )); - } - } - BuiltinHintIds::Select => { - if num_inputs != 3 { - return Err(Error::InternalError( - "select requires exactly 3 inputs".to_string(), - )); - } - if num_outputs != 1 { - return Err(Error::InternalError( - "select requires exactly 1 output".to_string(), - )); - } - } - } - Ok(()) -} +use registry::HintCaller; -pub fn validate_hint(hint_id: usize, num_inputs: usize, num_outputs: usize) -> Result<(), Error> { - match BuiltinHintIds::from_usize(hint_id) { - Some(hint_id) => validate_builtin_hint(hint_id, num_inputs, num_outputs), - None => { - if num_outputs == 0 { - return Err(Error::InternalError( - "custom hint requires at least 1 output".to_string(), - )); - } - if num_inputs == 0 { - return Err(Error::InternalError( - "custom hint requires at least 1 input".to_string(), - )); - } - Ok(()) - } - } -} +use crate::{field::Field, utils::error::Error}; -fn impl_builtin_hint( - hint_id: BuiltinHintIds, +pub fn safe_impl( + hint_caller: &mut impl HintCaller, + hint_id: usize, inputs: &[F], num_outputs: usize, -) -> Vec { - match hint_id { - BuiltinHintIds::Identity => inputs.iter().take(num_outputs).cloned().collect(), - BuiltinHintIds::Div => binop_hint(inputs, |x, y| match y.inv() { - Some(inv) => x * inv, - None => F::zero(), - }), - BuiltinHintIds::Eq => binop_hint(inputs, |x, y| F::from((x == y) as u32)), - BuiltinHintIds::NotEq => binop_hint(inputs, |x, y| F::from((x != y) as u32)), - BuiltinHintIds::BoolOr => binop_hint(inputs, |x, y| { - F::from((!x.is_zero() || !y.is_zero()) as u32) - }), - BuiltinHintIds::BoolAnd => binop_hint(inputs, |x, y| { - F::from((!x.is_zero() && !y.is_zero()) as u32) - }), - BuiltinHintIds::BitOr => binop_hint_on_u256(inputs, |x, y| x | y), - BuiltinHintIds::BitAnd => binop_hint_on_u256(inputs, |x, y| x & y), - BuiltinHintIds::BitXor => binop_hint_on_u256(inputs, |x, y| x ^ y), - BuiltinHintIds::Select => { - let mut outputs = Vec::with_capacity(num_outputs); - outputs.push(if !inputs[0].is_zero() { - inputs[1] - } else { - inputs[2] - }); - outputs - } - BuiltinHintIds::Pow => binop_hint(inputs, |x, y| { - let mut t = x; - let mut res = F::one(); - let mut y: U256 = y.to_u256(); - while y != U256::ZERO { - if y & U256::from(1u32) != U256::ZERO { - res *= t; - } - y >>= 1; - t = t * t; - } - res - }), - BuiltinHintIds::IntDiv => { - binop_hint_on_u256( - inputs, - |x, y| if y == U256::ZERO { U256::ZERO } else { x / y }, - ) - } - BuiltinHintIds::Mod => { - binop_hint_on_u256( - inputs, - |x, y| if y == U256::ZERO { U256::ZERO } else { x % y }, - ) - } - BuiltinHintIds::ShiftL => binop_hint_on_u256(inputs, |x, y| circom_shift_l_impl::(x, y)), - BuiltinHintIds::ShiftR => binop_hint_on_u256(inputs, |x, y| circom_shift_r_impl::(x, y)), - BuiltinHintIds::LesserEq => binop_hint(inputs, |x, y| F::from((x <= y) as u32)), - BuiltinHintIds::GreaterEq => binop_hint(inputs, |x, y| F::from((x >= y) as u32)), - BuiltinHintIds::Lesser => binop_hint(inputs, |x, y| F::from((x < y) as u32)), - BuiltinHintIds::Greater => binop_hint(inputs, |x, y| F::from((x > y) as u32)), - } -} - -fn binop_hint F>(inputs: &[F], f: G) -> Vec { - vec![f(inputs[0], inputs[1])] -} - -fn binop_hint_on_u256 U256>(inputs: &[F], f: G) -> Vec { - let x_u256: U256 = inputs[0].to_u256(); - let y_u256: U256 = inputs[1].to_u256(); - let z_u256 = f(x_u256, y_u256); - vec![F::from_u256(z_u256)] -} - -pub fn stub_impl(hint_id: usize, inputs: &Vec, num_outputs: usize) -> Vec { +) -> Result, Error> { match BuiltinHintIds::from_usize(hint_id) { - Some(hint_id) => impl_builtin_hint(hint_id, inputs, num_outputs), - None => stub_impl_general(hint_id, inputs, num_outputs), - } -} - -pub fn random_builtin(mut rand: impl RngCore) -> (usize, usize, usize) { - loop { - let hint_id = (rand.next_u64() as usize % 100) + (BuiltinHintIds::Identity as u64 as usize); - if let Some(hint_id) = BuiltinHintIds::from_usize(hint_id) { - match hint_id { - BuiltinHintIds::Identity => { - let num_inputs = (rand.next_u64() % 10) as usize + 1; - let num_outputs = num_inputs; - return (hint_id as usize, num_inputs, num_outputs); - } - BuiltinHintIds::Div - | BuiltinHintIds::Eq - | BuiltinHintIds::NotEq - | BuiltinHintIds::BoolOr - | BuiltinHintIds::BoolAnd - | BuiltinHintIds::BitOr - | BuiltinHintIds::BitAnd - | BuiltinHintIds::BitXor - | BuiltinHintIds::Pow - | BuiltinHintIds::IntDiv - | BuiltinHintIds::Mod - | BuiltinHintIds::ShiftL - | BuiltinHintIds::ShiftR - | BuiltinHintIds::LesserEq - | BuiltinHintIds::GreaterEq - | BuiltinHintIds::Lesser - | BuiltinHintIds::Greater => { - return (hint_id as usize, 2, 1); - } - BuiltinHintIds::Select => { - return (hint_id as usize, 3, 1); - } - } - } - } -} - -pub fn u256_bit_length(x: U256) -> usize { - 256 - x.leading_zeros() as usize -} - -pub fn circom_shift_l_impl(x: U256, k: U256) -> U256 { - let top = F::modulus() / 2; - if k <= top { - let shift = if (k >> U256::from(64u32)) == U256::ZERO { - k.as_u64() as usize - } else { - u256_bit_length(F::modulus()) - }; - if shift >= 256 { - return U256::ZERO; - } - let value = x << shift; - let mask = U256::from(1u32) << u256_bit_length(F::modulus()); - let mask = mask - 1; - value & mask - } else { - circom_shift_r_impl::(x, F::modulus() - k) - } -} - -pub fn circom_shift_r_impl(x: U256, k: U256) -> U256 { - let top = F::modulus() / 2; - if k <= top { - let shift = if (k >> U256::from(64u32)) == U256::ZERO { - k.as_u64() as usize - } else { - u256_bit_length(F::modulus()) - }; - if shift >= 256 { - return U256::ZERO; - } - x >> shift - } else { - circom_shift_l_impl::(x, F::modulus() - k) + Some(hint_id) => Ok(impl_builtin_hint(hint_id, inputs, num_outputs)), + None => hint_caller.call(hint_id, inputs, num_outputs), } } diff --git a/expander_compiler/src/hints/registry.rs b/expander_compiler/src/hints/registry.rs new file mode 100644 index 00000000..27ee0833 --- /dev/null +++ b/expander_compiler/src/hints/registry.rs @@ -0,0 +1,77 @@ +use std::collections::HashMap; + +use tiny_keccak::Hasher; + +use crate::{field::Field, utils::error::Error}; + +use super::BuiltinHintIds; + +pub type HintFn = dyn FnMut(&[F], &mut [F]) -> Result<(), Error>; + +#[derive(Default)] +pub struct HintRegistry { + hints: HashMap>>, +} + +pub fn hint_key_to_id(key: &str) -> usize { + let mut hasher = tiny_keccak::Keccak::v256(); + hasher.update(key.as_bytes()); + let mut hash = [0u8; 32]; + hasher.finalize(&mut hash); + + let res = usize::from_le_bytes(hash[0..8].try_into().unwrap()); + if BuiltinHintIds::from_usize(res).is_some() { + panic!("Hint id {} collides with a builtin hint id", res); + } + res +} + +impl HintRegistry { + pub fn new() -> Self { + Self::default() + } + pub fn register Result<(), Error> + 'static>( + &mut self, + key: &str, + hint: Hint, + ) { + let id = hint_key_to_id(key); + if self.hints.contains_key(&id) { + panic!("Hint with id {} already exists", id); + } + self.hints.insert(id, Box::new(hint)); + } + pub fn call(&mut self, id: usize, args: &[F], num_outputs: usize) -> Result, Error> { + if let Some(hint) = self.hints.get_mut(&id) { + let mut outputs = vec![F::zero(); num_outputs]; + hint(args, &mut outputs).map(|_| outputs) + } else { + panic!("Hint with id {} not found", id); + } + } +} + +#[derive(Default)] +pub struct EmptyHintCaller; + +impl EmptyHintCaller { + pub fn new() -> Self { + Self + } +} + +pub trait HintCaller: 'static { + fn call(&mut self, id: usize, args: &[F], num_outputs: usize) -> Result, Error>; +} + +impl HintCaller for HintRegistry { + fn call(&mut self, id: usize, args: &[F], num_outputs: usize) -> Result, Error> { + self.call(id, args, num_outputs) + } +} + +impl HintCaller for EmptyHintCaller { + fn call(&mut self, id: usize, _: &[F], _: usize) -> Result, Error> { + Err(Error::UserError(format!("hint with id {} not found", id))) + } +} diff --git a/expander_compiler/src/layering/compile.rs b/expander_compiler/src/layering/compile.rs index f2b9214d..05c58b1f 100644 --- a/expander_compiler/src/layering/compile.rs +++ b/expander_compiler/src/layering/compile.rs @@ -5,13 +5,13 @@ use crate::circuit::{ config::Config, ir::dest::{Circuit as IrCircuit, Instruction, RootCircuit as IrRootCircuit}, ir::expr::Expression, - layered::{Coef, Segment}, + layered::{Coef, InputType, Segment}, }; use crate::utils::pool::Pool; use super::layer_layout::{LayerLayout, LayerLayoutContext, LayerReq}; -pub struct CompileContext<'a, C: Config> { +pub struct CompileContext<'a, C: Config, I: InputType> { // the root circuit pub rc: &'a IrRootCircuit, @@ -26,8 +26,8 @@ pub struct CompileContext<'a, C: Config> { pub layer_req_to_layout: HashMap, // compiled layered circuits - pub compiled_circuits: Vec>, - pub conncected_wires: HashMap, + pub compiled_circuits: Vec>, + pub conncected_wires: HashMap, usize>, // layout id of each layer pub layout_ids: Vec, @@ -51,8 +51,12 @@ pub struct IrContext<'a, C: Config> { // it includes only variables mentioned in instructions, so internal variables in sub circuits are ignored here. pub min_layer: Vec, pub max_layer: Vec, + pub occured_layers: Vec>, pub output_layer: usize, + // for each layer i, the minimum layer j that there exists gate j->i + pub min_used_layer: Vec, + pub output_order: HashMap, // outputOrder[x] == y -> x is the y-th output pub sub_circuit_loc_map: HashMap, @@ -90,7 +94,7 @@ pub struct SubCircuitInsn<'a> { const EXTRA_PRE_ALLOC_SIZE: usize = 1000; -impl<'a, C: Config> CompileContext<'a, C> { +impl<'a, C: Config, I: InputType> CompileContext<'a, C, I> { pub fn compile(&mut self) { // 1. do a toposort of the circuits self.dfs_topo_sort(0); @@ -116,10 +120,7 @@ impl<'a, C: Config> CompileContext<'a, C> { self.layout_ids = layout_ids; // 5. generate wires - let mut layers = Vec::with_capacity(self.circuits[&0].output_layer); - for i in 0..self.circuits[&0].output_layer { - layers.push(self.connect_wires(self.layout_ids[i], self.layout_ids[i + 1])); - } + let layers = self.connect_wires(&self.layout_ids.clone()); self.layers = layers; // 6. record the input order (used to generate witness) @@ -164,7 +165,9 @@ impl<'a, C: Config> CompileContext<'a, C> { num_sub_circuits: ns, min_layer: Vec::new(), max_layer: Vec::new(), + occured_layers: Vec::new(), output_layer: 0, + min_used_layer: Vec::new(), output_order: HashMap::new(), sub_circuit_loc_map: HashMap::new(), sub_circuit_insn_ids: Vec::new(), @@ -436,6 +439,53 @@ impl<'a, C: Config> CompileContext<'a, C> { } } + // compute occured layers + if I::CROSS_LAYER_RELAY { + ic.occured_layers = vec![Vec::new(); ic.max_layer.len()]; + let outputs_set: HashSet = circuit.outputs.iter().cloned().collect(); + for x in q.iter().cloned() { + let mut tmp = Vec::with_capacity(out_edges[x].len() + 1); + tmp.push(ic.min_layer[x]); + for y in out_edges[x].iter().cloned() { + tmp.push(ic.min_layer[y] - layer_advance[y]); + } + if outputs_set.contains(&x) { + tmp.push(ic.output_layer); + } + tmp.sort(); + let mut tmp2 = Vec::with_capacity(tmp.len()); + for &v in tmp.iter() { + if tmp2.is_empty() || *tmp2.last().unwrap() != v { + tmp2.push(v); + } + } + assert_eq!(tmp2[0], ic.min_layer[x]); + assert_eq!(*tmp2.last().unwrap(), ic.max_layer[x]); + ic.occured_layers[x] = tmp2; + } + } + + // compute minUsedLayer + ic.min_used_layer = Vec::with_capacity(ic.output_layer + 1); + ic.min_used_layer.push(0); + ic.min_used_layer.extend(0..ic.output_layer); + for (i, sc) in ic.sub_circuit_insn_refs.iter().enumerate() { + let sub_circuit = &self.circuits[&sc.sub_circuit_id]; + let input_layer = ic.sub_circuit_start_layer[i]; + for j in 0..=sub_circuit.output_layer { + ic.min_used_layer[j + input_layer] = ic.min_used_layer[j + input_layer] + .min(sub_circuit.min_used_layer[j] + input_layer); + } + } + if I::CROSS_LAYER_RELAY { + for x in q.iter().cloned() { + let t = &ic.occured_layers[x]; + for (u, v) in t.iter().zip(t.iter().skip(1)) { + ic.min_used_layer[*v] = ic.min_used_layer[*v].min(*u); + } + } + } + self.circuits.insert(circuit_id, ic); } } diff --git a/expander_compiler/src/layering/input.rs b/expander_compiler/src/layering/input.rs index e872325f..ae9532d9 100644 --- a/expander_compiler/src/layering/input.rs +++ b/expander_compiler/src/layering/input.rs @@ -1,10 +1,10 @@ use std::collections::HashMap; -use crate::circuit::{config::Config, input_mapping::EMPTY}; +use crate::circuit::{config::Config, input_mapping::EMPTY, layered::InputType}; use super::{compile::CompileContext, layer_layout::LayerLayoutInner}; -impl<'a, C: Config> CompileContext<'a, C> { +impl<'a, C: Config, I: InputType> CompileContext<'a, C, I> { pub fn record_input_order(&self) -> Vec { let layout_id = self.layout_ids[0]; let l = self.layer_layout_pool.get(layout_id); diff --git a/expander_compiler/src/layering/layer_layout.rs b/expander_compiler/src/layering/layer_layout.rs index 930bc888..7cc6b43d 100644 --- a/expander_compiler/src/layering/layer_layout.rs +++ b/expander_compiler/src/layering/layer_layout.rs @@ -1,7 +1,7 @@ use std::{collections::HashMap, mem}; use crate::{ - circuit::{config::Config, input_mapping::EMPTY}, + circuit::{config::Config, input_mapping::EMPTY, layered::InputType}, utils::{misc::next_power_of_two, pool::Pool}, }; @@ -85,7 +85,7 @@ pub struct LayerReq { pub layer: usize, // which layer to solve? } -impl<'a, C: Config> CompileContext<'a, C> { +impl<'a, C: Config, I: InputType> CompileContext<'a, C, I> { pub fn prepare_layer_layout_context(&mut self, circuit_id: usize) { let mut ic = self.circuits.remove(&circuit_id).unwrap(); @@ -100,8 +100,14 @@ impl<'a, C: Config> CompileContext<'a, C> { ic.lcs[ic.output_layer].vars.add(v); } for i in 1..ic.num_var { - for j in ic.min_layer[i]..=ic.max_layer[i] { - ic.lcs[j].vars.add(&i); + if I::CROSS_LAYER_RELAY { + for j in ic.occured_layers[i].iter().cloned() { + ic.lcs[j].vars.add(&i); + } + } else { + for j in ic.min_layer[i]..=ic.max_layer[i] { + ic.lcs[j].vars.add(&i); + } } } diff --git a/expander_compiler/src/layering/mod.rs b/expander_compiler/src/layering/mod.rs index c9f4cf7e..9ac7bfa5 100644 --- a/expander_compiler/src/layering/mod.rs +++ b/expander_compiler/src/layering/mod.rs @@ -1,7 +1,12 @@ use std::collections::HashMap; use crate::{ - circuit::{config::Config, input_mapping::InputMapping, ir, layered}, + circuit::{ + config::Config, + input_mapping::InputMapping, + ir, + layered::{self, InputType, InputUsize}, + }, utils::pool::Pool, }; @@ -14,7 +19,9 @@ mod wire; #[cfg(test)] mod tests; -pub fn compile(rc: &ir::dest::RootCircuit) -> (layered::Circuit, InputMapping) { +pub fn compile( + rc: &ir::dest::RootCircuit, +) -> (layered::Circuit, InputMapping) { let mut ctx = compile::CompileContext { rc, circuits: HashMap::new(), @@ -29,7 +36,8 @@ pub fn compile(rc: &ir::dest::RootCircuit) -> (layered::Circuit root_has_constraints: false, }; ctx.compile(); - let l0_size = ctx.compiled_circuits[ctx.layers[0]].num_inputs; + let t: &I::InputUsize = &ctx.compiled_circuits[ctx.layers[0]].num_inputs; + let l0_size = t.get(0); let output_zeroes = rc.expected_num_output_zeroes + ctx.root_has_constraints as usize; let output_all = rc.circuits[&0].outputs.len() + ctx.root_has_constraints as usize; ( diff --git a/expander_compiler/src/layering/tests.rs b/expander_compiler/src/layering/tests.rs index 6eef63f0..b1c53f88 100644 --- a/expander_compiler/src/layering/tests.rs +++ b/expander_compiler/src/layering/tests.rs @@ -1,17 +1,23 @@ use crate::circuit::{ config::{Config, M31Config as C}, input_mapping::InputMapping, - ir::{common::rand_gen::*, dest::RootCircuit as IrRootCircuit}, - layered, + ir::{ + common::rand_gen::*, + dest::{Circuit as IrCircuit, Instruction as IrInstruction, RootCircuit as IrRootCircuit}, + expr::{Expression, Term}, + }, + layered::{self, CrossLayerInputType, InputType, NormalInputType}, }; +use crate::field::M31 as CField; + use crate::field::FieldArith; use super::compile; -pub fn test_input( +pub fn test_input( rc: &IrRootCircuit, - lc: &layered::Circuit, + lc: &layered::Circuit, input_mapping: &InputMapping, input: &Vec, ) { @@ -22,10 +28,10 @@ pub fn test_input( assert_eq!(rc_output, lc_output); } -pub fn compile_and_random_test( +pub fn compile_and_random_test( rc: &IrRootCircuit, n_tests: usize, -) -> (layered::Circuit, InputMapping) { +) -> (layered::Circuit, InputMapping) { assert!(rc.validate().is_ok()); let (lc, input_mapping) = compile(rc); assert_eq!(lc.validate(), Ok(())); @@ -56,7 +62,8 @@ fn random_circuits_1() { config.seed = i; let root = IrRootCircuit::::random(&config); assert_eq!(root.validate(), Ok(())); - compile_and_random_test(&root, 5); + compile_and_random_test::<_, NormalInputType>(&root, 5); + compile_and_random_test::<_, CrossLayerInputType>(&root, 5); } } @@ -76,7 +83,8 @@ fn random_circuits_2() { config.seed = i + 10000; let root = IrRootCircuit::::random(&config); assert_eq!(root.validate(), Ok(())); - compile_and_random_test(&root, 5); + compile_and_random_test::<_, NormalInputType>(&root, 5); + compile_and_random_test::<_, CrossLayerInputType>(&root, 5); } } @@ -96,7 +104,8 @@ fn random_circuits_3() { config.seed = i + 20000; let root = IrRootCircuit::::random(&config); assert_eq!(root.validate(), Ok(())); - compile_and_random_test(&root, 5); + compile_and_random_test::<_, NormalInputType>(&root, 5); + compile_and_random_test::<_, CrossLayerInputType>(&root, 5); } } @@ -119,6 +128,40 @@ fn random_circuits_4() { config.seed = i + 30000; let root = IrRootCircuit::::random(&config); assert_eq!(root.validate(), Ok(())); - compile_and_random_test(&root, 5); + compile_and_random_test::<_, NormalInputType>(&root, 5); + compile_and_random_test::<_, CrossLayerInputType>(&root, 5); + } +} + +#[test] +fn cross_layer_circuit() { + let mut root = IrRootCircuit::::default(); + const N: usize = 1000; + root.circuits.insert( + 0, + IrCircuit:: { + instructions: vec![], + constraints: vec![N * 2 - 1], + outputs: vec![], + num_inputs: N, + }, + ); + for i in 0..N - 1 { + root.circuits + .get_mut(&0) + .unwrap() + .instructions + .push(IrInstruction::InternalVariable { + expr: Expression::from_terms(vec![ + Term::new_linear(CField::one(), N + i), + Term::new_linear(CField::one(), N - i - 1), + ]), + }); + } + assert_eq!(root.validate(), Ok(())); + let (lc, _) = compile_and_random_test::<_, CrossLayerInputType>(&root, 5); + assert!((lc.layer_ids.len() as isize - N as isize).abs() <= 10); + for i in lc.layer_ids.iter() { + assert!(lc.segments[*i].gate_adds.len() <= 10); } } diff --git a/expander_compiler/src/layering/wire.rs b/expander_compiler/src/layering/wire.rs index c7d0a71f..6e671843 100644 --- a/expander_compiler/src/layering/wire.rs +++ b/expander_compiler/src/layering/wire.rs @@ -5,7 +5,10 @@ use crate::{ config::Config, input_mapping::EMPTY, ir::expr::VarSpec, - layered::{Allocation, Coef, GateAdd, GateConst, GateCustom, GateMul, Segment}, + layered::{ + Allocation, Coef, GateAdd, GateConst, GateCustom, GateMul, Input, InputType, + InputUsize, Segment, + }, }, field::FieldArith, utils::pool::Pool, @@ -96,7 +99,7 @@ impl LayoutQuery { } } -impl<'a, C: Config> CompileContext<'a, C> { +impl<'a, C: Config, I: InputType> CompileContext<'a, C, I> { fn layout_query(&self, l: &LayerLayout, s: &[usize]) -> LayoutQuery { let mut var_pos = HashMap::new(); match &l.inner { @@ -122,42 +125,41 @@ impl<'a, C: Config> CompileContext<'a, C> { LayoutQuery { var_pos } } - pub fn connect_wires(&mut self, a_: usize, b_: usize) -> usize { - let map_id = (a_ as u128) << 64 | b_ as u128; - if let Some(x) = self.conncected_wires.get(&map_id) { - return *x; - } - let a = self.layer_layout_pool.get(a_).clone(); - let b = self.layer_layout_pool.get(b_).clone(); - if (a.layer + 1 != b.layer) || a.circuit_id != b.circuit_id { - panic!("unexpected situation"); + pub fn connect_wires(&mut self, layout_ids: &[usize]) -> Vec { + let layouts = layout_ids + .iter() + .map(|x| self.layer_layout_pool.get(*x).clone()) + .collect::>(); + for (a, b) in layouts.iter().zip(layouts.iter().skip(1)) { + if a.layer + 1 != b.layer || a.circuit_id != b.circuit_id { + panic!("unexpected situation"); + } } - let circuit_id = a.circuit_id; - let ic = self.circuits.remove(&circuit_id).unwrap(); - let cur_layer = a.layer; - let next_layer = b.layer; - let (cur_lc, next_lc) = (&ic.lcs[cur_layer], &ic.lcs[next_layer]); - let aq = self.layout_query(&a, cur_lc.vars.vec()); - let bq = self.layout_query(&b, next_lc.vars.vec()); - - // check if all variables exist in the layout - for x in cur_lc.vars.vec().iter() { - if !aq.var_pos.contains_key(x) { + for (i, a) in layouts.iter().enumerate() { + if i != a.layer { panic!("unexpected situation"); } } - if cur_layer + 1 != ic.output_layer { - for x in next_lc.vars.vec().iter() { - if !bq.var_pos.contains_key(x) { + let circuit_id = layouts[0].circuit_id; + let ic = self.circuits.remove(&circuit_id).unwrap(); + if layouts.len() != ic.output_layer + 1 { + panic!("unexpected situation"); + } + let lqs = layouts + .iter() + .map(|x| self.layout_query(x, ic.lcs[x.layer].vars.vec())) + .collect::>(); + + for (lc, lq) in ic.lcs.iter().zip(lqs.iter()).take(ic.output_layer) { + for x in lc.vars.vec() { + if !lq.var_pos.contains_key(x) { panic!("unexpected situation"); } } } - let mut sub_insns: Pool = Pool::new(); - let mut sub_cur_layout: Vec> = Vec::new(); - let mut sub_next_layout: Vec> = Vec::new(); - let mut sub_cur_layout_all: HashMap = HashMap::new(); + let mut sub_layouts_of_layer: Vec> = + vec![HashMap::new(); ic.output_layer + 1]; // find all sub circuits for (i, insn_id) in ic.sub_circuit_insn_ids.iter().enumerate() { @@ -167,225 +169,302 @@ impl<'a, C: Config> CompileContext<'a, C> { let dep = sub_c.output_layer; let input_layer = ic.sub_circuit_start_layer[i]; let output_layer = input_layer + dep; - let mut cur_layout = None; - let mut next_layout = None; - let outf = |x: usize| -> usize { sub_c.circuit.outputs[x] }; - if input_layer <= cur_layer && output_layer >= next_layer { - // normal - if input_layer == cur_layer { - // for the input layer, we need to manually query the layout. (other layers are already subLayouts) - let vs = insn.inputs.clone(); - cur_layout = Some(aq.query( - &mut self.layer_layout_pool, - &self.circuits, - &vs, - |x| x + 1, - sub_id, - 0, - )); - } - if output_layer == next_layer { - // also for the output layer - next_layout = Some(bq.query( - &mut self.layer_layout_pool, - &self.circuits, - &insn.outputs, - outf, - sub_id, - dep, - )); - } - } else if cur_layer == output_layer { - cur_layout = Some(aq.query( + + sub_layouts_of_layer[input_layer].insert( + *insn_id, + lqs[input_layer].query( + &mut self.layer_layout_pool, + &self.circuits, + insn.inputs, + |x| x + 1, + sub_id, + 0, + ), + ); + sub_layouts_of_layer[output_layer].insert( + *insn_id, + lqs[output_layer].query( &mut self.layer_layout_pool, &self.circuits, &insn.outputs, - outf, + |x| sub_c.circuit.outputs[x], sub_id, dep, - )); - sub_cur_layout_all.insert(*insn_id, cur_layout.unwrap()); - continue; - } else { - continue; - } - sub_insns.add(insn_id); - sub_cur_layout.push(cur_layout); - sub_next_layout.push(next_layout); + ), + ); } - // fill already known subLayouts - let a = self.layer_layout_pool.get(a_); - let b = self.layer_layout_pool.get(b_); // fill already known sub_layouts - if let LayerLayoutInner::Sparse { sub_layout, .. } = &a.inner { - for x in sub_layout.iter() { - sub_cur_layout[sub_insns.get_idx(&x.insn_id)] = Some(x.clone()); + for (i, a) in layouts.iter().enumerate() { + if let LayerLayoutInner::Sparse { sub_layout, .. } = &a.inner { + for x in sub_layout.iter() { + sub_layouts_of_layer[i].insert(x.insn_id, x.clone()); + } } } - if let LayerLayoutInner::Sparse { sub_layout, .. } = &b.inner { - for x in sub_layout.iter() { - sub_next_layout[sub_insns.get_idx(&x.insn_id)] = Some(x.clone()); - } + + let mut ress: Vec> = Vec::new(); + for (i, b) in layouts.iter().enumerate().skip(1) { + let num_inputs_vec = (ic.min_used_layer[i]..i) + .rev() + .map(|j| layouts[j].size) + .collect(); + ress.push(Segment { + num_inputs: I::InputUsize::from_vec(num_inputs_vec), + num_outputs: b.size, + ..Default::default() + }); } - let mut res: Segment = Segment { - num_inputs: a.size, - num_outputs: b.size, - ..Default::default() - }; + let mut cached_ress = Vec::with_capacity(ic.output_layer); + for i in 1..=ic.output_layer { + let key = layout_ids[ic.min_used_layer[i]..=i].to_vec(); + cached_ress.push(self.conncected_wires.get(&key).cloned()); + } + let all_cached = cached_ress.iter().all(|x| x.is_some()); + if all_cached { + return cached_ress.into_iter().map(|x| x.unwrap()).collect(); + } // connect sub circuits - for i in 0..sub_insns.len() { - let sub_cur_layout = sub_cur_layout[i].as_ref().unwrap(); - let sub_next_layout = sub_next_layout[i].as_ref().unwrap(); - sub_cur_layout_all.insert(*sub_insns.get(i), sub_cur_layout.clone()); - let scid = self.connect_wires(sub_cur_layout.id, sub_next_layout.id); - let al = Allocation { - input_offset: sub_cur_layout.offset, - output_offset: sub_next_layout.offset, - }; - let mut found = false; - for j in 0..=res.child_segs.len() { - if j == res.child_segs.len() { - res.child_segs.push((scid, vec![al])); - found = true; - break; + for (i, insn_id) in ic.sub_circuit_insn_ids.iter().enumerate() { + let insn = &ic.sub_circuit_insn_refs[i]; + let sub_id = insn.sub_circuit_id; + let sub_c = &self.circuits[&sub_id]; + let dep = sub_c.output_layer; + let input_layer = ic.sub_circuit_start_layer[i]; + let output_layer = input_layer + dep; + + let cur_sub_layout_ids = (input_layer..=output_layer) + .map(|x| sub_layouts_of_layer[x][insn_id].id) + .collect::>(); + let segment_ids = self.connect_wires(&cur_sub_layout_ids); + let sub_c = &self.circuits[&sub_id]; + + for (i, segment_id) in segment_ids.iter().enumerate() { + let alloc_min_layer = sub_c.min_used_layer[i + 1] + input_layer; + let input_offset_vec = (alloc_min_layer..=input_layer + i) + .rev() + .map(|x| sub_layouts_of_layer[x][insn_id].offset) + .collect::>(); + let al = Allocation { + input_offset: I::InputUsize::from_vec(input_offset_vec), + output_offset: sub_layouts_of_layer[input_layer + i + 1][insn_id].offset, + }; + let mut found = false; + let child_segs = &mut ress[input_layer + i].child_segs; + for j in 0..=child_segs.len() { + if j == child_segs.len() { + child_segs.push((*segment_id, vec![al])); + found = true; + break; + } + if child_segs[j].0 == *segment_id { + child_segs[j].1.push(al); + found = true; + break; + } } - if res.child_segs[j].0 == scid { - res.child_segs[j].1.push(al); - found = true; - break; + if !found { + panic!("unexpected situation"); } } - if !found { - panic!("unexpected situation"); - } } // connect self variables - for x in next_lc.vars.vec().iter() { - // only consider real variables - if *x >= ic.num_var { - continue; + for x in 0..ic.num_var { + // connect first occurance + if ic.min_layer[x] != 0 { + let next_layer = ic.min_layer[x]; + let cur_layer = next_layer - 1; + if cached_ress[cur_layer].is_none() { + let res = &mut ress[cur_layer]; + let aq = &lqs[cur_layer]; + let bq = &lqs[next_layer]; + let pos = if let Some(p) = bq.var_pos.get(&x) { + *p + } else { + assert_eq!(cur_layer + 1, ic.output_layer); + continue; + }; + if let Some(value) = ic.constant_like_variables.get(&x) { + res.gate_consts.push(GateConst { + inputs: [], + output: pos, + coef: value.clone(), + }); + } else if ic.internal_variable_expr.contains_key(&x) { + for term in ic.internal_variable_expr[&x].iter() { + match &term.vars { + VarSpec::Const => { + res.gate_consts.push(GateConst { + inputs: [], + output: pos, + coef: Coef::Constant(term.coef), + }); + } + VarSpec::Linear(vid) => { + res.gate_adds.push(GateAdd { + inputs: [I::Input::new(0, aq.var_pos[vid])], + output: pos, + coef: Coef::Constant(term.coef), + }); + } + VarSpec::Quad(vid0, vid1) => { + let x = aq.var_pos[vid0]; + let y = aq.var_pos[vid1]; + let inputs = if x < y { [x, y] } else { [y, x] }; + res.gate_muls.push(GateMul { + inputs: [ + I::Input::new(0, inputs[0]), + I::Input::new(0, inputs[1]), + ], + output: pos, + coef: Coef::Constant(term.coef), + }); + } + VarSpec::Custom { gate_type, inputs } => { + res.gate_customs.push(GateCustom { + gate_type: *gate_type, + inputs: inputs + .iter() + .map(|x| I::Input::new(0, aq.var_pos[x])) + .collect(), + output: pos, + coef: Coef::Constant(term.coef), + }); + } + VarSpec::RandomLinear(vid) => { + res.gate_adds.push(GateAdd { + inputs: [I::Input::new(0, aq.var_pos[vid])], + output: pos, + coef: Coef::Random, + }); + } + } + } + } + } } - let pos = if let Some(p) = bq.var_pos.get(x) { - *p + // connect relays (this may generate cross layer connections) + if I::CROSS_LAYER_RELAY { + for (cur_layer, next_layer) in ic.occured_layers[x] + .iter() + .zip(ic.occured_layers[x].iter().skip(1)) + { + if cached_ress[next_layer - 1].is_none() { + let res = &mut ress[next_layer - 1]; + let aq = &lqs[*cur_layer]; + let bq = &lqs[*next_layer]; + let pos = if let Some(p) = bq.var_pos.get(&x) { + *p + } else { + assert_eq!(*next_layer, ic.output_layer); + continue; + }; + res.gate_adds.push(GateAdd { + inputs: [I::Input::new(next_layer - cur_layer - 1, aq.var_pos[&x])], + output: pos, + coef: Coef::Constant(C::CircuitField::one()), + }); + } + } } else { - assert_eq!(cur_layer + 1, ic.output_layer); - //assert!(!ic.output_order.contains_key(x)); - continue; - }; - // if it's not the first layer, just relay it - if ic.min_layer[*x] != next_layer { - res.gate_adds.push(GateAdd { - inputs: [aq.var_pos[x]], - output: pos, - coef: Coef::Constant(C::CircuitField::one()), - }); - continue; - } - if let Some(value) = ic.constant_like_variables.get(x) { - res.gate_consts.push(GateConst { - inputs: [], - output: pos, - coef: value.clone(), - }); - } else if ic.internal_variable_expr.contains_key(x) { - for term in ic.internal_variable_expr[x].iter() { - match &term.vars { - VarSpec::Const => { - res.gate_consts.push(GateConst { - inputs: [], - output: pos, - coef: Coef::Constant(term.coef), - }); - } - VarSpec::Linear(vid) => { - res.gate_adds.push(GateAdd { - inputs: [aq.var_pos[vid]], - output: pos, - coef: Coef::Constant(term.coef), - }); - } - VarSpec::Quad(vid0, vid1) => { - res.gate_muls.push(GateMul { - inputs: [aq.var_pos[vid0], aq.var_pos[vid1]], - output: pos, - coef: Coef::Constant(term.coef), - }); - } - VarSpec::Custom { gate_type, inputs } => { - res.gate_customs.push(GateCustom { - gate_type: *gate_type, - inputs: inputs.iter().map(|x| aq.var_pos[x]).collect(), - output: pos, - coef: Coef::Constant(term.coef), - }); - } - VarSpec::RandomLinear(vid) => { - res.gate_adds.push(GateAdd { - inputs: [aq.var_pos[vid]], - output: pos, - coef: Coef::Random, - }); - } + for cur_layer in ic.min_layer[x]..ic.max_layer[x] { + let next_layer = cur_layer + 1; + if cached_ress[cur_layer].is_none() { + let res = &mut ress[cur_layer]; + let aq = &lqs[cur_layer]; + let bq = &lqs[next_layer]; + let pos = if let Some(p) = bq.var_pos.get(&x) { + *p + } else { + assert_eq!(next_layer, ic.output_layer); + continue; + }; + res.gate_adds.push(GateAdd { + inputs: [I::Input::new(0, aq.var_pos[&x])], + output: pos, + coef: Coef::Constant(C::CircuitField::one()), + }); } } } } // also combined output variables - let cc = ic.combined_constraints[next_layer].as_ref(); - if let Some(cc) = cc { - let pos = bq.var_pos[&cc.id]; - for v in cc.variables.iter() { - let coef = if *v >= ic.num_var { - Coef::Constant(C::CircuitField::one()) - } else { - Coef::Random - }; - res.gate_adds.push(GateAdd { - inputs: [aq.var_pos[v]], - output: pos, - coef, - }); - } - for i in cc.sub_circuit_ids.iter() { - let insn_id = ic.sub_circuit_insn_ids[*i]; - let insn = &ic.sub_circuit_insn_refs[*i]; - let input_layer = ic.sub_circuit_start_layer[*i]; - let vid = self.circuits[&insn.sub_circuit_id].combined_constraints - [cur_layer - input_layer] - .as_ref() - .unwrap() - .id; - let vpid = self.circuits[&insn.sub_circuit_id].lcs[cur_layer - input_layer] - .vars - .get_idx(&vid); - let layout = self.layer_layout_pool.get(sub_cur_layout_all[&insn_id].id); - let spos = match &layout.inner { - LayerLayoutInner::Sparse { placement, .. } => placement - .iter() - .find_map(|(i, v)| if *v == vpid { Some(*i) } else { None }) - .unwrap(), - LayerLayoutInner::Dense { placement } => { - placement.iter().position(|x| *x == vpid).unwrap() - } - }; - res.gate_adds.push(GateAdd { - inputs: [sub_cur_layout_all[&insn_id].offset + spos], - output: pos, - coef: Coef::Constant(C::CircuitField::one()), - }); + for (cur_layer, ((cc, bq), aq)) in ic + .combined_constraints + .iter() + .zip(lqs.iter()) + .skip(1) + .zip(lqs.iter()) + .enumerate() + { + let res = &mut ress[cur_layer]; + if let Some(cc) = cc { + let pos = bq.var_pos[&cc.id]; + for v in cc.variables.iter() { + let coef = if *v >= ic.num_var { + Coef::Constant(C::CircuitField::one()) + } else { + Coef::Random + }; + res.gate_adds.push(GateAdd { + inputs: [Input::new(0, aq.var_pos[v])], + output: pos, + coef, + }); + } + for i in cc.sub_circuit_ids.iter() { + let insn_id = ic.sub_circuit_insn_ids[*i]; + let insn = &ic.sub_circuit_insn_refs[*i]; + let input_layer = ic.sub_circuit_start_layer[*i]; + let vid = self.circuits[&insn.sub_circuit_id].combined_constraints + [cur_layer - input_layer] + .as_ref() + .unwrap() + .id; + let vpid = self.circuits[&insn.sub_circuit_id].lcs[cur_layer - input_layer] + .vars + .get_idx(&vid); + let layout = self + .layer_layout_pool + .get(sub_layouts_of_layer[cur_layer][&insn_id].id); + let spos = match &layout.inner { + LayerLayoutInner::Sparse { placement, .. } => placement + .iter() + .find_map(|(i, v)| if *v == vpid { Some(*i) } else { None }) + .unwrap(), + LayerLayoutInner::Dense { placement } => { + placement.iter().position(|x| *x == vpid).unwrap() + } + }; + res.gate_adds.push(GateAdd { + inputs: [Input::new( + 0, + sub_layouts_of_layer[cur_layer][&insn_id].offset + spos, + )], + output: pos, + coef: Coef::Constant(C::CircuitField::one()), + }); + } } } - let res_id = self.compiled_circuits.len(); - self.compiled_circuits.push(res); - self.conncected_wires.insert(map_id, res_id); + let mut ress_ids = Vec::new(); + + for (res, cache) in ress.iter().zip(cached_ress.iter()) { + if let Some(cache) = cache { + ress_ids.push(*cache); + continue; + } + let res_id = self.compiled_circuits.len(); + self.compiled_circuits.push(res.clone()); + ress_ids.push(res_id); + } self.circuits.insert(circuit_id, ic); - res_id + ress_ids } } diff --git a/expander_compiler/tests/example_call_expander.rs b/expander_compiler/tests/example_call_expander.rs index 8ea12c7e..198744d8 100644 --- a/expander_compiler/tests/example_call_expander.rs +++ b/expander_compiler/tests/example_call_expander.rs @@ -1,9 +1,6 @@ use arith::Field; use expander_compiler::frontend::*; -use expander_config::{ - BN254ConfigKeccak, BN254ConfigSha2, GF2ExtConfigKeccak, GF2ExtConfigSha2, M31ExtConfigKeccak, - M31ExtConfigSha2, -}; +use rand::SeedableRng; declare_circuit!(Circuit { s: [Variable; 100], @@ -20,16 +17,14 @@ impl Define for Circuit { } } -fn example() -where - GKRC: expander_config::GKRConfig, -{ - let n_witnesses = ::pack_size(); +fn example() { + let n_witnesses = ::PACK_SIZE; println!("n_witnesses: {}", n_witnesses); let compile_result: CompileResult = compile(&Circuit::default()).unwrap(); let mut s = [C::CircuitField::zero(); 100]; + let mut rng = rand::rngs::StdRng::seed_from_u64(1235); for i in 0..s.len() { - s[i] = C::CircuitField::random_unsafe(&mut rand::thread_rng()); + s[i] = C::CircuitField::random_unsafe(&mut rng); } let assignment = Circuit:: { s, @@ -42,53 +37,47 @@ where .unwrap(); let output = compile_result.layered_circuit.run(&witness); for x in output.iter() { - assert_eq!(*x, true); + assert!(*x); } let mut expander_circuit = compile_result .layered_circuit - .export_to_expander::() + .export_to_expander::() .flatten(); - let config = expander_config::Config::::new( + let config = expander_config::Config::::new( expander_config::GKRScheme::Vanilla, - expander_config::MPIConfig::new(), + mpi_config::MPIConfig::new(), ); - let (simd_input, simd_public_input) = witness.to_simd::(); + let (simd_input, simd_public_input) = witness.to_simd::(); println!("{} {}", simd_input.len(), simd_public_input.len()); expander_circuit.layers[0].input_vals = simd_input; expander_circuit.public_input = simd_public_input.clone(); // prove expander_circuit.evaluate(); - let mut prover = gkr::Prover::new(&config); - prover.prepare_mem(&expander_circuit); - let (claimed_v, proof) = prover.prove(&mut expander_circuit); + let (claimed_v, proof) = gkr::executor::prove(&mut expander_circuit, &config); // verify - let verifier = gkr::Verifier::new(&config); - assert!(verifier.verify( + assert!(gkr::executor::verify( &mut expander_circuit, - &simd_public_input, - &claimed_v, - &proof + &config, + &proof, + &claimed_v )); } #[test] fn example_gf2() { - example::(); - example::(); + example::(); } #[test] fn example_m31() { - example::(); - example::(); + example::(); } #[test] fn example_bn254() { - example::(); - example::(); + example::(); } diff --git a/expander_compiler/tests/keccak_gf2.rs b/expander_compiler/tests/keccak_gf2.rs index 2d3b4422..445c3732 100644 --- a/expander_compiler/tests/keccak_gf2.rs +++ b/expander_compiler/tests/keccak_gf2.rs @@ -1,6 +1,7 @@ -use expander_compiler::frontend::*; +use expander_compiler::{circuit::layered::InputType, frontend::*}; +use extra::*; use internal::Serde; -use rand::{thread_rng, Rng}; +use rand::{Rng, SeedableRng}; use tiny_keccak::Hasher; const N_HASHES: usize = 1; @@ -34,8 +35,8 @@ fn rc() -> Vec { ] } -fn xor_in( - api: &mut API, +fn xor_in>( + api: &mut B, mut s: Vec>, buf: Vec>, ) -> Vec> { @@ -49,7 +50,10 @@ fn xor_in( s } -fn keccak_f(api: &mut API, mut a: Vec>) -> Vec> { +fn keccak_f>( + api: &mut B, + mut a: Vec>, +) -> Vec> { let mut b = vec![vec![api.constant(0); 64]; 25]; let mut c = vec![vec![api.constant(0); 64]; 5]; let mut d = vec![vec![api.constant(0); 64]; 5]; @@ -133,7 +137,7 @@ fn keccak_f(api: &mut API, mut a: Vec>) -> Vec(api: &mut API, a: Vec, b: Vec) -> Vec { +fn xor>(api: &mut B, a: Vec, b: Vec) -> Vec { let nbits = a.len(); let mut bits_res = vec![api.constant(0); nbits]; for i in 0..nbits { @@ -142,7 +146,7 @@ fn xor(api: &mut API, a: Vec, b: Vec) -> Vec(api: &mut API, a: Vec, b: Vec) -> Vec { +fn and>(api: &mut B, a: Vec, b: Vec) -> Vec { let nbits = a.len(); let mut bits_res = vec![api.constant(0); nbits]; for i in 0..nbits { @@ -151,7 +155,7 @@ fn and(api: &mut API, a: Vec, b: Vec) -> Vec(api: &mut API, a: Vec) -> Vec { +fn not>(api: &mut B, a: Vec) -> Vec { let mut bits_res = vec![api.constant(0); a.len()]; for i in 0..a.len() { bits_res[i] = api.sub(1, a[i].clone()); @@ -189,7 +193,7 @@ declare_circuit!(Keccak256Circuit { out: [[PublicVariable; 256]; N_HASHES], }); -fn compute_keccak(api: &mut API, p: &Vec) -> Vec { +fn compute_keccak>(api: &mut B, p: &Vec) -> Vec { let mut ss = vec![vec![api.constant(0); 64]; 25]; let mut new_p = p.clone(); let mut append_data = vec![0; 136 - 64]; @@ -211,12 +215,13 @@ fn compute_keccak(api: &mut API, p: &Vec) -> Vec for Keccak256Circuit { - fn define(&self, api: &mut API) { +impl GenericDefine for Keccak256Circuit { + fn define>(&self, api: &mut Builder) { for i in 0..N_HASHES { // You can use api.memorized_simple_call for sub-circuits - // let out = api.memorized_simple_call(compute_keccak, &self.p[i].to_vec()); - let out = compute_keccak(api, &self.p[i].to_vec()); + // Or use the function directly + let out = api.memorized_simple_call(compute_keccak, &self.p[i].to_vec()); + //let out = compute_keccak(api, &self.p[i].to_vec()); for j in 0..256 { api.assert_is_equal(out[j].clone(), self.out[i][j].clone()); } @@ -224,19 +229,17 @@ impl Define for Keccak256Circuit { } } -#[test] -fn keccak_gf2_main() { - let compile_result = compile(&Keccak256Circuit::default()).unwrap(); - let CompileResult { - witness_solver, - layered_circuit, - } = compile_result; - +fn keccak_gf2_test( + witness_solver: WitnessSolver, + layered_circuit: expander_compiler::circuit::layered::Circuit, + filename: &str, +) { let mut assignment = Keccak256Circuit::::default(); + let mut rng = rand::rngs::StdRng::seed_from_u64(1235); for k in 0..N_HASHES { let mut data = vec![0u8; 64]; for i in 0..64 { - data[i] = thread_rng().gen(); + data[i] = rng.gen(); } let mut hash = tiny_keccak::Keccak::v256(); hash.update(&data); @@ -288,17 +291,105 @@ fn keccak_gf2_main() { .solve_witnesses(&assignments_correct) .unwrap(); - let file = std::fs::File::create("circuit_gf2.txt").unwrap(); + let file = std::fs::File::create(format!("circuit_{}.txt", filename)).unwrap(); let writer = std::io::BufWriter::new(file); layered_circuit.serialize_into(writer).unwrap(); - let file = std::fs::File::create("witness_gf2.txt").unwrap(); + let file = std::fs::File::create(format!("witness_{}.txt", filename)).unwrap(); let writer = std::io::BufWriter::new(file); witness.serialize_into(writer).unwrap(); - let file = std::fs::File::create("witness_gf2_solver.txt").unwrap(); + let file = std::fs::File::create(format!("witness_{}_solver.txt", filename)).unwrap(); let writer = std::io::BufWriter::new(file); witness_solver.serialize_into(writer).unwrap(); println!("dumped to files"); } + +#[test] +fn keccak_gf2_main() { + let compile_result = + compile_generic(&Keccak256Circuit::default(), CompileOptions::default()).unwrap(); + let CompileResult { + witness_solver, + layered_circuit, + } = compile_result; + keccak_gf2_test(witness_solver, layered_circuit, "gf2"); +} + +#[test] +fn keccak_gf2_main_cross_layer() { + let compile_result = + compile_generic_cross_layer(&Keccak256Circuit::default(), CompileOptions::default()) + .unwrap(); + let CompileResultCrossLayer { + witness_solver, + layered_circuit, + } = compile_result; + keccak_gf2_test(witness_solver, layered_circuit, "gf2_cross_layer"); +} + +#[test] +fn keccak_gf2_debug() { + let mut assignment = Keccak256Circuit::::default(); + let mut rng = rand::rngs::StdRng::seed_from_u64(1235); + for k in 0..N_HASHES { + let mut data = vec![0u8; 64]; + for i in 0..64 { + data[i] = rng.gen(); + } + let mut hash = tiny_keccak::Keccak::v256(); + hash.update(&data); + let mut output = [0u8; 32]; + hash.finalize(&mut output); + for i in 0..64 { + for j in 0..8 { + assignment.p[k][i * 8 + j] = ((data[i] >> j) as u32 & 1).into(); + } + } + for i in 0..32 { + for j in 0..8 { + assignment.out[k][i * 8 + j] = ((output[i] >> j) as u32 & 1).into(); + } + } + } + + debug_eval( + &Keccak256Circuit::default(), + &assignment, + EmptyHintCaller::new(), + ); +} + +#[test] +#[should_panic] +fn keccak_gf2_debug_error() { + let mut assignment = Keccak256Circuit::::default(); + let mut rng = rand::rngs::StdRng::seed_from_u64(1235); + for k in 0..N_HASHES { + let mut data = vec![0u8; 64]; + for i in 0..64 { + data[i] = rng.gen(); + } + let mut hash = tiny_keccak::Keccak::v256(); + hash.update(&data); + let mut output = [0u8; 32]; + hash.finalize(&mut output); + for i in 0..64 { + for j in 0..8 { + assignment.p[k][i * 8 + j] = ((data[i] >> j) as u32 & 1).into(); + } + } + for i in 0..32 { + for j in 0..8 { + assignment.out[k][i * 8 + j] = (((output[i] >> j) as u32 & 1) ^ 1).into(); + } + } + } + + debug_eval( + &Keccak256Circuit::default(), + &assignment, + EmptyHintCaller::new(), + ); +} diff --git a/expander_compiler/tests/keccak_gf2_full.rs b/expander_compiler/tests/keccak_gf2_full.rs index cab14168..973c6f68 100644 --- a/expander_compiler/tests/keccak_gf2_full.rs +++ b/expander_compiler/tests/keccak_gf2_full.rs @@ -1,5 +1,5 @@ use expander_compiler::frontend::*; -use rand::{thread_rng, Rng}; +use rand::{Rng, SeedableRng}; use tiny_keccak::Hasher; const N_HASHES: usize = 1; @@ -232,10 +232,11 @@ fn keccak_gf2_full() { } = compile_result; let mut assignment = Keccak256Circuit::::default(); + let mut rng = rand::rngs::StdRng::seed_from_u64(1235); for k in 0..N_HASHES { let mut data = vec![0u8; 64]; for i in 0..64 { - data[i] = thread_rng().gen(); + data[i] = rng.gen(); } let mut hash = tiny_keccak::Keccak::v256(); hash.update(&data); @@ -281,31 +282,30 @@ fn keccak_gf2_full() { assert_eq!(res, expected_res); println!("test 3 passed"); + // alternatively, you can specify the particular config like gkr_field_config::GF2ExtConfig let mut expander_circuit = layered_circuit - .export_to_expander::() + .export_to_expander::<::DefaultGKRFieldConfig>() .flatten(); - let config = expander_config::Config::::new( + let config = expander_config::Config::<::DefaultGKRConfig>::new( expander_config::GKRScheme::Vanilla, - expander_config::MPIConfig::new(), + mpi_config::MPIConfig::new(), ); - let (simd_input, simd_public_input) = witness.to_simd::(); + let (simd_input, simd_public_input) = + witness.to_simd::<::DefaultSimdField>(); println!("{} {}", simd_input.len(), simd_public_input.len()); expander_circuit.layers[0].input_vals = simd_input; expander_circuit.public_input = simd_public_input.clone(); // prove expander_circuit.evaluate(); - let mut prover = gkr::Prover::new(&config); - prover.prepare_mem(&expander_circuit); - let (claimed_v, proof) = prover.prove(&mut expander_circuit); + let (claimed_v, proof) = gkr::executor::prove(&mut expander_circuit, &config); // verify - let verifier = gkr::Verifier::new(&config); - assert!(verifier.verify( + assert!(gkr::executor::verify( &mut expander_circuit, - &simd_public_input, - &claimed_v, - &proof + &config, + &proof, + &claimed_v )); } diff --git a/expander_compiler/tests/keccak_gf2_full_crosslayer.rs b/expander_compiler/tests/keccak_gf2_full_crosslayer.rs new file mode 100644 index 00000000..6e4bc6d4 --- /dev/null +++ b/expander_compiler/tests/keccak_gf2_full_crosslayer.rs @@ -0,0 +1,307 @@ +use expander_compiler::frontend::*; +use expander_transcript::{BytesHashTranscript, SHA256hasher, Transcript}; +use rand::{Rng, SeedableRng}; +use tiny_keccak::Hasher; + +const N_HASHES: usize = 1; + +fn rc() -> Vec { + vec![ + 0x0000000000000001, + 0x0000000000008082, + 0x800000000000808A, + 0x8000000080008000, + 0x000000000000808B, + 0x0000000080000001, + 0x8000000080008081, + 0x8000000000008009, + 0x000000000000008A, + 0x0000000000000088, + 0x0000000080008009, + 0x000000008000000A, + 0x000000008000808B, + 0x800000000000008B, + 0x8000000000008089, + 0x8000000000008003, + 0x8000000000008002, + 0x8000000000000080, + 0x000000000000800A, + 0x800000008000000A, + 0x8000000080008081, + 0x8000000000008080, + 0x0000000080000001, + 0x8000000080008008, + ] +} + +fn xor_in>( + api: &mut B, + mut s: Vec>, + buf: Vec>, +) -> Vec> { + for y in 0..5 { + for x in 0..5 { + if x + 5 * y < buf.len() { + s[5 * x + y] = xor(api, s[5 * x + y].clone(), buf[x + 5 * y].clone()) + } + } + } + s +} + +fn keccak_f>( + api: &mut B, + mut a: Vec>, +) -> Vec> { + let mut b = vec![vec![api.constant(0); 64]; 25]; + let mut c = vec![vec![api.constant(0); 64]; 5]; + let mut d = vec![vec![api.constant(0); 64]; 5]; + let mut da = vec![vec![api.constant(0); 64]; 5]; + let rc = rc(); + + for i in 0..24 { + for j in 0..5 { + let t1 = xor(api, a[j * 5 + 1].clone(), a[j * 5 + 2].clone()); + let t2 = xor(api, a[j * 5 + 3].clone(), a[j * 5 + 4].clone()); + c[j] = xor(api, t1, t2); + } + + for j in 0..5 { + d[j] = xor( + api, + c[(j + 4) % 5].clone(), + rotate_left::(&c[(j + 1) % 5], 1), + ); + da[j] = xor( + api, + a[((j + 4) % 5) * 5].clone(), + rotate_left::(&a[((j + 1) % 5) * 5], 1), + ); + } + + for j in 0..25 { + let tmp = xor(api, da[j / 5].clone(), a[j].clone()); + a[j] = xor(api, tmp, d[j / 5].clone()); + } + + /*Rho and pi steps*/ + b[0] = a[0].clone(); + + b[8] = rotate_left::(&a[1], 36); + b[11] = rotate_left::(&a[2], 3); + b[19] = rotate_left::(&a[3], 41); + b[22] = rotate_left::(&a[4], 18); + + b[2] = rotate_left::(&a[5], 1); + b[5] = rotate_left::(&a[6], 44); + b[13] = rotate_left::(&a[7], 10); + b[16] = rotate_left::(&a[8], 45); + b[24] = rotate_left::(&a[9], 2); + + b[4] = rotate_left::(&a[10], 62); + b[7] = rotate_left::(&a[11], 6); + b[10] = rotate_left::(&a[12], 43); + b[18] = rotate_left::(&a[13], 15); + b[21] = rotate_left::(&a[14], 61); + + b[1] = rotate_left::(&a[15], 28); + b[9] = rotate_left::(&a[16], 55); + b[12] = rotate_left::(&a[17], 25); + b[15] = rotate_left::(&a[18], 21); + b[23] = rotate_left::(&a[19], 56); + + b[3] = rotate_left::(&a[20], 27); + b[6] = rotate_left::(&a[21], 20); + b[14] = rotate_left::(&a[22], 39); + b[17] = rotate_left::(&a[23], 8); + b[20] = rotate_left::(&a[24], 14); + + /*Xi state*/ + + for j in 0..25 { + let t = not(api, b[(j + 5) % 25].clone()); + let t = and(api, t, b[(j + 10) % 25].clone()); + a[j] = xor(api, b[j].clone(), t); + } + + /*Last step*/ + + for j in 0..64 { + if rc[i] >> j & 1 == 1 { + a[0][j] = api.sub(1, a[0][j]); + } + } + } + + a +} + +fn xor>(api: &mut B, a: Vec, b: Vec) -> Vec { + let nbits = a.len(); + let mut bits_res = vec![api.constant(0); nbits]; + for i in 0..nbits { + bits_res[i] = api.add(a[i].clone(), b[i].clone()); + } + bits_res +} + +fn and>(api: &mut B, a: Vec, b: Vec) -> Vec { + let nbits = a.len(); + let mut bits_res = vec![api.constant(0); nbits]; + for i in 0..nbits { + bits_res[i] = api.mul(a[i].clone(), b[i].clone()); + } + bits_res +} + +fn not>(api: &mut B, a: Vec) -> Vec { + let mut bits_res = vec![api.constant(0); a.len()]; + for i in 0..a.len() { + bits_res[i] = api.sub(1, a[i].clone()); + } + bits_res +} + +fn rotate_left(bits: &Vec, k: usize) -> Vec { + let n = bits.len(); + let s = k & (n - 1); + let mut new_bits = bits[(n - s) as usize..].to_vec(); + new_bits.append(&mut bits[0..(n - s) as usize].to_vec()); + new_bits +} + +fn copy_out_unaligned(s: Vec>, rate: usize, output_len: usize) -> Vec { + let mut out = vec![]; + let w = 8; + let mut b = 0; + while b < output_len { + for y in 0..5 { + for x in 0..5 { + if x + 5 * y < rate / w && b < output_len { + out.append(&mut s[5 * x + y].clone()); + b += 8; + } + } + } + } + out +} + +declare_circuit!(Keccak256Circuit { + p: [[Variable; 64 * 8]; N_HASHES], + out: [[Variable; 256]; N_HASHES], +}); + +fn compute_keccak>(api: &mut B, p: &Vec) -> Vec { + let mut ss = vec![vec![api.constant(0); 64]; 25]; + let mut new_p = p.clone(); + let mut append_data = vec![0; 136 - 64]; + append_data[0] = 1; + append_data[135 - 64] = 0x80; + for i in 0..136 - 64 { + for j in 0..8 { + new_p.push(api.constant(((append_data[i] >> j) & 1) as u32)); + } + } + let mut p = vec![vec![api.constant(0); 64]; 17]; + for i in 0..17 { + for j in 0..64 { + p[i][j] = new_p[i * 64 + j].clone(); + } + } + ss = xor_in(api, ss, p); + ss = keccak_f(api, ss); + copy_out_unaligned(ss, 136, 32) +} + +impl GenericDefine for Keccak256Circuit { + fn define>(&self, api: &mut Builder) { + for i in 0..N_HASHES { + // You can use api.memorized_simple_call for sub-circuits + // let out = api.memorized_simple_call(compute_keccak, &self.p[i].to_vec()); + let out = compute_keccak(api, &self.p[i].to_vec()); + for j in 0..256 { + api.assert_is_equal(out[j].clone(), self.out[i][j].clone()); + } + } + } +} + +#[test] +fn keccak_gf2_full_crosslayer() { + let compile_result = + compile_generic_cross_layer(&Keccak256Circuit::default(), CompileOptions::default()) + .unwrap(); + let CompileResultCrossLayer { + witness_solver, + layered_circuit, + } = compile_result; + + let mut assignment = Keccak256Circuit::::default(); + let mut rng = rand::rngs::StdRng::seed_from_u64(1235); + for k in 0..N_HASHES { + let mut data = vec![0u8; 64]; + for i in 0..64 { + data[i] = rng.gen(); + } + let mut hash = tiny_keccak::Keccak::v256(); + hash.update(&data); + let mut output = [0u8; 32]; + hash.finalize(&mut output); + for i in 0..64 { + for j in 0..8 { + assignment.p[k][i * 8 + j] = ((data[i] >> j) as u32 & 1).into(); + } + } + for i in 0..32 { + for j in 0..8 { + assignment.out[k][i * 8 + j] = ((output[i] >> j) as u32 & 1).into(); + } + } + } + + let mut assignments = Vec::new(); + for _ in 0..8 { + assignments.push(assignment.clone()); + } + let witness = witness_solver.solve_witnesses(&assignments).unwrap(); + let res = layered_circuit.run(&witness); + let expected_res = vec![true; 8]; + assert_eq!(res, expected_res); + println!("basic test passed"); + + let expander_circuit = layered_circuit + .export_to_expander::() + .flatten(); + + let (simd_input, simd_public_input) = witness.to_simd::(); + println!("{} {}", simd_input.len(), simd_public_input.len()); + assert_eq!(simd_public_input.len(), 0); // public input is not supported in current virgo++ + + let mut transcript = BytesHashTranscript::< + ::ChallengeField, + SHA256hasher, + >::new(); + + let connections = crosslayer_prototype::CrossLayerConnections::parse_circuit(&expander_circuit); + + let start_time = std::time::Instant::now(); + let evals = expander_circuit.evaluate(&simd_input); + let mut sp = + crosslayer_prototype::CrossLayerProverScratchPad::::new( + expander_circuit.layers.len(), + expander_circuit.max_num_input_var(), + expander_circuit.max_num_output_var(), + 1, + ); + let (_output_claim, _input_challenge, _input_claim) = crosslayer_prototype::prove_gkr( + &expander_circuit, + &evals, + &connections, + &mut transcript, + &mut sp, + ); + let stop_time = std::time::Instant::now(); + let duration = stop_time.duration_since(start_time); + println!("Time elapsed {} ms", duration.as_millis()); +} diff --git a/expander_compiler/tests/keccak_gf2_vec.rs b/expander_compiler/tests/keccak_gf2_vec.rs index af8acd3f..207f1c87 100644 --- a/expander_compiler/tests/keccak_gf2_vec.rs +++ b/expander_compiler/tests/keccak_gf2_vec.rs @@ -1,5 +1,5 @@ use expander_compiler::frontend::*; -use rand::{thread_rng, Rng}; +use rand::{Rng, SeedableRng}; use tiny_keccak::Hasher; const N_HASHES: usize = 4; @@ -234,12 +234,13 @@ fn keccak_gf2_vec() { } = compile_result; let mut assignment = Keccak256Circuit::::default(); + let mut rng = rand::rngs::StdRng::seed_from_u64(1235); assignment.p = vec![vec![GF2::from(0); 64 * 8]; N_HASHES]; assignment.out = vec![vec![GF2::from(0); 32 * 8]; N_HASHES]; for k in 0..N_HASHES { let mut data = vec![0u8; 64]; for i in 0..64 { - data[i] = thread_rng().gen(); + data[i] = rng.gen(); } let mut hash = tiny_keccak::Keccak::v256(); hash.update(&data); diff --git a/expander_compiler/tests/keccak_m31_bn254.rs b/expander_compiler/tests/keccak_m31_bn254.rs index a541074c..686f862d 100644 --- a/expander_compiler/tests/keccak_m31_bn254.rs +++ b/expander_compiler/tests/keccak_m31_bn254.rs @@ -2,7 +2,7 @@ use ethnum::U256; use expander_compiler::field::{FieldArith, FieldModulus}; use expander_compiler::frontend::*; use internal::Serde; -use rand::{thread_rng, Rng}; +use rand::{Rng, SeedableRng}; use tiny_keccak::Hasher; const N_HASHES: usize = 2; @@ -292,10 +292,11 @@ fn keccak_big_field(field_name: &str) { } = compile_result; let mut assignment = Keccak256Circuit::::default(); + let mut rng = rand::rngs::StdRng::seed_from_u64(1235); for k in 0..N_HASHES { let mut data = vec![0u8; 64]; for i in 0..64 { - data[i] = thread_rng().gen(); + data[i] = rng.gen(); } let mut hash = tiny_keccak::Keccak::v256(); hash.update(&data); @@ -315,7 +316,7 @@ fn keccak_big_field(field_name: &str) { let out_compressed = compress_bits(out_bits); assert_eq!(out_compressed.len(), CHECK_PARTITIONS); for (i, x) in out_compressed.iter().enumerate() { - assert!(U256::from(*x as u64) < C::CircuitField::modulus()); + assert!(U256::from(*x as u64) < C::CircuitField::MODULUS); assignment.out[k][i] = C::CircuitField::from(*x as u32); } } diff --git a/expander_compiler/tests/mul_fanout_limit.rs b/expander_compiler/tests/mul_fanout_limit.rs new file mode 100644 index 00000000..dd4076b7 --- /dev/null +++ b/expander_compiler/tests/mul_fanout_limit.rs @@ -0,0 +1,74 @@ +use expander_compiler::{circuit::layered::InputUsize, frontend::*}; + +declare_circuit!(Circuit { + x: [Variable; 16], + y: [Variable; 512], + sum: Variable, +}); + +impl GenericDefine for Circuit { + fn define>(&self, builder: &mut Builder) { + let mut sum = builder.constant(0); + for i in 0..16 { + for j in 0..512 { + let t = builder.mul(self.x[i], self.y[j]); + sum = builder.add(sum, t); + } + } + builder.assert_is_equal(self.sum, sum); + } +} + +fn mul_fanout_limit(limit: usize) { + let compile_result = compile_generic( + &Circuit::default(), + CompileOptions::default().with_mul_fanout_limit(limit), + ) + .unwrap(); + let circuit = compile_result.layered_circuit; + for segment in circuit.segments.iter() { + let mut ref_num = vec![0; segment.num_inputs.get(0)]; + for m in segment.gate_muls.iter() { + ref_num[m.inputs[0].offset] += 1; + ref_num[m.inputs[1].offset] += 1; + } + for x in ref_num.iter() { + assert!(*x <= limit); + } + } +} + +#[test] +fn mul_fanout_limit_2() { + mul_fanout_limit(2); +} + +#[test] +fn mul_fanout_limit_3() { + mul_fanout_limit(3); +} + +#[test] +fn mul_fanout_limit_4() { + mul_fanout_limit(4); +} + +#[test] +fn mul_fanout_limit_16() { + mul_fanout_limit(16); +} + +#[test] +fn mul_fanout_limit_64() { + mul_fanout_limit(64); +} + +#[test] +fn mul_fanout_limit_256() { + mul_fanout_limit(256); +} + +#[test] +fn mul_fanout_limit_1024() { + mul_fanout_limit(1024); +} diff --git a/expander_compiler/tests/to_binary_hint.rs b/expander_compiler/tests/to_binary_hint.rs new file mode 100644 index 00000000..258a5e00 --- /dev/null +++ b/expander_compiler/tests/to_binary_hint.rs @@ -0,0 +1,89 @@ +use std::cell::RefCell; +use std::rc::Rc; + +use expander_compiler::frontend::*; + +declare_circuit!(Circuit { + input: PublicVariable, +}); + +fn to_binary(api: &mut API, x: Variable, n_bits: usize) -> Vec { + api.new_hint("myhint.tobinary", &[x], n_bits) +} + +fn from_binary(api: &mut API, bits: Vec) -> Variable { + let mut res = api.constant(0); + for i in 0..bits.len() { + let coef = 1 << i; + let cur = api.mul(coef, bits[i]); + res = api.add(res, cur); + } + res +} + +impl Define for Circuit { + fn define(&self, builder: &mut API) { + let bits = to_binary(builder, self.input, 8); + let x = from_binary(builder, bits); + builder.assert_is_equal(x, self.input); + } +} + +fn to_binary_hint(x: &[M31], y: &mut [M31]) -> Result<(), Error> { + let t = x[0].to_u256(); + for (i, k) in y.iter_mut().enumerate() { + *k = M31::from_u256(t >> i as u32 & 1); + } + Ok(()) +} + +#[test] +fn test_300() { + let mut hint_registry = HintRegistry::::new(); + hint_registry.register("myhint.tobinary", to_binary_hint); + + let compile_result = compile(&Circuit::default()).unwrap(); + for i in 0..300 { + let assignment = Circuit:: { + input: M31::from(i as u32), + }; + let witness = compile_result + .witness_solver + .solve_witness_with_hints(&assignment, &mut hint_registry) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![i < 256]); + } +} + +#[test] +fn test_300_closure() { + let mut hint_registry = HintRegistry::::new(); + let call_count = Rc::new(RefCell::new(0)); + let call_count_clone = call_count.clone(); + hint_registry.register( + "myhint.tobinary", + move |x: &[M31], y: &mut [M31]| -> Result<(), Error> { + *call_count_clone.borrow_mut() += 1; + let t = x[0].to_u256(); + for (i, k) in y.iter_mut().enumerate() { + *k = M31::from_u256(t >> i as u32 & 1); + } + Ok(()) + }, + ); + + let compile_result = compile(&Circuit::default()).unwrap(); + for i in 0..300 { + let assignment = Circuit:: { + input: M31::from(i as u32), + }; + let witness = compile_result + .witness_solver + .solve_witness_with_hints(&assignment, &mut hint_registry) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![i < 256]); + } + assert_eq!(*call_count.borrow(), 300); +}