diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 000000000..d133ffe95 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,6 @@ +version: 2 +updates: + - package-ecosystem: "cargo" + directories: ["/", "/examples"] + schedule: + interval: "weekly" diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 8f77818e1..2918a3ce1 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -10,7 +10,7 @@ on: jobs: skip_check: - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 outputs: should_skip: ${{ steps.skip_check.outputs.should_skip }} steps: @@ -30,7 +30,7 @@ jobs: name: Integration testing timeout-minutes: 30 - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 strategy: matrix: diff --git a/.github/workflows/lints.yml b/.github/workflows/lints.yml index 36991eb21..299714e9b 100644 --- a/.github/workflows/lints.yml +++ b/.github/workflows/lints.yml @@ -10,7 +10,7 @@ on: jobs: skip_check: - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 outputs: should_skip: ${{ steps.skip_check.outputs.should_skip }} steps: @@ -30,7 +30,7 @@ jobs: name: Various lints timeout-minutes: 30 - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 strategy: matrix: @@ -77,3 +77,8 @@ jobs: cargo make clippy # Same for clippy: cargo clippy --workspace --all-targets --exclude ceno_rt --release + + - name: Install taplo + run: taplo --version || cargo install taplo-cli + - name: Run taplo + run: taplo fmt --check --diff diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 5df217a1f..e922bb3de 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -10,7 +10,7 @@ on: jobs: skip_check: - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 outputs: should_skip: ${{ steps.skip_check.outputs.should_skip }} steps: @@ -29,7 +29,7 @@ jobs: (github.event.action == 'ready_for_review' || needs.skip_check.outputs.should_skip != 'true') name: Run Tests timeout-minutes: 30 - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 strategy: matrix: diff --git a/.gitignore b/.gitignore index 3f7f23880..b432b7fc9 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ target log.txt logs/ table_cache_dev_* +.DS_Store diff --git a/Cargo.lock b/Cargo.lock index 8208ccf7c..b3890f899 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "addr2line" @@ -59,9 +59,9 @@ checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" [[package]] name = "anstream" -version = "0.6.15" +version = "0.6.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64e15c1ab1f89faffbf04a634d5e1962e9074f2741eef6d97f3c4e322426d526" +checksum = "23a1e53f0f5d86382dafe1cf314783b2044280f406e7e1506368220ad11b1338" dependencies = [ "anstyle", "anstyle-parse", @@ -74,97 +74,43 @@ dependencies = [ [[package]] name = "anstyle" -version = "1.0.8" +version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bec1de6f59aedf83baf9ff929c98f2ad654b97c9510f4e70cf6f661d49fd5b1" +checksum = "8365de52b16c035ff4fcafe0092ba9390540e3e352870ac09933bebcaa2c8c56" [[package]] name = "anstyle-parse" -version = "0.2.5" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb47de1e80c2b463c735db5b217a0ddc39d612e7ac9e2e96a5aed1f57616c1cb" +checksum = "3b2d16507662817a6a20a9ea92df6652ee4f94f914589377d69f3b21bc5798a9" dependencies = [ "utf8parse", ] [[package]] name = "anstyle-query" -version = "1.1.1" +version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d36fc52c7f6c869915e99412912f22093507da8d9e942ceaf66fe4b7c14422a" +checksum = "79947af37f4177cfead1110013d678905c37501914fba0efea834c3fe9a8d60c" dependencies = [ - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] name = "anstyle-wincon" -version = "3.0.4" +version = "3.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5bf74e1b6e971609db8ca7a9ce79fd5768ab6ae46441c572e46cf596f59e57f8" +checksum = "2109dbce0e72be3ec00bed26e6a7479ca384ad226efdd66db8fa2e3a38c83125" dependencies = [ "anstyle", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] name = "anyhow" -version = "1.0.89" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86fdf8605db99b54d3cd748a44c6d04df638eb5dafb219b135d0149bd0db01f6" - -[[package]] -name = "ark-ff" -version = "0.4.2" +version = "1.0.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec847af850f44ad29048935519032c33da8aa03340876d351dfab5660d2966ba" -dependencies = [ - "ark-ff-asm", - "ark-ff-macros", - "ark-serialize", - "ark-std", - "derivative", - "digest", - "itertools 0.10.5", - "num-bigint", - "num-traits", - "paste", - "rustc_version", - "zeroize", -] - -[[package]] -name = "ark-ff-asm" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ed4aa4fe255d0bc6d79373f7e31d2ea147bcf486cba1be5ba7ea85abdb92348" -dependencies = [ - "quote", - "syn 1.0.109", -] - -[[package]] -name = "ark-ff-macros" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7abe79b0e4288889c4574159ab790824d0033b9fdcb2a112a3182fac2e514565" -dependencies = [ - "num-bigint", - "num-traits", - "proc-macro2", - "quote", - "syn 1.0.109", -] - -[[package]] -name = "ark-serialize" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adb7b85a02b83d2f22f89bd5cac66c9c89474240cb6207cb1efc16d098e822a5" -dependencies = [ - "ark-std", - "digest", - "num-bigint", -] +checksum = "c042108f3ed77fd83760a5fd79b53be043192bb3b9dba91d8c574c0ada7850c8" [[package]] name = "ark-std" @@ -189,17 +135,6 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" -[[package]] -name = "atty" -version = "0.2.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" -dependencies = [ - "hermit-abi 0.1.19", - "libc", - "winapi", -] - [[package]] name = "autocfg" version = "1.4.0" @@ -251,15 +186,6 @@ dependencies = [ "wyz", ] -[[package]] -name = "blake2" -version = "0.10.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe" -dependencies = [ - "digest", -] - [[package]] name = "blake2b_simd" version = "1.0.2" @@ -271,15 +197,6 @@ dependencies = [ "constant_time_eq", ] -[[package]] -name = "block-buffer" -version = "0.10.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" -dependencies = [ - "generic-array", -] - [[package]] name = "bumpalo" version = "3.16.0" @@ -288,9 +205,9 @@ checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" [[package]] name = "bytemuck" -version = "1.18.0" +version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94bbb0ad554ad961ddc5da507a12a29b14e4ae5bda06b19f575a3e6079d2e2ae" +checksum = "8334215b81e418a0a7bdb8ef0849474f40bb10c8b71f1c4ed315cff49f32494d" [[package]] name = "byteorder" @@ -306,9 +223,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.1.28" +version = "1.1.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e80e3b6a3ab07840e1cae9b0666a63970dc28e8ed5ffbcdacbfc760c281bfc1" +checksum = "c2e7962b54006dcfcc61cb72735f4d89bb97061dd6a7ed882ec6b8ee53714c6f" dependencies = [ "shlex", ] @@ -324,8 +241,8 @@ dependencies = [ "anyhow", "ceno-examples", "elf", - "strum 0.25.0", - "strum_macros 0.25.3", + "strum", + "strum_macros", "tracing", ] @@ -334,7 +251,6 @@ name = "ceno_rt" version = "0.1.0" dependencies = [ "riscv", - "riscv-rt", ] [[package]] @@ -346,14 +262,13 @@ dependencies = [ "ceno_emul", "cfg-if", "clap", - "const_env", "criterion", "ff", "ff_ext", "generic_static", "glob", "goldilocks", - "itertools 0.12.1", + "itertools 0.13.0", "mpcs", "multilinear_extensions", "paste", @@ -363,8 +278,8 @@ dependencies = [ "rayon", "serde", "serde_json", - "strum 0.25.0", - "strum_macros 0.25.3", + "strum", + "strum_macros", "sumcheck", "tempfile", "thread_local", @@ -445,10 +360,10 @@ version = "4.5.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.85", ] [[package]] @@ -459,9 +374,9 @@ checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97" [[package]] name = "colorchoice" -version = "1.0.2" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3fd119d74b830634cea2a0f58bbd0d54540518a14397557951e79340abc28c0" +checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" [[package]] name = "colored" @@ -493,26 +408,6 @@ dependencies = [ "tiny-keccak", ] -[[package]] -name = "const_env" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e9e4f72c6e3398ca6da372abd9affd8f89781fe728869bbf986206e9af9627e" -dependencies = [ - "const_env_impl", -] - -[[package]] -name = "const_env_impl" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a4f51209740b5e1589e702b3044cdd4562cef41b6da404904192ffffb852d62" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", -] - [[package]] name = "constant_time_eq" version = "0.3.1" @@ -575,9 +470,9 @@ dependencies = [ [[package]] name = "critical-section" -version = "1.1.3" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f64009896348fc5af4222e9cf7d7d82a95a256c634ebcf61c53e4ea461422242" +checksum = "790eea4361631c5e7d22598ecd5723ff611904e3344ce8720784c93e3d83d40b" [[package]] name = "crossbeam-channel" @@ -632,7 +527,8 @@ dependencies = [ [[package]] name = "ctr" version = "0.9.2" -source = "git+https://github.com/RustCrypto/block-modes?rev=78fefba#78fefba5a80f3d788fd8ba7c4483625148f5dc18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0369ee1ad671834580515889b80f2ea915f23b8be8d0daa4bbaf2ac5c7590835" dependencies = [ "cipher", ] @@ -646,28 +542,6 @@ dependencies = [ "uuid", ] -[[package]] -name = "derivative" -version = "2.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcc3dd5e9e9c0b295d6e1e4d811fb6f157d5ffd784b8d202fc62eac8035a770b" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", -] - -[[package]] -name = "digest" -version = "0.10.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" -dependencies = [ - "block-buffer", - "crypto-common", - "subtle", -] - [[package]] name = "either" version = "1.13.0" @@ -686,17 +560,27 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "361a90feb7004eca4019fb28352a9465666b24f840f5c3cddf0ff13920590b89" +[[package]] +name = "env_filter" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f2c92ceda6ceec50f43169f9ee8424fe2db276791afde7b2cd8bc084cb376ab" +dependencies = [ + "log", + "regex", +] + [[package]] name = "env_logger" -version = "0.7.1" +version = "0.11.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44533bbbb3bb3c1fa17d9f2e4e38bbbaf8396ba82193c4cb1b6445d711445d36" +checksum = "e13fa619b91fb2381732789fc5de83b45675e882f66623b7d8cb4f643017018d" dependencies = [ - "atty", + "anstream", + "anstyle", + "env_filter", "humantime", "log", - "regex", - "termcolor", ] [[package]] @@ -814,13 +698,12 @@ version = "0.1.0" dependencies = [ "ark-std", "cfg-if", - "const_env", "criterion", "crossbeam-channel", "ff", "ff_ext", "goldilocks", - "itertools 0.12.1", + "itertools 0.13.0", "multilinear_extensions", "pprof", "rayon", @@ -844,7 +727,7 @@ dependencies = [ "ff_ext", "gkr", "goldilocks", - "itertools 0.12.1", + "itertools 0.13.0", "multilinear_extensions", "simple-frontend", "sumcheck", @@ -860,10 +743,10 @@ checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" [[package]] name = "goldilocks" version = "0.1.0" -source = "git+https://github.com/hero78119/Goldilocks#7798793666c1205d50fc80b32b23a4430fce31e5" +source = "git+https://github.com/scroll-tech/ceno-Goldilocks#29a15d186ce4375dab346a3cc9eca6e43540cb8d" dependencies = [ "ff", - "halo2curves 0.1.0", + "halo2curves", "itertools 0.12.1", "rand_core", "serde", @@ -911,25 +794,6 @@ dependencies = [ "subtle", ] -[[package]] -name = "halo2curves" -version = "0.3.3" -source = "git+https://github.com/privacy-scaling-explorations/halo2curves?tag=0.3.3#8e4cb9f0c66c864e8ca25da07f50ae95f664a5b7" -dependencies = [ - "ff", - "group", - "lazy_static", - "num-bigint", - "num-traits", - "pasta_curves", - "paste", - "rand", - "rand_core", - "serde", - "static_assertions", - "subtle", -] - [[package]] name = "hashbrown" version = "0.14.5" @@ -947,27 +811,12 @@ version = "0.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e087f84d4f86bf4b218b927129862374b72199ae7d8657835f1e89000eea4fb" -[[package]] -name = "heck" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" - [[package]] name = "heck" version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" -[[package]] -name = "hermit-abi" -version = "0.1.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33" -dependencies = [ - "libc", -] - [[package]] name = "hermit-abi" version = "0.4.0" @@ -982,12 +831,9 @@ checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" [[package]] name = "humantime" -version = "1.3.0" +version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df004cfca50ef23c36850aaaa59ad52cc70d0e90243c3c7737a4dd32dc7a3c4f" -dependencies = [ - "quick-error", -] +checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" [[package]] name = "indexmap" @@ -1032,7 +878,7 @@ version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "261f68e344040fbd0edea105bef17c66edf46f984ddb1115b775ce31be948f4b" dependencies = [ - "hermit-abi 0.4.0", + "hermit-abi", "libc", "windows-sys 0.52.0", ] @@ -1070,6 +916,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.11" @@ -1078,22 +933,13 @@ checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" [[package]] name = "js-sys" -version = "0.3.71" +version = "0.3.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0cb94a0ffd3f3ee755c20f7d8752f45cac88605a4dcf808abcff72873296ec7b" +checksum = "6a88f1bda2bd75b0452a14784937d796722fdebfe50df998aeb3f0b7603019a9" dependencies = [ "wasm-bindgen", ] -[[package]] -name = "keccak" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ecc2af9a1119c51f12a14607e783cb977bde58bc069ff0c3da1095e635d70654" -dependencies = [ - "cpufeatures", -] - [[package]] name = "keccak-hash" version = "0.8.0" @@ -1115,9 +961,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.159" +version = "0.2.161" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "561d97a539a36e26a9a5fad1ea11a3039a67714694aaa379433e580854bc3dc5" +checksum = "8e9489c2807c139ffd9c1794f4af0ebe86a828db53ecdc7fea2111d0fed085d1" [[package]] name = "linux-raw-sys" @@ -1181,17 +1027,13 @@ dependencies = [ "aes", "ark-std", "bitvec", - "blake2", - "blake2b_simd", "criterion", "ctr", "ff", "ff_ext", "generic-array", - "gkr", "goldilocks", - "halo2curves 0.3.3", - "itertools 0.12.1", + "itertools 0.13.0", "multilinear_extensions", "num-bigint", "num-integer", @@ -1201,8 +1043,6 @@ dependencies = [ "rand_chacha", "rayon", "serde", - "sha3", - "simple-frontend", "transcript", ] @@ -1211,18 +1051,15 @@ name = "multilinear_extensions" version = "0.1.0" dependencies = [ "ark-std", - "criterion", "env_logger", "ff", "ff_ext", "goldilocks", - "itertools 0.12.1", + "itertools 0.13.0", "log", "rayon", "serde", "tracing", - "tracing-flame", - "tracing-subscriber", ] [[package]] @@ -1404,9 +1241,9 @@ checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" [[package]] name = "pin-project-lite" -version = "0.2.14" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" +checksum = "915a1e146535de9163f3987b8944ed8cf49a18bb0056bcebcdcece385cece4ff" [[package]] name = "plonky2" @@ -1549,19 +1386,13 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.87" +version = "1.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3e4daa0dcf6feba26f985457cdf104d4b4256fc5a09547140f3631bb076b19a" +checksum = "f139b0662de085916d1fb67d2b4169d1addddda1919e696f3252b740b629986e" dependencies = [ "unicode-ident", ] -[[package]] -name = "quick-error" -version = "1.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" - [[package]] name = "quick-xml" version = "0.26.0" @@ -1648,9 +1479,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.11.0" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38200e5ee88914975b69f657f0801b6f6dccafd44fd9326302a4aaeecfacb1d8" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", @@ -1701,34 +1532,33 @@ dependencies = [ [[package]] name = "riscv" -version = "0.11.1" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f5c1b8bf41ea746266cdee443d1d1e9125c86ce1447e1a2615abd34330d33a9" +checksum = "5ea8ff73d3720bdd0a97925f0bf79ad2744b6da8ff36be3840c48ac81191d7a7" dependencies = [ "critical-section", "embedded-hal", + "paste", + "riscv-macros", + "riscv-pac", ] [[package]] -name = "riscv-rt" -version = "0.12.2" +name = "riscv-macros" +version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0d35e32cf1383183e8885d8a9aa4402a087fd094dc34c2cb6df6687d0229dfe" +checksum = "f265be5d634272320a7de94cea15c22a3bfdd4eb42eb43edc528415f066a1f25" dependencies = [ - "riscv", - "riscv-rt-macros", + "proc-macro2", + "quote", + "syn 2.0.85", ] [[package]] -name = "riscv-rt-macros" -version = "0.2.1" +name = "riscv-pac" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8d100d466dbb76681ef6a9386f3da9abc570d57394e86da0ba5af8c4408486d" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", -] +checksum = "8188909339ccc0c68cfb5a04648313f09621e8b87dc03095454f1a11f6c5d436" [[package]] name = "rustc-demangle" @@ -1736,20 +1566,11 @@ version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" -[[package]] -name = "rustc_version" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" -dependencies = [ - "semver", -] - [[package]] name = "rustix" -version = "0.38.37" +version = "0.38.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8acb788b847c24f28525660c4d7758620a7210875711f79e7f663cc152726811" +checksum = "aa260229e6538e52293eeb577aabd09945a09d6d9cc0fc550ed7529056c2e32a" dependencies = [ "bitflags 2.6.0", "errno", @@ -1760,9 +1581,9 @@ dependencies = [ [[package]] name = "rustversion" -version = "1.0.17" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "955d28af4278de8121b7ebeb796b6a45735dc01436d898801014aced2773a3d6" +checksum = "0e819f2bc632f285be6d7cd36e25940d45b2391dd6d9b939e79de557f7014248" [[package]] name = "ryu" @@ -1785,37 +1606,31 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" -[[package]] -name = "semver" -version = "1.0.23" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" - [[package]] name = "serde" -version = "1.0.210" +version = "1.0.213" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a" +checksum = "3ea7893ff5e2466df8d720bb615088341b295f849602c6956047f8f80f0e9bc1" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.210" +version = "1.0.213" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f" +checksum = "7e85ad2009c50b58e87caa8cd6dac16bdf511bbfb7af6c33df902396aa480fa5" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.85", ] [[package]] name = "serde_json" -version = "1.0.128" +version = "1.0.132" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8" +checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03" dependencies = [ "itoa", "memchr", @@ -1823,16 +1638,6 @@ dependencies = [ "serde", ] -[[package]] -name = "sha3" -version = "0.10.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75872d278a8f37ef87fa0ddbda7802605cb18344497949862c0d4dcb291eba60" -dependencies = [ - "digest", - "keccak", -] - [[package]] name = "sharded-slab" version = "0.1.7" @@ -1852,12 +1657,10 @@ checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" name = "simple-frontend" version = "0.1.0" dependencies = [ - "ark-std", "ff", "ff_ext", "goldilocks", - "itertools 0.12.1", - "rayon", + "itertools 0.13.0", "serde", ] @@ -1867,15 +1670,13 @@ version = "0.1.0" dependencies = [ "ark-std", "cfg-if", - "const_env", "criterion", "ff", "ff_ext", "gkr", "gkr-graph", "goldilocks", - "itertools 0.12.1", - "mpcs", + "itertools 0.13.0", "multilinear_extensions", "paste", "pprof", @@ -1883,8 +1684,8 @@ dependencies = [ "serde", "simple-frontend", "singer-utils", - "strum 0.25.0", - "strum_macros 0.25.3", + "strum", + "strum_macros", "tracing", "tracing-flame", "tracing-subscriber", @@ -1901,11 +1702,11 @@ dependencies = [ "gkr", "gkr-graph", "goldilocks", - "itertools 0.12.1", + "itertools 0.13.0", "multilinear_extensions", "simple-frontend", - "strum 0.26.3", - "strum_macros 0.26.4", + "strum", + "strum_macros", "sumcheck", "transcript", ] @@ -1946,42 +1747,23 @@ version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" -[[package]] -name = "strum" -version = "0.25.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125" - [[package]] name = "strum" version = "0.26.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" -[[package]] -name = "strum_macros" -version = "0.25.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23dc1fa9ac9c169a78ba62f0b841814b7abae11bdd047b9c58f893439e309ea0" -dependencies = [ - "heck 0.4.1", - "proc-macro2", - "quote", - "rustversion", - "syn 2.0.79", -] - [[package]] name = "strum_macros" version = "0.26.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "quote", "rustversion", - "syn 2.0.79", + "syn 2.0.85", ] [[package]] @@ -1994,17 +1776,14 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" name = "sumcheck" version = "0.1.0" dependencies = [ - "ark-ff", "ark-std", - "const_env", "criterion", "crossbeam-channel", "ff", "ff_ext", "goldilocks", - "itertools 0.12.1", + "itertools 0.13.0", "multilinear_extensions", - "pprof", "rayon", "serde", "tracing", @@ -2047,9 +1826,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.79" +version = "2.0.85" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89132cd0bf050864e1d38dc3bbc07a0eb8e7530af26344d3d2bbbef83499f590" +checksum = "5023162dfcd14ef8f32034d8bcd4cc5ddc61ef7a247c024a33e24e1f24d21b56" dependencies = [ "proc-macro2", "quote", @@ -2075,33 +1854,24 @@ dependencies = [ "windows-sys 0.59.0", ] -[[package]] -name = "termcolor" -version = "1.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755" -dependencies = [ - "winapi-util", -] - [[package]] name = "thiserror" -version = "1.0.64" +version = "1.0.65" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d50af8abc119fb8bb6dbabcfa89656f46f84aa0ac7688088608076ad2b459a84" +checksum = "5d11abd9594d9b38965ef50805c5e469ca9cc6f197f883f717e0269a3057b3d5" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.64" +version = "1.0.65" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3" +checksum = "ae71770322cbd277e69d762a16c444af02aa0575ac0d174f0b9562d3b37f8602" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.85", ] [[package]] @@ -2152,7 +1922,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.85", ] [[package]] @@ -2209,14 +1979,11 @@ dependencies = [ name = "transcript" version = "0.1.0" dependencies = [ - "ark-std", "crossbeam-channel", "ff", "ff_ext", "goldilocks", - "halo2curves 0.1.0", "poseidon", - "rayon", "serde", ] @@ -2262,9 +2029,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81dfa00651efa65069b0b6b651f4aaa31ba9e3c3ce0137aaad053604ee7e0314" +checksum = "f8c5f0a0af699448548ad1a2fbf920fb4bee257eae39953ba95cb84891a0446a" [[package]] name = "valuable" @@ -2296,9 +2063,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.94" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef073ced962d62984fb38a36e5fdc1a2b23c9e0e1fa0689bb97afa4202ef6887" +checksum = "128d1e363af62632b8eb57219c8fd7877144af57558fb2ef0368d0087bddeb2e" dependencies = [ "cfg-if", "once_cell", @@ -2307,24 +2074,24 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.94" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4bfab14ef75323f4eb75fa52ee0a3fb59611977fd3240da19b2cf36ff85030e" +checksum = "cb6dd4d3ca0ddffd1dd1c9c04f94b868c37ff5fac97c30b97cff2d74fce3a358" dependencies = [ "bumpalo", "log", "once_cell", "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.85", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-macro" -version = "0.2.94" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7bec9830f60924d9ceb3ef99d55c155be8afa76954edffbb5936ff4509474e7" +checksum = "e79384be7f8f5a9dd5d7167216f022090cf1f9ec128e6e6a482a2cb5c5422c56" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -2332,28 +2099,28 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.94" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c74f6e152a76a2ad448e223b0fc0b6b5747649c3d769cc6bf45737bf97d0ed6" +checksum = "26c6ab57572f7a24a4985830b120de1594465e5d500f24afe89e16b4e833ef68" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.85", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.94" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a42f6c679374623f295a8623adfe63d9284091245c3504bde47c17a3ce2777d9" +checksum = "65fc09f10666a9f147042251e0dda9c18f166ff7de300607007e96bdebc1068d" [[package]] name = "web-sys" -version = "0.3.71" +version = "0.3.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44188d185b5bdcae1052d08bcbcf9091a5524038d4572cc4f4f2bb9d5554ddd9" +checksum = "f6488b90108c040df0fe62fa815cbdee25124641df01814dd7282749234c6112" dependencies = [ "js-sys", "wasm-bindgen", @@ -2575,25 +2342,5 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", -] - -[[package]] -name = "zeroize" -version = "1.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" -dependencies = [ - "zeroize_derive", -] - -[[package]] -name = "zeroize_derive" -version = "1.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.79", + "syn 2.0.85", ] diff --git a/Cargo.toml b/Cargo.toml index 5aa326e44..3972004e8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,46 +1,55 @@ [workspace] +exclude = ["examples"] members = [ - "ceno_emul", - "examples-builder", - "ceno_rt", - "gkr", - "gkr-graph", - "mpcs", - "multilinear_extensions", - "simple-frontend", - "singer", - "singer-utils", - "sumcheck", - "transcript", - "ceno_zkvm", - "poseidon" + "ceno_emul", + "examples-builder", + "ceno_rt", + "gkr", + "gkr-graph", + "mpcs", + "multilinear_extensions", + "simple-frontend", + "singer", + "singer-utils", + "sumcheck", + "transcript", + "ceno_zkvm", + "poseidon", ] -exclude = ["examples"] resolver = "2" [workspace.package] -version = "0.1.0" edition = "2021" license = "MIT OR Apache-2.0" +version = "0.1.0" [workspace.dependencies] -ark-std = { version = "0.4" } -# ark-std = { version = "0.4", features = ["print-trace"] } -ark-poly = "0.4" +ark-std = "0.4" +cfg-if = "1.0" criterion = { version = "0.5", features = ["html_reports"] } +crossbeam-channel = "0.5" ff = "0.13" -goldilocks = { git = "https://github.com/zhenfeizhang/Goldilocks" } -halo2curves = "0.1.0" +goldilocks = { git = "https://github.com/scroll-tech/ceno-Goldilocks" } +itertools = "0.13" +paste = "1" +plonky2 = "0.2" poseidon = { path = "./poseidon" } -serde = { version = "1.0", features = ["derive"] } -subtle = "2.6.1" -rand_core = "0.6.4" +pprof = { version = "0.13", features = ["flamegraph"] } +rand = "0.8" +rand_chacha = { version = "0.3", features = ["serde1"] } +rand_core = "0.6" rand_xorshift = "0.3" rayon = "1.10" -rand_chacha = { version = "0.3.1", features = ["serde1"] } - -[patch."https://github.com/zhenfeizhang/Goldilocks"] -goldilocks = { git = "https://github.com/hero78119/Goldilocks" } +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +strum = "0.26" +strum_macros = "0.26" +subtle = "2.6" +tracing = { version = "0.1", features = [ + "attributes", +] } +tracing-flame = "0.2" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } [profile.release] lto = "thin" diff --git a/Makefile.toml b/Makefile.toml index 9845a9537..027f51d2c 100644 --- a/Makefile.toml +++ b/Makefile.toml @@ -1,7 +1,6 @@ [env] CARGO_MAKE_EXTEND_WORKSPACE_MAKEFILE = true CORE = { script = ["nproc"] } -RAYON_NUM_THREADS = "${CORE}" CUR_TARGET = { script = [''' if [ -z "${TARGET}" ]; then TARGET=$(rustc -vV | grep "host" | awk '{print $2}') @@ -10,53 +9,58 @@ CUR_TARGET = { script = [''' echo "${TARGET}" fi '''] } +RAYON_NUM_THREADS = "${CORE}" [tasks.tests] -workspace = false -command = "cargo" args = [ - "test", - "--lib", - "--release", - "--target", - "${CUR_TARGET}", - "--workspace", - "--exclude", - "ceno_rt", + "test", + # Run everything but 'benches'. + "--lib", + "--bins", + "--tests", + "--examples", + "--release", + "--target", + "${CUR_TARGET}", + "--workspace", + "--exclude", + "ceno_rt", ] +command = "cargo" +workspace = false [tasks.fmt-all-check] -workspace = false -command = "cargo" args = ["fmt", "--all", "--", "--check"] +command = "cargo" +workspace = false [tasks.fmt-all] -workspace = false -command = "cargo" args = ["fmt", "--all"] +command = "cargo" +workspace = false [tasks.clippy-all] -workspace = false -command = "cargo" args = ["clippy", "--all-features", "--all-targets", "--", "-D", "warnings"] +command = "cargo" +workspace = false [tasks.fmt] -workspace = false -command = "cargo" args = ["fmt", "-p", "ceno_zkvm", "--", "--check"] +command = "cargo" +workspace = false [tasks.clippy] -workspace = false -command = "cargo" args = [ - "clippy", - "--workspace", - "--all-targets", - "--exclude", - "ceno_rt", - "--target", - "${CUR_TARGET}", - "--", - "-D", - "warnings", + "clippy", + "--workspace", + "--all-targets", + "--exclude", + "ceno_rt", + "--target", + "${CUR_TARGET}", + "--", + "-D", + "warnings", ] +command = "cargo" +workspace = false diff --git a/build.rs b/build.rs deleted file mode 100644 index 3e31cb0a9..000000000 --- a/build.rs +++ /dev/null @@ -1,3 +0,0 @@ -fn main() { - println!("cargo:rerun-if-env-changed=RAYON_NUM_THREADS"); -} diff --git a/ceno_emul/Cargo.toml b/ceno_emul/Cargo.toml index a615910f9..8f78e5805 100644 --- a/ceno_emul/Cargo.toml +++ b/ceno_emul/Cargo.toml @@ -1,17 +1,19 @@ [package] -name = "ceno_emul" -version.workspace = true edition.workspace = true license.workspace = true +name = "ceno_emul" +version.workspace = true [dependencies] anyhow = { version = "1.0", default-features = false } -strum = "0.25.0" -strum_macros = "0.25.3" -tracing = { version = "0.1", default-features = false, features = [ - "attributes", -] } -elf = { version = "0.7.4" } +elf = "0.7" +strum.workspace = true +strum_macros.workspace = true +tracing.workspace = true [dev-dependencies] ceno-examples = { path = "../examples-builder" } + +[features] +default = ["forbid_overflow"] +forbid_overflow = [] diff --git a/ceno_emul/src/addr.rs b/ceno_emul/src/addr.rs index 132770ddf..0ce39f56b 100644 --- a/ceno_emul/src/addr.rs +++ b/ceno_emul/src/addr.rs @@ -79,6 +79,10 @@ impl ByteAddr { WordAddr(self.0 / WORD_SIZE as u32) } + pub const fn shift(self) -> u32 { + self.0 & 0x03 + } + pub const fn is_aligned(&self) -> bool { self.0 % WORD_SIZE as u32 == 0 } diff --git a/ceno_emul/src/lib.rs b/ceno_emul/src/lib.rs index 2b1719234..f32e462f1 100644 --- a/ceno_emul/src/lib.rs +++ b/ceno_emul/src/lib.rs @@ -15,3 +15,6 @@ pub use rv32im::{DecodedInstruction, EmuContext, InsnCategory, InsnCodes, InsnKi mod elf; pub use elf::Program; + +mod rv32im_encode; +pub use rv32im_encode::encode_rv32; diff --git a/ceno_emul/src/platform.rs b/ceno_emul/src/platform.rs index 4192607cf..264d810f5 100644 --- a/ceno_emul/src/platform.rs +++ b/ceno_emul/src/platform.rs @@ -25,11 +25,26 @@ impl Platform { } pub const fn ram_start(&self) -> Addr { - 0x8000_0000 + let ram_start = 0x8000_0000; + if cfg!(feature = "forbid_overflow") { + // -1<<11 == 0x800 is the smallest negative 'immediate' + // offset we can have in memory instructions. + // So if we stay away from it, we are safe. + assert!(ram_start >= 0x800); + } + ram_start } pub const fn ram_end(&self) -> Addr { 0xFFFF_FFFF + - if cfg!(feature = "forbid_overflow") { + // (1<<11) - 1 == 0x7ff is the largest positive 'immediate' + // offset we can have in memory instructions. + // So if we stay away from it, we are safe. + 0x7FF + } else { + 0 + } } pub fn is_ram(&self, addr: Addr) -> bool { diff --git a/ceno_emul/src/rv32im.rs b/ceno_emul/src/rv32im.rs index 17274386b..2abd03cce 100644 --- a/ceno_emul/src/rv32im.rs +++ b/ceno_emul/src/rv32im.rs @@ -24,9 +24,6 @@ pub trait EmuContext { // Handle environment call fn ecall(&mut self) -> Result; - // Handle halt - fn halt(&mut self, pc: ByteAddr); - // Handle a machine return fn mret(&self) -> Result; @@ -221,28 +218,6 @@ impl DecodedInstruction { } } - #[allow(dead_code)] - pub fn from_raw(kind: InsnKind, rs1: u32, rs2: u32, rd: u32) -> Self { - // limit the range of inputs - let rs2 = rs2 & 0x1f; // 5bits mask - let rs1 = rs1 & 0x1f; - let rd = rd & 0x1f; - let func7 = kind.codes().func7; - let func3 = kind.codes().func3; - let opcode = kind.codes().opcode; - let insn = func7 << 25 | rs2 << 20 | rs1 << 15 | func3 << 12 | rd << 7 | opcode; - Self { - insn, - top_bit: func7 | 0x80, - func7, - rs2, - rs1, - func3, - rd, - opcode, - } - } - pub fn encoded(&self) -> u32 { self.insn } @@ -329,7 +304,7 @@ impl DecodedInstruction { match self.codes() { InsnCodes { format: R | U, .. } => false, InsnCodes { - kind: SLLI | SRLI | SRAI | ADDI | SW | LW, + kind: SLLI | SRLI | SRAI | ADDI, .. } => false, _ => self.top_bit != 0, diff --git a/ceno_emul/src/rv32im_encode.rs b/ceno_emul/src/rv32im_encode.rs new file mode 100644 index 000000000..4a7fe3e7f --- /dev/null +++ b/ceno_emul/src/rv32im_encode.rs @@ -0,0 +1,107 @@ +use crate::{InsnKind, rv32im::InsnFormat}; + +const MASK_4_BITS: u32 = 0xF; +const MASK_5_BITS: u32 = 0x1F; +const MASK_6_BITS: u32 = 0x3F; +const MASK_7_BITS: u32 = 0x7F; +const MASK_8_BITS: u32 = 0xFF; +const MASK_10_BITS: u32 = 0x3FF; +const MASK_12_BITS: u32 = 0xFFF; + +pub const fn encode_rv32(kind: InsnKind, rs1: u32, rs2: u32, rd: u32, imm: u32) -> u32 { + match kind.codes().format { + InsnFormat::R => encode_r(kind, rs1, rs2, rd), + InsnFormat::I => encode_i(kind, rs1, rd, imm), + InsnFormat::S => encode_s(kind, rs1, rs2, imm), + InsnFormat::B => encode_b(kind, rs1, rs2, imm), + InsnFormat::U => encode_u(kind, rd, imm), + InsnFormat::J => encode_j(kind, rd, imm), + } +} + +// R-Type +// 25 20 15 12 7 0 +// +------+-----+-----+--------+----+-------+ +// funct7 | rs2 | rs1 | funct3 | rd | opcode +const fn encode_r(kind: InsnKind, rs1: u32, rs2: u32, rd: u32) -> u32 { + let rs2 = rs2 & MASK_5_BITS; // 5-bits mask + let rs1 = rs1 & MASK_5_BITS; + let rd = rd & MASK_5_BITS; + let func7 = kind.codes().func7; + let func3 = kind.codes().func3; + let opcode = kind.codes().opcode; + func7 << 25 | rs2 << 20 | rs1 << 15 | func3 << 12 | rd << 7 | opcode +} + +// I-Type +// 20 15 12 7 0 +// +---------+-----+--------+----+-------+ +// imm[0:11] | rs1 | funct3 | rd | opcode +const fn encode_i(kind: InsnKind, rs1: u32, rd: u32, imm: u32) -> u32 { + let rs1 = rs1 & MASK_5_BITS; + let rd = rd & MASK_5_BITS; + let func3 = kind.codes().func3; + let opcode = kind.codes().opcode; + let imm = imm & MASK_12_BITS; + imm << 20 | rs1 << 15 | func3 << 12 | rd << 7 | opcode +} + +// S-Type +// 25 20 15 12 7 0 +// +---------+-----+-----+--------+----------+-------+ +// imm[5:11] | rs2 | rs1 | funct3 | imm[0:4] | opcode +const fn encode_s(kind: InsnKind, rs1: u32, rs2: u32, imm: u32) -> u32 { + let rs2 = rs2 & MASK_5_BITS; + let rs1 = rs1 & MASK_5_BITS; + let func3 = kind.codes().func3; + let opcode = kind.codes().opcode; + let imm_lo = imm & MASK_5_BITS; + let imm_hi = (imm >> 5) & MASK_7_BITS; // 7-bits mask + imm_hi << 25 | rs2 << 20 | rs1 << 15 | func3 << 12 | imm_lo << 7 | opcode +} + +// B-Type +// 31 25 20 15 12 8 7 0 +// +-------+-----------+-----+-----+--------+----------+---------+-------+ +// imm[12] | imm[5:10] | rs2 | rs1 | funct3 | imm[1:4] | imm[11] | opcode +const fn encode_b(kind: InsnKind, rs1: u32, rs2: u32, imm: u32) -> u32 { + let rs2 = rs2 & MASK_5_BITS; + let rs1 = rs1 & MASK_5_BITS; + let func3 = kind.codes().func3; + let opcode = kind.codes().opcode; + let imm_1_4 = (imm >> 1) & MASK_4_BITS; // skip imm[0] + let imm_5_10 = (imm >> 5) & MASK_6_BITS; + ((imm >> 12) & 1) << 31 + | imm_5_10 << 25 + | rs2 << 20 + | rs1 << 15 + | func3 << 12 + | imm_1_4 << 8 + | ((imm >> 11) & 1) << 7 + | opcode +} + +// J-Type +// 31 21 20 12 7 0 +// +-------+-----------+---------+------------+----+-------+ +// imm[20] | imm[1:10] | imm[11] | imm[12:19] | rd | opcode +const fn encode_j(kind: InsnKind, rd: u32, imm: u32) -> u32 { + let rd = rd & MASK_5_BITS; + let opcode = kind.codes().opcode; + let imm_1_10 = (imm >> 1) & MASK_10_BITS; // skip imm[0] + let imm_12_19 = (imm >> 12) & MASK_8_BITS; + ((imm >> 20) & 1) << 31 + | imm_1_10 << 21 + | ((imm >> 11) & 1) << 20 + | imm_12_19 << 12 + | rd << 7 + | opcode +} + +// U-Type +// 12 7 0 +// +----------+----+--------+ +// imm[12:31] | rd | opcode +const fn encode_u(kind: InsnKind, rd: u32, imm: u32) -> u32 { + (imm >> 12) << 12 | (rd & MASK_5_BITS) << 7 | kind.codes().opcode +} diff --git a/ceno_emul/src/tracer.rs b/ceno_emul/src/tracer.rs index e8bcd5542..b9423f418 100644 --- a/ceno_emul/src/tracer.rs +++ b/ceno_emul/src/tracer.rs @@ -56,7 +56,7 @@ impl StepRecord { pub fn new_r_instruction( cycle: Cycle, pc: ByteAddr, - insn_code: Word, + insn_code: u32, rs1_read: Word, rs2_read: Word, rd: Change, @@ -70,6 +70,7 @@ impl StepRecord { Some(rs1_read), Some(rs2_read), Some(rd), + None, prev_cycle, ) } @@ -77,7 +78,7 @@ impl StepRecord { pub fn new_b_instruction( cycle: Cycle, pc: Change, - insn_code: Word, + insn_code: u32, rs1_read: Word, rs2_read: Word, prev_cycle: Cycle, @@ -89,16 +90,38 @@ impl StepRecord { Some(rs1_read), Some(rs2_read), None, + None, prev_cycle, ) } pub fn new_i_instruction( + cycle: Cycle, + pc: Change, + insn_code: u32, + rs1_read: Word, + rd: Change, + prev_cycle: Cycle, + ) -> StepRecord { + StepRecord::new_insn( + cycle, + pc, + insn_code, + Some(rs1_read), + None, + Some(rd), + None, + prev_cycle, + ) + } + + pub fn new_im_instruction( cycle: Cycle, pc: ByteAddr, - insn_code: Word, + insn_code: u32, rs1_read: Word, rd: Change, + mem_op: ReadOp, prev_cycle: Cycle, ) -> StepRecord { let pc = Change::new(pc, pc + PC_STEP_SIZE); @@ -109,6 +132,14 @@ impl StepRecord { Some(rs1_read), None, Some(rd), + Some(WriteOp { + addr: mem_op.addr, + value: Change { + before: mem_op.value, + after: mem_op.value, + }, + previous_cycle: mem_op.previous_cycle, + }), prev_cycle, ) } @@ -116,31 +147,55 @@ impl StepRecord { pub fn new_u_instruction( cycle: Cycle, pc: ByteAddr, - insn_code: Word, + insn_code: u32, rd: Change, prev_cycle: Cycle, ) -> StepRecord { let pc = Change::new(pc, pc + PC_STEP_SIZE); - StepRecord::new_insn(cycle, pc, insn_code, None, None, Some(rd), prev_cycle) + StepRecord::new_insn(cycle, pc, insn_code, None, None, Some(rd), None, prev_cycle) } pub fn new_j_instruction( cycle: Cycle, pc: Change, - insn_code: Word, + insn_code: u32, rd: Change, prev_cycle: Cycle, ) -> StepRecord { - StepRecord::new_insn(cycle, pc, insn_code, None, None, Some(rd), prev_cycle) + StepRecord::new_insn(cycle, pc, insn_code, None, None, Some(rd), None, prev_cycle) } + pub fn new_s_instruction( + cycle: Cycle, + pc: ByteAddr, + insn_code: u32, + rs1_read: Word, + rs2_read: Word, + memory_op: WriteOp, + prev_cycle: Cycle, + ) -> StepRecord { + let pc = Change::new(pc, pc + PC_STEP_SIZE); + StepRecord::new_insn( + cycle, + pc, + insn_code, + Some(rs1_read), + Some(rs2_read), + None, + Some(memory_op), + prev_cycle, + ) + } + + #[allow(clippy::too_many_arguments)] fn new_insn( cycle: Cycle, pc: Change, - insn_code: Word, + insn_code: u32, rs1_read: Option, rs2_read: Option, rd: Option>, + memory_op: Option, previous_cycle: Cycle, ) -> StepRecord { let insn = DecodedInstruction::new(insn_code); @@ -163,7 +218,7 @@ impl StepRecord { value: rd, previous_cycle, }), - memory_op: None, + memory_op, } } @@ -249,12 +304,6 @@ impl Tracer { self.record.pc.after = pc; } - pub fn halt(&mut self, pc: ByteAddr) { - let pc_addr = CENO_PLATFORM.pc_vma().into(); - self.record.pc.after = pc; - self.track_access(pc_addr, Self::SUBCYCLES_PER_INSN); - } - pub fn fetch(&mut self, pc: WordAddr, value: Word) { self.record.pc.before = pc.baddr(); self.record.insn_code = value; @@ -325,6 +374,11 @@ impl Tracer { pub fn final_accesses(&self) -> &HashMap { &self.latest_accesses } + + /// Return the cycle of the pending instruction (after the last completed step). + pub fn cycle(&self) -> Cycle { + self.record.cycle + } } #[derive(Copy, Clone, Default, PartialEq, Eq)] diff --git a/ceno_emul/src/vm_state.rs b/ceno_emul/src/vm_state.rs index 3611d3881..5d43a3865 100644 --- a/ceno_emul/src/vm_state.rs +++ b/ceno_emul/src/vm_state.rs @@ -86,6 +86,11 @@ impl VMState { pub fn init_register_unsafe(&mut self, idx: RegIdx, value: Word) { self.registers[idx] = value; } + + fn halt(&mut self) { + self.set_pc(0.into()); + self.halted = true; + } } impl EmuContext for VMState { @@ -96,7 +101,7 @@ impl EmuContext for VMState { let exit_code = self.load_register(self.platform.reg_arg0())?; tracing::debug!("halt with exit_code={}", exit_code); - self.halt(ByteAddr(0)); + self.halt(); Ok(true) } else { self.trap(TrapCause::EcallError) @@ -126,12 +131,6 @@ impl EmuContext for VMState { self.pc = after.0; } - fn halt(&mut self, pc: ByteAddr) { - self.pc = pc.0; - self.halted = true; - self.tracer.halt(ByteAddr(pc.0)); - } - /// Load a register and record this operation. fn load_register(&mut self, idx: RegIdx) -> Result { self.tracer.load_register(idx, self.peek_register(idx)); diff --git a/ceno_rt/.cargo/config.toml b/ceno_rt/.cargo/config.toml index c08ff50ab..9cea5ae7c 100644 --- a/ceno_rt/.cargo/config.toml +++ b/ceno_rt/.cargo/config.toml @@ -1,13 +1,15 @@ [target.riscv32im-unknown-none-elf] rustflags = [ - "-C", "link-arg=-Tmemory.x", + "-C", + "link-arg=-Tmemory.x", #"-C", "link-arg=-Tlink.x", // Script from riscv_rt. - "-C", "link-arg=-Tceno_link.x", + "-C", + "link-arg=-Tceno_link.x", ] [build] target = "riscv32im-unknown-none-elf" [profile.release] -panic = "abort" lto = true +panic = "abort" diff --git a/ceno_rt/Cargo.toml b/ceno_rt/Cargo.toml index c295821f2..505f67dd8 100644 --- a/ceno_rt/Cargo.toml +++ b/ceno_rt/Cargo.toml @@ -1,9 +1,8 @@ [package] -name = "ceno_rt" -version.workspace = true edition.workspace = true license.workspace = true +name = "ceno_rt" +version.workspace = true [dependencies] -riscv = "0.11.1" -riscv-rt = "0.12.2" +riscv = "0.12" diff --git a/ceno_rt/src/lib.rs b/ceno_rt/src/lib.rs index 47253856f..07c53a2bc 100644 --- a/ceno_rt/src/lib.rs +++ b/ceno_rt/src/lib.rs @@ -60,18 +60,34 @@ _start: ", ); +#[macro_export] +macro_rules! entry { + ($path:path) => { + // Type check the given path + const CENO_ENTRY: fn() = $path; + + mod ceno_generated_main { + #[no_mangle] + extern "C" fn bespoke_entrypoint() { + super::CENO_ENTRY(); + } + } + }; +} + /// _start_rust is called by the assembly entry point and it calls the Rust main(). #[no_mangle] unsafe extern "C" fn _start_rust() -> ! { allocator::init_heap(); - main(); + { + extern "C" { + fn bespoke_entrypoint(); + } + bespoke_entrypoint(); + } halt(0) } -extern "C" { - fn main(); -} - extern "C" { // The address of this variable is the start of the stack (growing downwards). static _stack_start: u8; diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index 7a2743126..ac16a6de9 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -1,58 +1,56 @@ [package] -name = "ceno_zkvm" -version.workspace = true edition.workspace = true license.workspace = true - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +name = "ceno_zkvm" +version.workspace = true [dependencies] ark-std.workspace = true ff.workspace = true goldilocks.workspace = true +rand_chacha.workspace = true rayon.workspace = true serde.workspace = true -rand_chacha.workspace = true -transcript = { path = "../transcript" } -sumcheck = { version = "0.1.0", path = "../sumcheck" } -multilinear_extensions = { version = "0.1.0", path = "../multilinear_extensions" } -ff_ext = { path = "../ff_ext" } ceno_emul = { path = "../ceno_emul" } +ff_ext = { path = "../ff_ext" } mpcs = { path = "../mpcs" } +multilinear_extensions = { version = "0", path = "../multilinear_extensions" } +sumcheck = { version = "0", path = "../sumcheck" } +transcript = { path = "../transcript" } -itertools = "0.12.1" -strum = "0.25.0" -strum_macros = "0.25.3" -paste = "1.0.15" -tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } -tracing-flame = "0.2.0" -tracing = "0.1.40" +itertools.workspace = true +paste.workspace = true +strum.workspace = true +strum_macros.workspace = true +tracing.workspace = true +tracing-flame.workspace = true +tracing-subscriber.workspace = true -rand = "0.8" -thread_local = "1.1.8" -generic_static = "0.2.0" -clap = { version = "4.5.20", features = ["derive"] } -tempfile = "3.13.0" +clap = { version = "4.5", features = ["derive"] } +generic_static = "0.2" +rand.workspace = true +tempfile = "3.13" +thread_local = "1.1" [dev-dependencies] -pprof = { version = "0.13", features = ["flamegraph"]} -criterion = { version = "0.5", features = ["html_reports"] } -cfg-if = "1.0.0" -const_env = "0.1.2" -serde_json = "1.0.128" -base64 = "0.22.1" +base64 = "0.22" +cfg-if.workspace = true +criterion.workspace = true +pprof.workspace = true +serde_json.workspace = true [build-dependencies] glob = "0.3" [features] -default = ["riv32"] +default = ["riv32", "forbid_overflow"] +flamegraph = ["pprof/flamegraph", "pprof/criterion"] +forbid_overflow = [] +non_pow2_rayon_thread = [] riv32 = [] riv64 = [] -non_pow2_rayon_thread = [] -flamegraph = ["pprof/flamegraph", "pprof/criterion"] [[bench]] -name = "riscv_add" harness = false +name = "riscv_add" diff --git a/ceno_zkvm/Makefile.toml b/ceno_zkvm/Makefile.toml index 77f4845a7..a4965ba50 100644 --- a/ceno_zkvm/Makefile.toml +++ b/ceno_zkvm/Makefile.toml @@ -4,6 +4,6 @@ CORE = { script = ["nproc"] } RAYON_NUM_THREADS = "${CORE}" [tasks.riscv_opcodes_flamegraph] -env = { "RUST_LOG" = "debug", "RAYON_NUM_THREADS" = "8"} -command = "cargo" args = ["run", "--package", "ceno_zkvm", "--release", "--example", "riscv_opcodes"] +command = "cargo" +env = { "RUST_LOG" = "debug", "RAYON_NUM_THREADS" = "8" } diff --git a/ceno_zkvm/benches/riscv_add.rs b/ceno_zkvm/benches/riscv_add.rs index 9d69e120f..16d5cfe67 100644 --- a/ceno_zkvm/benches/riscv_add.rs +++ b/ceno_zkvm/benches/riscv_add.rs @@ -7,7 +7,6 @@ use ceno_zkvm::{ scheme::prover::ZKVMProver, structs::{ZKVMConstraintSystem, ZKVMFixedTraces}, }; -use const_env::from_env; use criterion::*; use ceno_zkvm::scheme::constants::MAX_NUM_VARIABLES; @@ -37,31 +36,9 @@ cfg_if::cfg_if! { criterion_main!(op_add); const NUM_SAMPLES: usize = 10; -#[from_env] -const RAYON_NUM_THREADS: usize = 8; fn bench_add(c: &mut Criterion) { type Pcs = BasefoldDefault; - let max_threads = { - if !RAYON_NUM_THREADS.is_power_of_two() { - #[cfg(not(feature = "non_pow2_rayon_thread"))] - { - panic!( - "add --features non_pow2_rayon_thread to enable unsafe feature which support non pow of 2 rayon thread pool" - ); - } - - #[cfg(feature = "non_pow2_rayon_thread")] - { - use sumcheck::{local_thread_pool::create_local_pool_once, util::ceil_log2}; - let max_thread_id = 1 << ceil_log2(RAYON_NUM_THREADS); - create_local_pool_once(1 << ceil_log2(RAYON_NUM_THREADS), true); - max_thread_id - } - } else { - RAYON_NUM_THREADS - } - }; let mut zkvm_cs = ZKVMConstraintSystem::default(); let _ = zkvm_cs.register_opcode_circuit::>(); let mut zkvm_fixed_traces = ZKVMFixedTraces::default(); @@ -128,7 +105,6 @@ fn bench_add(c: &mut Criterion) { commit, &[], num_instances, - max_threads, &mut transcript, &challenges, ) diff --git a/ceno_zkvm/examples/riscv_opcodes.rs b/ceno_zkvm/examples/riscv_opcodes.rs index c15cbfa38..b979975d2 100644 --- a/ceno_zkvm/examples/riscv_opcodes.rs +++ b/ceno_zkvm/examples/riscv_opcodes.rs @@ -1,35 +1,33 @@ -use std::{iter, panic, time::Instant}; +use std::{panic, time::Instant}; use ceno_zkvm::{ - instructions::riscv::{arith::AddInstruction, branch::BltuInstruction, jump::JalInstruction}, + declare_program, + instructions::riscv::{Rv32imConfig, constants::EXIT_PC}, scheme::prover::ZKVMProver, - tables::ProgramTableCircuit, + state::GlobalState, + tables::{MemFinalRecord, ProgramTableCircuit, initial_memory, initial_registers}, }; use clap::Parser; -use const_env::from_env; use ceno_emul::{ - ByteAddr, CENO_PLATFORM, - InsnKind::{ADD, BLTU, EANY, JAL}, - StepRecord, VMState, + ByteAddr, CENO_PLATFORM, EmuContext, + InsnKind::{ADD, BLTU, EANY, JAL, LUI, LW}, + StepRecord, Tracer, VMState, WordAddr, encode_rv32, }; use ceno_zkvm::{ - instructions::riscv::ecall::HaltInstruction, scheme::{PublicValues, constants::MAX_NUM_VARIABLES, verifier::ZKVMVerifier}, structs::{ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses}, - tables::{AndTableCircuit, LtuTableCircuit, U16TableCircuit}, }; use ff_ext::ff::Field; use goldilocks::GoldilocksExt2; +use itertools::Itertools; use mpcs::{Basefold, BasefoldRSParams, PolynomialCommitmentScheme}; use rand_chacha::ChaCha8Rng; use tracing_flame::FlameLayer; use tracing_subscriber::{EnvFilter, Registry, fmt, layer::SubscriberExt}; use transcript::Transcript; -#[from_env] -const RAYON_NUM_THREADS: usize = 8; - +const PROGRAM_SIZE: usize = 512; // For now, we assume registers // - x0 is not touched, // - x1 is initialized to 1, @@ -39,14 +37,26 @@ const RAYON_NUM_THREADS: usize = 8; #[allow(clippy::unusual_byte_groupings)] const ECALL_HALT: u32 = 0b_000000000000_00000_000_00000_1110011; #[allow(clippy::unusual_byte_groupings)] -const PROGRAM_CODE: [u32; 5] = [ - // func7 rs2 rs1 f3 rd opcode - 0b_0000000_00100_00001_000_00100_0110011, // add x4, x4, x1 <=> addi x4, x4, 1 - 0b_0000000_00011_00010_000_00011_0110011, // add x3, x3, x2 <=> addi x3, x3, -1 - 0b_1_111111_00011_00000_110_1100_1_1100011, // bltu x0, x3, -8 - 0b_0_0000000010_0_00000000_00001_1101111, // jal x1, 4 - ECALL_HALT, // ecall halt -]; +const PROGRAM_CODE: [u32; PROGRAM_SIZE] = { + let mut program: [u32; PROGRAM_SIZE] = [ECALL_HALT; PROGRAM_SIZE]; + declare_program!( + program, + // Load parameters from initial RAM. + encode_rv32(LUI, 0, 0, 10, CENO_PLATFORM.ram_start()), // lui x10, program_data + encode_rv32(LW, 10, 0, 1, 0), // lw x1, 0(x10) + encode_rv32(LW, 10, 0, 2, 4), // lw x2, 4(x10) + encode_rv32(LW, 10, 0, 3, 8), // lw x3, 8(x10) + // Main loop. + encode_rv32(ADD, 1, 4, 4, 0), // add x4, x1, x4 + encode_rv32(ADD, 2, 3, 3, 0), // add x3, x2, x3 + encode_rv32(BLTU, 0, 3, 0, -8_i32 as u32), // bltu x0, x3, -8 + // End. + encode_rv32(JAL, 0, 0, 1, 4), // jal x1, 4 + ECALL_HALT, // ecall halt + ); + program +}; +type ExampleProgramTableCircuit = ProgramTableCircuit; /// Simple program to greet a person #[derive(Parser, Debug)] @@ -66,27 +76,6 @@ fn main() { type E = GoldilocksExt2; type Pcs = Basefold; - let max_threads = { - if !RAYON_NUM_THREADS.is_power_of_two() { - #[cfg(not(feature = "non_pow2_rayon_thread"))] - { - panic!( - "add --features non_pow2_rayon_thread to enable unsafe feature which support non pow of 2 rayon thread pool" - ); - } - - #[cfg(feature = "non_pow2_rayon_thread")] - { - use sumcheck::{local_thread_pool::create_local_pool_once, util::ceil_log2}; - let max_thread_id = 1 << ceil_log2(RAYON_NUM_THREADS); - create_local_pool_once(1 << ceil_log2(RAYON_NUM_THREADS), true); - max_thread_id - } - } else { - RAYON_NUM_THREADS - } - }; - let (flame_layer, _guard) = FlameLayer::with_file("./tracing.folded").unwrap(); let subscriber = Registry::default() .with( @@ -103,141 +92,119 @@ fn main() { let pcs_param = Pcs::setup(1 << MAX_NUM_VARIABLES).expect("Basefold PCS setup"); let (pp, vp) = Pcs::trim(&pcs_param, 1 << MAX_NUM_VARIABLES).expect("Basefold trim"); let mut zkvm_cs = ZKVMConstraintSystem::default(); - // opcode circuits - let add_config = zkvm_cs.register_opcode_circuit::>(); - let bltu_config = zkvm_cs.register_opcode_circuit::(); - let jal_config = zkvm_cs.register_opcode_circuit::>(); - let halt_config = zkvm_cs.register_opcode_circuit::>(); - // tables - let u16_range_config = zkvm_cs.register_table_circuit::>(); - let and_config = zkvm_cs.register_table_circuit::>(); - let ltu_config = zkvm_cs.register_table_circuit::>(); - let prog_config = zkvm_cs.register_table_circuit::>(); - - let program_code: Vec = PROGRAM_CODE - .iter() - .cloned() - .chain(iter::repeat(ECALL_HALT)) - .take(512) - .collect(); - let mut zkvm_fixed_traces = ZKVMFixedTraces::default(); - zkvm_fixed_traces.register_opcode_circuit::>(&zkvm_cs); - zkvm_fixed_traces.register_opcode_circuit::(&zkvm_cs); - zkvm_fixed_traces.register_opcode_circuit::>(&zkvm_cs); - zkvm_fixed_traces.register_opcode_circuit::>(&zkvm_cs); - - zkvm_fixed_traces.register_table_circuit::>( - &zkvm_cs, - u16_range_config.clone(), - &(), - ); - zkvm_fixed_traces.register_table_circuit::>( - &zkvm_cs, - and_config.clone(), - &(), - ); - zkvm_fixed_traces.register_table_circuit::>( - &zkvm_cs, - ltu_config.clone(), - &(), - ); - zkvm_fixed_traces.register_table_circuit::>( - &zkvm_cs, - prog_config.clone(), - &program_code, - ); - let pk = zkvm_cs - .clone() - .key_gen::(pp, vp, zkvm_fixed_traces) - .expect("keygen failed"); - let vk = pk.get_vk(); - - // proving - let prover = ZKVMProver::new(pk); - let verifier = ZKVMVerifier::new(vk); + let config = Rv32imConfig::::construct_circuits(&mut zkvm_cs); + let prog_config = zkvm_cs.register_table_circuit::>(); + zkvm_cs.register_global_state::(); for instance_num_vars in args.start..args.end { let step_loop = 1 << (instance_num_vars - 1); // 1 step in loop contribute to 2 add instance + + // init vm.x1 = 1, vm.x2 = -1, vm.x3 = step_loop + let program_data: &[u32] = &[1, u32::MAX, step_loop]; + + let mut zkvm_fixed_traces = ZKVMFixedTraces::default(); + + zkvm_fixed_traces.register_table_circuit::>( + &zkvm_cs, + prog_config.clone(), + &PROGRAM_CODE, + ); + + let reg_init = initial_registers(); + let mem_init = initial_memory(program_data); + + config.generate_fixed_traces(&zkvm_cs, &mut zkvm_fixed_traces, ®_init, &mem_init); + + let pk = zkvm_cs + .clone() + .key_gen::(pp.clone(), vp.clone(), zkvm_fixed_traces) + .expect("keygen failed"); + let vk = pk.get_vk(); + + // proving + let prover = ZKVMProver::new(pk); + let verifier = ZKVMVerifier::new(vk); + let mut vm = VMState::new(CENO_PLATFORM); let pc_start = ByteAddr(CENO_PLATFORM.pc_start()).waddr(); - // init vm.x1 = 1, vm.x2 = -1, vm.x3 = num_instances - // vm.x4 += vm.x1 - vm.init_register_unsafe(1usize, 1); - vm.init_register_unsafe(2usize, u32::MAX); // -1 in two's complement - vm.init_register_unsafe(3usize, step_loop as u32); - for (i, inst) in program_code.iter().enumerate() { + for (i, inst) in PROGRAM_CODE.iter().enumerate() { vm.init_memory(pc_start + i, *inst); } + for record in &mem_init { + vm.init_memory(record.addr.into(), record.value); + } let all_records = vm .iter_until_halt() .collect::, _>>() - .expect("vm exec failed") - .into_iter() - .collect::>(); - let mut add_records = Vec::new(); - let mut bltu_records = Vec::new(); - let mut jal_records = Vec::new(); - let mut halt_records = Vec::new(); - all_records.into_iter().for_each(|record| { - let kind = record.insn().kind().1; - match kind { - ADD => add_records.push(record), - BLTU => bltu_records.push(record), - JAL => jal_records.push(record), - EANY => { - if record.rs1().unwrap().value == CENO_PLATFORM.ecall_halt() { - halt_records.push(record); - } - } - _ => {} - } - }); + .expect("vm exec failed"); + + let halt_record = all_records + .iter() + .rev() + .find(|record| { + record.insn().codes().kind == EANY + && record.rs1().unwrap().value == CENO_PLATFORM.ecall_halt() + }) + .expect("halt record not found"); - assert_eq!(halt_records.len(), 1); - let exit_code = halt_records[0].rs2().unwrap().value; - let pi = PublicValues::new(exit_code, 0); + let final_access = vm.tracer().final_accesses(); - tracing::info!( - "tracer generated {} ADD records, {} BLTU records, {} JAL records", - add_records.len(), - bltu_records.len(), - jal_records.len(), + let end_cycle: u32 = vm.tracer().cycle().try_into().unwrap(); + let exit_code = halt_record.rs2().unwrap().value; + let pi = PublicValues::new( + exit_code, + CENO_PLATFORM.rom_start(), + Tracer::SUBCYCLES_PER_INSN as u32, + EXIT_PC as u32, + end_cycle, ); let mut zkvm_witness = ZKVMWitnesses::default(); // assign opcode circuits - zkvm_witness - .assign_opcode_circuit::>(&zkvm_cs, &add_config, add_records) - .unwrap(); - zkvm_witness - .assign_opcode_circuit::(&zkvm_cs, &bltu_config, bltu_records) + config + .assign_opcode_circuit(&zkvm_cs, &mut zkvm_witness, all_records) .unwrap(); - zkvm_witness - .assign_opcode_circuit::>(&zkvm_cs, &jal_config, jal_records) - .unwrap(); - zkvm_witness - .assign_opcode_circuit::>(&zkvm_cs, &halt_config, halt_records) - .unwrap(); - zkvm_witness.finalize_lk_multiplicities(); + + // Find the final register values and cycles. + let reg_final = reg_init + .iter() + .map(|rec| { + let index = rec.addr as usize; + let vma: WordAddr = CENO_PLATFORM.register_vma(index).into(); + MemFinalRecord { + value: vm.peek_register(index), + cycle: *final_access.get(&vma).unwrap_or(&0), + } + }) + .collect_vec(); + + // Find the final memory values and cycles. + let mem_final = mem_init + .iter() + .map(|rec| { + let vma: WordAddr = rec.addr.into(); + MemFinalRecord { + value: vm.peek_memory(vma), + cycle: *final_access.get(&vma).unwrap_or(&0), + } + }) + .collect_vec(); + // assign table circuits - zkvm_witness - .assign_table_circuit::>(&zkvm_cs, &u16_range_config, &()) - .unwrap(); - zkvm_witness - .assign_table_circuit::>(&zkvm_cs, &and_config, &()) - .unwrap(); - zkvm_witness - .assign_table_circuit::>(&zkvm_cs, <u_config, &()) + config + .assign_table_circuit(&zkvm_cs, &mut zkvm_witness, ®_final, &mem_final) .unwrap(); + + // assign program circuit zkvm_witness - .assign_table_circuit::>( + .assign_table_circuit::>( &zkvm_cs, &prog_config, - &program_code.len(), + &PROGRAM_CODE.len(), ) .unwrap(); @@ -245,7 +212,7 @@ fn main() { let transcript = Transcript::new(b"riscv"); let mut zkvm_proof = prover - .create_proof(zkvm_witness, pi, max_threads, transcript) + .create_proof(zkvm_witness, pi, transcript) .expect("create_proof failed"); println!( diff --git a/ceno_zkvm/src/chip_handler.rs b/ceno_zkvm/src/chip_handler.rs index b359588e8..8d16f342d 100644 --- a/ceno_zkvm/src/chip_handler.rs +++ b/ceno_zkvm/src/chip_handler.rs @@ -48,27 +48,27 @@ pub trait RegisterChipOperations, N: FnOnce( ) -> Result<(Expression, AssertLTConfig), ZKVMError>; } +/// The common representation of a memory address. +pub type AddressExpr = Expression; + /// The common representation of a memory value. -/// Format: `[u16; UINT_LIMBS]`, least-significant-first. -pub type MemoryExpr = [Expression; UINT_LIMBS]; +pub type MemoryExpr = Expression; pub trait MemoryChipOperations, N: FnOnce() -> NR> { - #[allow(dead_code)] fn memory_read( &mut self, name_fn: N, - memory_addr: &MemoryExpr, + memory_addr: &AddressExpr, prev_ts: Expression, ts: Expression, value: MemoryExpr, ) -> Result<(Expression, AssertLTConfig), ZKVMError>; #[allow(clippy::too_many_arguments)] - #[allow(dead_code)] fn memory_write( &mut self, name_fn: N, - memory_addr: &MemoryExpr, + memory_addr: &AddressExpr, prev_ts: Expression, ts: Expression, prev_values: MemoryExpr, diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index e3ae37145..429bba66d 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -4,13 +4,13 @@ use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, error::ZKVMError, expression::{Expression, Fixed, Instance, ToExpr, WitIn}, - instructions::riscv::constants::EXIT_CODE_IDX, + instructions::riscv::constants::{ + END_CYCLE_IDX, END_PC_IDX, EXIT_CODE_IDX, INIT_CYCLE_IDX, INIT_PC_IDX, + }, structs::ROMType, tables::InsnRecord, }; -use super::utils::rlc_chip_record; - impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { pub fn new(cs: &'a mut ConstraintSystem) -> Self { Self { cs } @@ -40,21 +40,39 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { ]) } + pub fn query_init_pc(&mut self) -> Result { + self.cs.query_instance(|| "init_pc", INIT_PC_IDX) + } + + pub fn query_init_cycle(&mut self) -> Result { + self.cs.query_instance(|| "init_cycle", INIT_CYCLE_IDX) + } + + pub fn query_end_pc(&mut self) -> Result { + self.cs.query_instance(|| "end_pc", END_PC_IDX) + } + + pub fn query_end_cycle(&mut self) -> Result { + self.cs.query_instance(|| "end_cycle", END_CYCLE_IDX) + } + pub fn lk_record( &mut self, name_fn: N, - rlc_record: Expression, + rom_type: ROMType, + items: Vec>, ) -> Result<(), ZKVMError> where NR: Into, N: FnOnce() -> NR, { - self.cs.lk_record(name_fn, rlc_record) + self.cs.lk_record(name_fn, rom_type, items) } pub fn lk_table_record( &mut self, name_fn: N, + table_len: usize, rlc_record: Expression, multiplicity: Expression, ) -> Result<(), ZKVMError> @@ -62,18 +80,39 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { NR: Into, N: FnOnce() -> NR, { - self.cs.lk_table_record(name_fn, rlc_record, multiplicity) + self.cs + .lk_table_record(name_fn, table_len, rlc_record, multiplicity) + } + + pub fn r_table_record( + &mut self, + name_fn: N, + table_len: usize, + rlc_record: Expression, + ) -> Result<(), ZKVMError> + where + NR: Into, + N: FnOnce() -> NR, + { + self.cs.r_table_record(name_fn, table_len, rlc_record) + } + + pub fn w_table_record( + &mut self, + name_fn: N, + table_len: usize, + rlc_record: Expression, + ) -> Result<(), ZKVMError> + where + NR: Into, + N: FnOnce() -> NR, + { + self.cs.w_table_record(name_fn, table_len, rlc_record) } /// Fetch an instruction at a given PC from the Program table. pub fn lk_fetch(&mut self, record: &InsnRecord>) -> Result<(), ZKVMError> { - let rlc_record = { - let mut fields = vec![E::BaseField::from(ROMType::Instruction as u64).expr()]; - fields.extend_from_slice(record.as_slice()); - self.rlc_chip_record(fields) - }; - - self.cs.lk_record(|| "fetch", rlc_record) + self.lk_record(|| "fetch", ROMType::Instruction, record.as_slice().to_vec()) } pub fn read_record( @@ -101,11 +140,29 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { } pub fn rlc_chip_record(&self, records: Vec>) -> Expression { - rlc_chip_record( - records, - self.cs.chip_record_alpha.clone(), - self.cs.chip_record_beta.clone(), - ) + self.cs.rlc_chip_record(records) + } + + pub fn create_u8(&mut self, name_fn: N) -> Result + where + NR: Into, + N: FnOnce() -> NR + Clone, + { + let byte = self.cs.create_witin(name_fn.clone())?; + self.assert_ux::<_, _, 8>(name_fn, byte.expr())?; + + Ok(byte) + } + + pub fn create_u16(&mut self, name_fn: N) -> Result + where + NR: Into, + N: FnOnce() -> NR + Clone, + { + let limb = self.cs.create_witin(name_fn.clone())?; + self.assert_ux::<_, _, 16>(name_fn, limb.expr())?; + + Ok(limb) } pub fn require_zero( @@ -126,8 +183,8 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { pub fn require_equal( &mut self, name_fn: N, - target: Expression, - rlc_record: Expression, + a: Expression, + b: Expression, ) -> Result<(), ZKVMError> where NR: Into, @@ -135,7 +192,10 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { { self.namespace( || "require_equal", - |cb| cb.cs.require_zero(name_fn, target - rlc_record), + |cb| { + cb.cs + .require_zero(name_fn, a.to_monomial_form() - b.to_monomial_form()) + }, ) } @@ -173,6 +233,15 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { ) } + pub fn select( + &mut self, + cond: &Expression, + when_true: &Expression, + when_false: &Expression, + ) -> Expression { + cond.clone() * when_true.clone() + (1 - cond.clone()) * when_false.clone() + } + pub(crate) fn assert_ux( &mut self, name_fn: N, @@ -184,6 +253,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { { match C { 16 => self.assert_u16(name_fn, expr), + 14 => self.assert_u14(name_fn, expr), 8 => self.assert_byte(name_fn, expr), 5 => self.assert_u5(name_fn, expr), c => panic!("Unsupported bit range {c}"), @@ -197,28 +267,25 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { { self.namespace( || "assert_u5", - |cb| { - let items: Vec> = vec![ - Expression::Constant(E::BaseField::from(ROMType::U5 as u64)), - expr, - ]; - let rlc_record = cb.rlc_chip_record(items); - cb.cs.lk_record(name_fn, rlc_record) - }, + |cb| cb.lk_record(name_fn, ROMType::U5, vec![expr]), ) } + fn assert_u14(&mut self, name_fn: N, expr: Expression) -> Result<(), ZKVMError> + where + NR: Into, + N: FnOnce() -> NR, + { + self.lk_record(name_fn, ROMType::U14, vec![expr])?; + Ok(()) + } + fn assert_u16(&mut self, name_fn: N, expr: Expression) -> Result<(), ZKVMError> where NR: Into, N: FnOnce() -> NR, { - let items: Vec> = vec![ - Expression::Constant(E::BaseField::from(ROMType::U16 as u64)), - expr, - ]; - let rlc_record = self.rlc_chip_record(items); - self.lk_record(name_fn, rlc_record)?; + self.lk_record(name_fn, ROMType::U16, vec![expr])?; Ok(()) } @@ -243,9 +310,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { NR: Into, N: FnOnce() -> NR, { - let items: Vec> = vec![(ROMType::U8 as usize).into(), expr]; - let rlc_record = self.rlc_chip_record(items); - self.lk_record(name_fn, rlc_record)?; + self.lk_record(name_fn, ROMType::U8, vec![expr])?; Ok(()) } @@ -275,9 +340,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { b: Expression, c: Expression, ) -> Result<(), ZKVMError> { - let items: Vec> = vec![(rom_type as usize).into(), a, b, c]; - let rlc_record = self.rlc_chip_record(items); - self.lk_record(|| format!("lookup_{:?}", rom_type), rlc_record) + self.lk_record(|| format!("lookup_{:?}", rom_type), rom_type, vec![a, b, c]) } /// Assert `a & b = c` and that `a, b, c` are all bytes. diff --git a/ceno_zkvm/src/chip_handler/memory.rs b/ceno_zkvm/src/chip_handler/memory.rs index a82ce4883..9a58c8a04 100644 --- a/ceno_zkvm/src/chip_handler/memory.rs +++ b/ceno_zkvm/src/chip_handler/memory.rs @@ -1,5 +1,5 @@ use crate::{ - chip_handler::{MemoryChipOperations, MemoryExpr}, + chip_handler::{AddressExpr, MemoryChipOperations, MemoryExpr}, circuit_builder::CircuitBuilder, error::ZKVMError, expression::Expression, @@ -12,11 +12,10 @@ use ff_ext::ExtensionField; impl<'a, E: ExtensionField, NR: Into, N: FnOnce() -> NR> MemoryChipOperations for CircuitBuilder<'a, E> { - #[allow(dead_code)] fn memory_read( &mut self, name_fn: N, - memory_addr: &MemoryExpr, + memory_addr: &AddressExpr, prev_ts: Expression, ts: Expression, value: MemoryExpr, @@ -25,11 +24,11 @@ impl<'a, E: ExtensionField, NR: Into, N: FnOnce() -> NR> MemoryChipOpera // READ (a, v, t) let read_record = cb.rlc_chip_record( [ - vec![Expression::::Constant(E::BaseField::from( - RAMType::Memory as u64, - ))], - memory_addr.to_vec(), - value.to_vec(), + vec![ + Expression::::Constant(E::BaseField::from(RAMType::Memory as u64)), + memory_addr.clone(), + ], + vec![value.clone()], vec![prev_ts.clone()], ] .concat(), @@ -37,11 +36,11 @@ impl<'a, E: ExtensionField, NR: Into, N: FnOnce() -> NR> MemoryChipOpera // Write (a, v, t) let write_record = cb.rlc_chip_record( [ - vec![Expression::::Constant(E::BaseField::from( - RAMType::Memory as u64, - ))], - memory_addr.to_vec(), - value.to_vec(), + vec![ + Expression::::Constant(E::BaseField::from(RAMType::Memory as u64)), + memory_addr.clone(), + ], + vec![value], vec![ts.clone()], ] .concat(), @@ -58,7 +57,7 @@ impl<'a, E: ExtensionField, NR: Into, N: FnOnce() -> NR> MemoryChipOpera UINT_LIMBS, )?; - let next_ts = ts + 1.into(); + let next_ts = ts + 1; Ok((next_ts, lt_cfg)) }) @@ -67,7 +66,7 @@ impl<'a, E: ExtensionField, NR: Into, N: FnOnce() -> NR> MemoryChipOpera fn memory_write( &mut self, name_fn: N, - memory_addr: &MemoryExpr, + memory_addr: &AddressExpr, prev_ts: Expression, ts: Expression, prev_values: MemoryExpr, @@ -77,11 +76,11 @@ impl<'a, E: ExtensionField, NR: Into, N: FnOnce() -> NR> MemoryChipOpera // READ (a, v, t) let read_record = cb.rlc_chip_record( [ - vec![Expression::::Constant(E::BaseField::from( - RAMType::Memory as u64, - ))], - memory_addr.to_vec(), - prev_values.to_vec(), + vec![ + Expression::::Constant(E::BaseField::from(RAMType::Memory as u64)), + memory_addr.clone(), + ], + vec![prev_values], vec![prev_ts.clone()], ] .concat(), @@ -89,11 +88,11 @@ impl<'a, E: ExtensionField, NR: Into, N: FnOnce() -> NR> MemoryChipOpera // Write (a, v, t) let write_record = cb.rlc_chip_record( [ - vec![Expression::::Constant(E::BaseField::from( - RAMType::Memory as u64, - ))], - memory_addr.to_vec(), - value.to_vec(), + vec![ + Expression::::Constant(E::BaseField::from(RAMType::Memory as u64)), + memory_addr.clone(), + ], + vec![value], vec![ts.clone()], ] .concat(), @@ -109,7 +108,7 @@ impl<'a, E: ExtensionField, NR: Into, N: FnOnce() -> NR> MemoryChipOpera UINT_LIMBS, )?; - let next_ts = ts + 1.into(); + let next_ts = ts + 1; Ok((next_ts, lt_cfg)) }) diff --git a/ceno_zkvm/src/chip_handler/register.rs b/ceno_zkvm/src/chip_handler/register.rs index e96d548d5..254a8b6bc 100644 --- a/ceno_zkvm/src/chip_handler/register.rs +++ b/ceno_zkvm/src/chip_handler/register.rs @@ -59,7 +59,7 @@ impl<'a, E: ExtensionField, NR: Into, N: FnOnce() -> NR> RegisterChipOpe UINT_LIMBS, )?; - let next_ts = ts + 1.into(); + let next_ts = ts + 1; Ok((next_ts, lt_cfg)) }) @@ -111,7 +111,7 @@ impl<'a, E: ExtensionField, NR: Into, N: FnOnce() -> NR> RegisterChipOpe UINT_LIMBS, )?; - let next_ts = ts + 1.into(); + let next_ts = ts + 1; #[cfg(test)] { diff --git a/ceno_zkvm/src/chip_handler/utils.rs b/ceno_zkvm/src/chip_handler/utils.rs index 96db0057d..06f9bb192 100644 --- a/ceno_zkvm/src/chip_handler/utils.rs +++ b/ceno_zkvm/src/chip_handler/utils.rs @@ -3,6 +3,7 @@ use std::iter::successors; use crate::expression::Expression; use ff::Field; use ff_ext::ExtensionField; +use itertools::izip; pub fn rlc_chip_record( records: Vec>, @@ -12,12 +13,9 @@ pub fn rlc_chip_record( assert!(!records.is_empty()); let beta_pows = power_sequence(chip_record_beta); - let item_rlc = beta_pows - .into_iter() - .zip(records.iter()) - .map(|(beta, record)| beta * record.clone()) - .reduce(|a, b| a + b) - .expect("reduce error"); + let item_rlc = izip!(records, beta_pows) + .map(|(record, beta)| record * beta) + .sum::>(); item_rlc + chip_record_alpha.clone() } diff --git a/ceno_zkvm/src/circuit_builder.rs b/ceno_zkvm/src/circuit_builder.rs index 90807b95d..f5ef37825 100644 --- a/ceno_zkvm/src/circuit_builder.rs +++ b/ceno_zkvm/src/circuit_builder.rs @@ -5,6 +5,8 @@ use ff_ext::ExtensionField; use mpcs::PolynomialCommitmentScheme; use crate::{ + ROMType, + chip_handler::utils::rlc_chip_record, error::ZKVMError, expression::{Expression, Fixed, Instance, WitIn}, structs::{ProvingKey, VerifyingKey, WitnessId}, @@ -67,6 +69,13 @@ impl NameSpace { pub struct LogupTableExpression { pub multiplicity: Expression, pub values: Expression, + pub table_len: usize, +} + +#[derive(Clone, Debug)] +pub struct SetTableExpression { + pub values: Expression, + pub table_len: usize, } #[derive(Clone, Debug)] @@ -87,6 +96,12 @@ pub struct ConstraintSystem { pub w_expressions: Vec>, pub w_expressions_namespace_map: Vec, + /// init/final ram expression + pub r_table_expressions: Vec>, + pub r_table_expressions_namespace_map: Vec, + pub w_table_expressions: Vec>, + pub w_table_expressions_namespace_map: Vec, + /// lookup expression pub lk_expressions: Vec>, pub lk_expressions_namespace_map: Vec, @@ -110,6 +125,8 @@ pub struct ConstraintSystem { #[cfg(test)] pub debug_map: HashMap>>, + #[cfg(test)] + pub lk_expressions_items_map: Vec<(ROMType, Vec>)>, pub(crate) phantom: PhantomData, } @@ -127,6 +144,10 @@ impl ConstraintSystem { r_expressions_namespace_map: vec![], w_expressions: vec![], w_expressions_namespace_map: vec![], + r_table_expressions: vec![], + r_table_expressions_namespace_map: vec![], + w_table_expressions: vec![], + w_table_expressions_namespace_map: vec![], lk_expressions: vec![], lk_expressions_namespace_map: vec![], lk_table_expressions: vec![], @@ -141,6 +162,8 @@ impl ConstraintSystem { #[cfg(test)] debug_map: HashMap::new(), + #[cfg(test)] + lk_expressions_items_map: vec![], phantom: std::marker::PhantomData, } @@ -215,11 +238,30 @@ impl ConstraintSystem { Ok(i) } + pub fn rlc_chip_record(&self, items: Vec>) -> Expression { + rlc_chip_record( + items, + self.chip_record_alpha.clone(), + self.chip_record_beta.clone(), + ) + } + pub fn lk_record, N: FnOnce() -> NR>( &mut self, name_fn: N, - rlc_record: Expression, + rom_type: ROMType, + items: Vec>, ) -> Result<(), ZKVMError> { + let rlc_record = self.rlc_chip_record( + std::iter::once(Expression::Constant(E::BaseField::from(rom_type as u64))) + .chain( + #[cfg(test)] + items.clone(), + #[cfg(not(test))] + items, + ) + .collect(), + ); assert_eq!( rlc_record.degree(), 1, @@ -229,12 +271,15 @@ impl ConstraintSystem { self.lk_expressions.push(rlc_record); let path = self.ns.compute_path(name_fn().into()); self.lk_expressions_namespace_map.push(path); + #[cfg(test)] + self.lk_expressions_items_map.push((rom_type, items)); Ok(()) } pub fn lk_table_record( &mut self, name_fn: N, + table_len: usize, rlc_record: Expression, multiplicity: Expression, ) -> Result<(), ZKVMError> @@ -251,6 +296,7 @@ impl ConstraintSystem { self.lk_table_expressions.push(LogupTableExpression { values: rlc_record, multiplicity, + table_len, }); let path = self.ns.compute_path(name_fn().into()); self.lk_table_expressions_namespace_map.push(path); @@ -258,6 +304,58 @@ impl ConstraintSystem { Ok(()) } + pub fn r_table_record( + &mut self, + name_fn: N, + table_len: usize, + rlc_record: Expression, + ) -> Result<(), ZKVMError> + where + NR: Into, + N: FnOnce() -> NR, + { + assert_eq!( + rlc_record.degree(), + 1, + "rlc record degree {} != 1", + rlc_record.degree() + ); + self.r_table_expressions.push(SetTableExpression { + values: rlc_record, + table_len, + }); + let path = self.ns.compute_path(name_fn().into()); + self.r_table_expressions_namespace_map.push(path); + + Ok(()) + } + + pub fn w_table_record( + &mut self, + name_fn: N, + table_len: usize, + rlc_record: Expression, + ) -> Result<(), ZKVMError> + where + NR: Into, + N: FnOnce() -> NR, + { + assert_eq!( + rlc_record.degree(), + 1, + "rlc record degree {} != 1", + rlc_record.degree() + ); + self.w_table_expressions.push(SetTableExpression { + values: rlc_record, + table_len, + }); + let path = self.ns.compute_path(name_fn().into()); + self.w_table_expressions_namespace_map.push(path); + + Ok(()) + } + pub fn read_record, N: FnOnce() -> NR>( &mut self, name_fn: N, diff --git a/ceno_zkvm/src/expression.rs b/ceno_zkvm/src/expression.rs index e119198ba..e59a3985d 100644 --- a/ceno_zkvm/src/expression.rs +++ b/ceno_zkvm/src/expression.rs @@ -21,7 +21,7 @@ use crate::{ structs::{ChallengeId, WitnessId}, }; -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] pub enum Expression { /// WitIn(Id) WitIn(WitnessId), @@ -317,7 +317,7 @@ impl Add for Expression { impl Sum for Expression { fn sum>>(iter: I) -> Expression { - iter.fold(Expression::Constant(E::BaseField::ZERO), |acc, x| acc + x) + iter.fold(Expression::ZERO, |acc, x| acc + x) } } @@ -442,6 +442,42 @@ impl Sub for Expression { } } +macro_rules! binop_instances { + ($op: ident, $fun: ident, ($($t:ty),*)) => { + $(impl $op> for $t { + type Output = Expression; + + fn $fun(self, rhs: Expression) -> Expression { + Expression::::from(self).$fun(rhs) + } + } + + impl $op<$t> for Expression { + type Output = Expression; + + fn $fun(self, rhs: $t) -> Expression { + self.$fun(Expression::::from(rhs)) + } + })* + }; +} + +binop_instances!( + Add, + add, + (u8, u16, u32, u64, usize, i8, i16, i32, i64, i128, isize) +); +binop_instances!( + Sub, + sub, + (u8, u16, u32, u64, usize, i8, i16, i32, i64, i128, isize) +); +binop_instances!( + Mul, + mul, + (u8, u16, u32, u64, usize, i8, i16, i32, i64, i128, isize) +); + impl Mul for Expression { type Output = Expression; fn mul(self, rhs: Expression) -> Expression { @@ -603,22 +639,6 @@ impl WitIn { } } -#[macro_export] -/// this is to avoid non-monomial expression -macro_rules! create_witin_from_expr { - // Handle the case for a single expression - ($name:expr, $builder:expr, $debug:expr, $e:expr) => { - WitIn::from_expr($name, $builder, $e, $debug) - }; - // Recursively handle multiple expressions and create a flat tuple with error handling - ($name:expr, $builder:expr, $debug:expr, $e:expr, $($rest:expr),+) => { - { - // Return a Result tuple, handling errors - Ok::<_, ZKVMError>((WitIn::from_expr($name, $builder, $e, $debug)?, $(WitIn::from_expr($name, $builder, $rest)?),*)) - } - }; -} - pub trait ToExpr { type Output; fn expr(&self) -> Self::Output; @@ -638,6 +658,20 @@ impl ToExpr for &WitIn { } } +impl ToExpr for Fixed { + type Output = Expression; + fn expr(&self) -> Expression { + Expression::Fixed(*self) + } +} + +impl ToExpr for &Fixed { + type Output = Expression; + fn expr(&self) -> Expression { + Expression::Fixed(**self) + } +} + impl ToExpr for Instance { type Output = Expression; fn expr(&self) -> Expression { @@ -703,11 +737,13 @@ pub mod fmt { pub fn expr( expression: &Expression, wtns: &mut Vec, - add_prn_sum: bool, + add_parens_sum: bool, ) -> String { match expression { Expression::WitIn(wit_in) => { - wtns.push(*wit_in); + if !wtns.contains(wit_in) { + wtns.push(*wit_in); + } format!("WitIn({})", wit_in) } Expression::Challenge(id, pow, scaler, offset) => { @@ -728,12 +764,18 @@ pub mod fmt { s } } - Expression::Constant(constant) => base_field::(constant, true).to_string(), + Expression::Constant(constant) => { + base_field::(constant, true).to_string() + } Expression::Fixed(fixed) => format!("{:?}", fixed), Expression::Instance(i) => format!("{:?}", i), Expression::Sum(left, right) => { let s = format!("{} + {}", expr(left, wtns, false), expr(right, wtns, false)); - if add_prn_sum { format!("({})", s) } else { s } + if add_parens_sum { + format!("({})", s) + } else { + s + } } Expression::Product(left, right) => { format!("{} * {}", expr(left, wtns, true), expr(right, wtns, true)) @@ -745,7 +787,11 @@ pub mod fmt { expr(x, wtns, true), expr(b, wtns, false) ); - if add_prn_sum { format!("({})", s) } else { s } + if add_parens_sum { + format!("({})", s) + } else { + s + } } } } @@ -757,7 +803,7 @@ pub mod fmt { let data = field .as_bases() .iter() - .map(|b| base_field::(b, false)) + .map(|b| base_field::(b, false)) .collect::>(); let only_one_limb = field.as_bases()[1..].iter().all(|&x| x == 0.into()); @@ -768,29 +814,26 @@ pub mod fmt { } } - pub fn base_field(base_field: &E::BaseField, add_prn: bool) -> String { + pub fn base_field(base_field: &F, add_parens: bool) -> String { let value = base_field.to_canonical_u64(); - if value > E::BaseField::MODULUS_U64 - u16::MAX as u64 { + if value > F::MODULUS_U64 - u16::MAX as u64 { // beautiful format for negative number > -65536 - prn(format!("-{}", E::BaseField::MODULUS_U64 - value), add_prn) + parens(format!("-{}", F::MODULUS_U64 - value), add_parens) } else if value < u16::MAX as u64 { format!("{value}") } else { // hex - if value > E::BaseField::MODULUS_U64 - (u32::MAX as u64 + u16::MAX as u64) { - prn( - format!("-{:#x}", E::BaseField::MODULUS_U64 - value), - add_prn, - ) + if value > F::MODULUS_U64 - (u32::MAX as u64 + u16::MAX as u64) { + parens(format!("-{:#x}", F::MODULUS_U64 - value), add_parens) } else { format!("{value:#x}") } } } - pub fn prn(s: String, add_prn: bool) -> String { - if add_prn { format!("({})", s) } else { s } + pub fn parens(s: String, add_parens: bool) -> String { + if add_parens { format!("({})", s) } else { s } } #[cfg(test)] @@ -801,18 +844,17 @@ pub mod fmt { wits_in_name: &[String], ) -> String { use itertools::Itertools; + use multilinear_extensions::mle::FieldType; wtns.iter() .sorted() .map(|wt_id| { let wit = &wits_in[*wt_id as usize]; let name = &wits_in_name[*wt_id as usize]; - let value_fmt = if let Some(e) = wit.get_ext_field_vec_optn() { - field(&e[inst_id]) - } else if let Some(bf) = wit.get_base_field_vec_optn() { - base_field::(&bf[inst_id], true) - } else { - "Unknown".to_string() + let value_fmt = match wit.evaluations() { + FieldType::Base(vec) => base_field::(&vec[inst_id], true), + FieldType::Ext(vec) => field(&vec[inst_id]), + FieldType::Unreachable => unreachable!(), }; format!(" WitIn({wt_id})={value_fmt} {name:?}") }) diff --git a/ceno_zkvm/src/expression/monomial.rs b/ceno_zkvm/src/expression/monomial.rs index 8b159c950..60814a6fd 100644 --- a/ceno_zkvm/src/expression/monomial.rs +++ b/ceno_zkvm/src/expression/monomial.rs @@ -1,6 +1,4 @@ use ff_ext::ExtensionField; -use goldilocks::SmallField; -use std::cmp::Ordering; use super::Expression; use Expression::*; @@ -94,43 +92,6 @@ struct Term { vars: Vec>, } -// Define a lexicographic order for expressions. It compares the types first, then the arguments left-to-right. -impl Ord for Expression { - fn cmp(&self, other: &Self) -> Ordering { - use Ordering::*; - - match (self, other) { - (Fixed(a), Fixed(b)) => a.cmp(b), - (WitIn(a), WitIn(b)) => a.cmp(b), - (Instance(a), Instance(b)) => a.cmp(b), - (Challenge(a, b, c, d), Challenge(e, f, g, h)) => { - (a, b, for_cmp_ext(c), for_cmp_ext(d)).cmp(&(e, f, for_cmp_ext(g), for_cmp_ext(h))) - } - (Fixed(_), _) => Less, - (Instance(_), Fixed(_)) => Greater, - (Instance(_), _) => Less, - (WitIn(_), Fixed(_)) => Greater, - (WitIn(_), Instance(_)) => Greater, - (WitIn(_), _) => Less, - (Challenge(..), Fixed(_)) => Greater, - (Challenge(..), Instance(_)) => Greater, - (Challenge(..), WitIn(_)) => Greater, - // TODO(Matthias): remove this, and check whether we got all the cases. - _ => unreachable!(), - } - } -} - -impl PartialOrd for Expression { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -fn for_cmp_ext(a: &E) -> Vec { - a.as_bases().iter().map(|f| f.to_canonical_u64()).collect() -} - #[cfg(test)] mod tests { use crate::{expression::Fixed as FixedS, scheme::utils::eval_by_expr_with_fixed}; diff --git a/ceno_zkvm/src/gadgets/is_lt.rs b/ceno_zkvm/src/gadgets/is_lt.rs index 083d9438e..f8d40cdee 100644 --- a/ceno_zkvm/src/gadgets/is_lt.rs +++ b/ceno_zkvm/src/gadgets/is_lt.rs @@ -1,15 +1,17 @@ use std::{fmt::Display, mem::MaybeUninit}; -use ceno_emul::SWord; +use ceno_emul::{SWord, Word}; use ff_ext::ExtensionField; use goldilocks::SmallField; use itertools::izip; use crate::{ + Value, chip_handler::utils::power_sequence, circuit_builder::CircuitBuilder, error::ZKVMError, expression::{Expression, ToExpr, WitIn}, + instructions::riscv::constants::{UINT_LIMBS, UInt}, set_val, witness::LkMultiplicity, }; @@ -53,7 +55,7 @@ impl AssertLTConfig { lhs: u64, rhs: u64, ) -> Result<(), ZKVMError> { - self.0.assign_instance(instance, lkm, true, lhs, rhs)?; + self.0.assign_instance(instance, lkm, lhs, rhs)?; Ok(()) } } @@ -107,10 +109,21 @@ impl IsLtConfig { lhs: u64, rhs: u64, ) -> Result<(), ZKVMError> { - let is_lt = lhs < rhs; - set_val!(instance, self.is_lt, is_lt as u64); + set_val!(instance, self.is_lt, (lhs < rhs) as u64); + self.config.assign_instance(instance, lkm, lhs, rhs)?; + Ok(()) + } + + pub fn assign_instance_signed( + &self, + instance: &mut [MaybeUninit], + lkm: &mut LkMultiplicity, + lhs: SWord, + rhs: SWord, + ) -> Result<(), ZKVMError> { + set_val!(instance, self.is_lt, (lhs < rhs) as u64); self.config - .assign_instance(instance, lkm, is_lt, lhs, rhs)?; + .assign_instance_signed(instance, lkm, lhs, rhs)?; Ok(()) } } @@ -157,7 +170,7 @@ impl InnerLtConfig { .map(|(record, beta)| beta * record.expr()) .sum::>(); - let range = Self::range(max_num_u16_limbs).into(); + let range = Self::range(max_num_u16_limbs); cb.require_equal(|| name.clone(), lhs - rhs, diff_expr - is_lt_expr * range)?; @@ -171,11 +184,10 @@ impl InnerLtConfig { &self, instance: &mut [MaybeUninit], lkm: &mut LkMultiplicity, - is_lt: bool, lhs: u64, rhs: u64, ) -> Result<(), ZKVMError> { - let diff = cal_lt_diff(is_lt, self.max_num_u16_limbs, lhs, rhs); + let diff = cal_lt_diff(lhs < rhs, self.max_num_u16_limbs, lhs, rhs); self.diff.iter().enumerate().for_each(|(i, wit)| { // extract the 16 bit limb from diff and assign to instance let val = (diff >> (i * u16::BITS as usize)) & 0xffff; @@ -190,14 +202,13 @@ impl InnerLtConfig { &self, instance: &mut [MaybeUninit], lkm: &mut LkMultiplicity, - is_signed_lt: bool, lhs: SWord, rhs: SWord, ) -> Result<(), ZKVMError> { - let diff = if is_signed_lt { - Self::range(self.diff.len()) - (rhs - lhs) as u64 + let diff = if lhs < rhs { + Self::range(self.diff.len()) - lhs.abs_diff(rhs) as u64 } else { - (lhs - rhs) as u64 + lhs.abs_diff(rhs) as u64 }; self.diff.iter().enumerate().for_each(|(i, wit)| { // extract the 16 bit limb from diff and assign to instance @@ -217,3 +228,171 @@ pub fn cal_lt_diff(is_lt: bool, max_num_u16_limbs: usize, lhs: u64, rhs: u64) -> } + lhs - rhs) } + +#[derive(Debug)] +pub struct AssertSignedLtConfig { + config: InnerSignedLtConfig, +} + +impl AssertSignedLtConfig { + pub fn construct_circuit< + E: ExtensionField, + NR: Into + Display + Clone, + N: FnOnce() -> NR, + >( + cb: &mut CircuitBuilder, + name_fn: N, + lhs: &UInt, + rhs: &UInt, + ) -> Result { + cb.namespace( + || "assert_signed_lt", + |cb| { + let name = name_fn(); + let config = + InnerSignedLtConfig::construct_circuit(cb, name, lhs, rhs, Expression::ONE)?; + Ok(Self { config }) + }, + ) + } + + pub fn assign_instance( + &self, + instance: &mut [MaybeUninit], + lkm: &mut LkMultiplicity, + lhs: SWord, + rhs: SWord, + ) -> Result<(), ZKVMError> { + self.config.assign_instance::(instance, lkm, lhs, rhs)?; + Ok(()) + } +} + +#[derive(Debug)] +pub struct SignedLtConfig { + is_lt: WitIn, + config: InnerSignedLtConfig, +} + +impl SignedLtConfig { + pub fn expr(&self) -> Expression { + self.is_lt.expr() + } + + pub fn construct_circuit< + E: ExtensionField, + NR: Into + Display + Clone, + N: FnOnce() -> NR, + >( + cb: &mut CircuitBuilder, + name_fn: N, + lhs: &UInt, + rhs: &UInt, + ) -> Result { + cb.namespace( + || "is_signed_lt", + |cb| { + let name = name_fn(); + let is_lt = cb.create_witin(|| format!("{name} is_signed_lt witin"))?; + cb.assert_bit(|| "is_lt_bit", is_lt.expr())?; + let config = + InnerSignedLtConfig::construct_circuit(cb, name, lhs, rhs, is_lt.expr())?; + + Ok(SignedLtConfig { is_lt, config }) + }, + ) + } + + pub fn assign_instance( + &self, + instance: &mut [MaybeUninit], + lkm: &mut LkMultiplicity, + lhs: SWord, + rhs: SWord, + ) -> Result<(), ZKVMError> { + set_val!(instance, self.is_lt, (lhs < rhs) as u64); + self.config + .assign_instance::(instance, lkm, lhs as SWord, rhs as SWord)?; + Ok(()) + } +} + +#[derive(Debug)] +struct InnerSignedLtConfig { + is_lhs_neg: IsLtConfig, + is_rhs_neg: IsLtConfig, + config: InnerLtConfig, +} + +impl InnerSignedLtConfig { + pub fn construct_circuit + Display + Clone>( + cb: &mut CircuitBuilder, + name: NR, + lhs: &UInt, + rhs: &UInt, + is_lt_expr: Expression, + ) -> Result { + let max_signed_limb_expr: Expression<_> = ((1 << (UInt::::LIMB_BITS - 1)) - 1).into(); + // Extract the sign bit. + let is_lhs_neg = IsLtConfig::construct_circuit( + cb, + || "lhs_msb", + max_signed_limb_expr.clone(), + lhs.limbs.iter().last().unwrap().expr(), // msb limb + 1, + )?; + let is_rhs_neg = IsLtConfig::construct_circuit( + cb, + || "rhs_msb", + max_signed_limb_expr, + rhs.limbs.iter().last().unwrap().expr(), // msb limb + 1, + )?; + + // Convert to field arithmetic. + let lhs_value = lhs.to_field_expr(is_lhs_neg.expr()); + let rhs_value = rhs.to_field_expr(is_rhs_neg.expr()); + let config = InnerLtConfig::construct_circuit( + cb, + format!("{name} (lhs < rhs)"), + lhs_value, + rhs_value, + is_lt_expr, + UINT_LIMBS, + )?; + + Ok(Self { + is_lhs_neg, + is_rhs_neg, + config, + }) + } + + pub fn assign_instance( + &self, + instance: &mut [MaybeUninit], + lkm: &mut LkMultiplicity, + lhs: SWord, + rhs: SWord, + ) -> Result<(), ZKVMError> { + let max_signed_limb = (1u64 << (UInt::::LIMB_BITS - 1)) - 1; + let lhs_value = Value::new_unchecked(lhs as Word); + let rhs_value = Value::new_unchecked(rhs as Word); + self.is_lhs_neg.assign_instance( + instance, + lkm, + max_signed_limb, + *lhs_value.limbs.last().unwrap() as u64, + )?; + self.is_rhs_neg.assign_instance( + instance, + lkm, + max_signed_limb, + *rhs_value.limbs.last().unwrap() as u64, + )?; + + self.config + .assign_instance_signed(instance, lkm, lhs, rhs)?; + Ok(()) + } +} diff --git a/ceno_zkvm/src/gadgets/mod.rs b/ceno_zkvm/src/gadgets/mod.rs index 377c254ef..60846581e 100644 --- a/ceno_zkvm/src/gadgets/mod.rs +++ b/ceno_zkvm/src/gadgets/mod.rs @@ -1,6 +1,11 @@ mod div; mod is_lt; mod is_zero; +mod signed_ext; + pub use div::DivConfig; -pub use is_lt::{AssertLTConfig, InnerLtConfig, IsLtConfig, cal_lt_diff}; +pub use is_lt::{ + AssertLTConfig, AssertSignedLtConfig, InnerLtConfig, IsLtConfig, SignedLtConfig, cal_lt_diff, +}; pub use is_zero::{IsEqualConfig, IsZeroConfig}; +pub use signed_ext::SignedExtendConfig; diff --git a/ceno_zkvm/src/gadgets/signed_ext.rs b/ceno_zkvm/src/gadgets/signed_ext.rs new file mode 100644 index 000000000..b81c009f7 --- /dev/null +++ b/ceno_zkvm/src/gadgets/signed_ext.rs @@ -0,0 +1,93 @@ +use crate::{ + circuit_builder::CircuitBuilder, + error::ZKVMError, + expression::{Expression, ToExpr, WitIn}, + instructions::riscv::constants::UInt, + set_val, + witness::LkMultiplicity, +}; +use ff_ext::ExtensionField; +use std::mem::MaybeUninit; + +pub struct SignedExtendConfig { + /// most significant bit + msb: WitIn, + /// number of bits contained in the value + n_bits: usize, +} + +impl SignedExtendConfig { + pub fn construct_limb( + cb: &mut CircuitBuilder, + val: Expression, + ) -> Result { + Self::construct_circuit(cb, 16, val) + } + + pub fn construct_byte( + cb: &mut CircuitBuilder, + val: Expression, + ) -> Result { + Self::construct_circuit(cb, 8, val) + } + + fn construct_circuit( + cb: &mut CircuitBuilder, + n_bits: usize, + val: Expression, // it's assumed that val is within [0, 2^N_BITS) + ) -> Result { + assert!(n_bits == 8 || n_bits == 16); + + let msb = cb.create_witin(|| "msb")?; + // require msb is boolean + cb.assert_bit(|| "msb is boolean", msb.expr())?; + + // assert 2*val - msb*2^N_BITS is within range [0, 2^N_BITS) + // - if val < 2^(N_BITS-1), then 2*val < 2^N_BITS, msb can only be zero. + // - otherwise, 2*val >= 2^N_BITS, then msb can only be one. + let assert_ux = match n_bits { + 8 => CircuitBuilder::::assert_ux::<_, _, 8>, + 16 => CircuitBuilder::::assert_ux::<_, _, 16>, + _ => unreachable!("unsupported n_bits = {}", n_bits), + }; + assert_ux( + cb, + || "0 <= 2*val - msb*2^N_BITS < 2^N_BITS", + 2 * val - msb.expr() * (1 << n_bits), + )?; + + Ok(SignedExtendConfig { msb, n_bits }) + } + + /// Get the signed extended value + pub fn signed_extended_value(&self, val: Expression) -> UInt { + assert_eq!(UInt::::LIMB_BITS, 16); + + let limb0 = match self.n_bits { + 8 => self.msb.expr() * 0xff00 + val, + 16 => val, + _ => unreachable!("unsupported N_BITS = {}", self.n_bits), + }; + UInt::from_exprs_unchecked(vec![limb0, self.msb.expr() * 0xffff]) + } + + pub fn assign_instance( + &self, + instance: &mut [MaybeUninit], + lk_multiplicity: &mut LkMultiplicity, + val: u64, + ) -> Result<(), ZKVMError> { + let msb = val >> (self.n_bits - 1); + + let assert_ux = match self.n_bits { + 8 => LkMultiplicity::assert_ux::<8>, + 16 => LkMultiplicity::assert_ux::<16>, + _ => unreachable!("unsupported n_bits = {}", self.n_bits), + }; + + assert_ux(lk_multiplicity, 2 * val - (msb << self.n_bits)); + set_val!(instance, self.msb, E::BaseField::from(msb)); + + Ok(()) + } +} diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index 8a93401c4..308ea27fb 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -53,7 +53,8 @@ pub trait Instruction { steps.len().div_ceil(nthreads) } else { steps.len() - }; + } + .max(1); let lk_multiplicity = LkMultiplicity::default(); let mut raw_witin = RowMajorMatrix::::new(steps.len(), num_witin); let raw_witin_iter = raw_witin.par_batch_iter_mut(num_instance_per_batch); diff --git a/ceno_zkvm/src/instructions/riscv.rs b/ceno_zkvm/src/instructions/riscv.rs index a87aaa909..96b192d60 100644 --- a/ceno_zkvm/src/instructions/riscv.rs +++ b/ceno_zkvm/src/instructions/riscv.rs @@ -1,5 +1,8 @@ use ceno_emul::InsnKind; +mod rv32im; +pub use rv32im::Rv32imConfig; + pub mod arith; pub mod arith_imm; pub mod branch; @@ -9,9 +12,12 @@ pub mod divu; pub mod ecall; pub mod jump; pub mod logic; +pub mod logic_imm; pub mod mulh; pub mod shift; pub mod shift_imm; +pub mod slt; +pub mod slti; pub mod sltu; mod b_insn; @@ -24,10 +30,12 @@ mod u_insn; mod ecall_insn; mod im_insn; -mod mem; +mod memory; mod s_insn; #[cfg(test)] mod test; +#[cfg(test)] +mod test_utils; pub trait RIVInstruction { const INST_KIND: InsnKind; diff --git a/ceno_zkvm/src/instructions/riscv/arith.rs b/ceno_zkvm/src/instructions/riscv/arith.rs index 0f12b9ffe..1d0d2146a 100644 --- a/ceno_zkvm/src/instructions/riscv/arith.rs +++ b/ceno_zkvm/src/instructions/riscv/arith.rs @@ -169,7 +169,7 @@ impl Instruction for ArithInstruction Instruction for AddiInstruction { circuit_builder: &mut CircuitBuilder, ) -> Result { let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; - let imm = UInt::new_unchecked(|| "imm", circuit_builder)?; + let imm = UInt::new(|| "imm", circuit_builder)?; let rd_written = rs1_read.add(|| "rs1_read + imm", circuit_builder, &imm, true)?; let i_insn = IInstructionConfig::::construct_circuit( @@ -44,6 +44,7 @@ impl Instruction for AddiInstruction { &imm.value(), rs1_read.register_expr(), rd_written.register_expr(), + false, )?; Ok(InstructionConfig { @@ -80,15 +81,15 @@ impl Instruction for AddiInstruction { #[cfg(test)] mod test { - use ceno_emul::{Change, StepRecord}; + use ceno_emul::{Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; use goldilocks::GoldilocksExt2; use itertools::Itertools; use multilinear_extensions::mle::IntoMLEs; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, - instructions::Instruction, - scheme::mock_prover::{MOCK_PC_ADDI, MOCK_PC_ADDI_SUB, MOCK_PROGRAM, MockProver}, + instructions::{Instruction, riscv::test_utils::imm_i}, + scheme::mock_prover::{MOCK_PC_START, MockProver}, }; use super::AddiInstruction; @@ -108,13 +109,14 @@ mod test { .unwrap() .unwrap(); - let (raw_witin, _) = AddiInstruction::::assign_instances( + let insn_code = encode_rv32(InsnKind::ADDI, 2, 0, 4, imm_i(3)); + let (raw_witin, lkm) = AddiInstruction::::assign_instances( &config, cb.cs.num_witin as usize, vec![StepRecord::new_i_instruction( 3, - MOCK_PC_ADDI, - MOCK_PROGRAM[13], + Change::new(MOCK_PC_START, MOCK_PC_START + PC_STEP_SIZE), + insn_code, 1000, Change::new(0, 1003), 0, @@ -130,7 +132,9 @@ mod test { .into_iter() .map(|v| v.into()) .collect_vec(), + &[insn_code], None, + Some(lkm), ); } @@ -149,13 +153,14 @@ mod test { .unwrap() .unwrap(); - let (raw_witin, _) = AddiInstruction::::assign_instances( + let insn_code = encode_rv32(InsnKind::ADDI, 2, 0, 4, imm_i(-3)); + let (raw_witin, lkm) = AddiInstruction::::assign_instances( &config, cb.cs.num_witin as usize, vec![StepRecord::new_i_instruction( 3, - MOCK_PC_ADDI_SUB, - MOCK_PROGRAM[14], + Change::new(MOCK_PC_START, MOCK_PC_START + PC_STEP_SIZE), + insn_code, 1000, Change::new(0, 997), 0, @@ -171,7 +176,9 @@ mod test { .into_iter() .map(|v| v.into()) .collect_vec(), - Some([1.into(), 10000.into()]), + &[insn_code], + None, + Some(lkm), ); } } diff --git a/ceno_zkvm/src/instructions/riscv/b_insn.rs b/ceno_zkvm/src/instructions/riscv/b_insn.rs index 78536ecee..b7c74543f 100644 --- a/ceno_zkvm/src/instructions/riscv/b_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/b_insn.rs @@ -71,9 +71,8 @@ impl BInstructionConfig { ))?; // Branch program counter - let pc_offset = branch_taken_bit.clone() * imm.expr() - - branch_taken_bit * PC_STEP_SIZE.into() - + PC_STEP_SIZE.into(); + let pc_offset = + branch_taken_bit.clone() * imm.expr() - branch_taken_bit * PC_STEP_SIZE + PC_STEP_SIZE; let next_pc = vm_state.next_pc.unwrap(); circuit_builder.require_equal( || "pc_branch", diff --git a/ceno_zkvm/src/instructions/riscv/branch/blt.rs b/ceno_zkvm/src/instructions/riscv/branch/blt.rs index cd3040800..43f07bbfb 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/blt.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/blt.rs @@ -7,11 +7,10 @@ use crate::{ circuit_builder::CircuitBuilder, error::ZKVMError, expression::Expression, + gadgets::SignedLtConfig, instructions::{ Instruction, - riscv::{ - RIVInstruction, b_insn::BInstructionConfig, config::SignedLtConfig, constants::UInt, - }, + riscv::{RIVInstruction, b_insn::BInstructionConfig, constants::UInt}, }, witness::LkMultiplicity, }; diff --git a/ceno_zkvm/src/instructions/riscv/branch/test.rs b/ceno_zkvm/src/instructions/riscv/branch/test.rs index 25fd8e7bd..36746fff3 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/test.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/test.rs @@ -1,4 +1,4 @@ -use ceno_emul::{ByteAddr, Change, PC_STEP_SIZE, StepRecord, Word}; +use ceno_emul::{ByteAddr, Change, PC_STEP_SIZE, StepRecord, Word, encode_rv32}; use goldilocks::GoldilocksExt2; use itertools::Itertools; use multilinear_extensions::mle::IntoMLEs; @@ -7,11 +7,8 @@ use super::*; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, error::ZKVMError, - instructions::Instruction, - scheme::mock_prover::{ - MOCK_PC_BEQ, MOCK_PC_BGE, MOCK_PC_BGEU, MOCK_PC_BLT, MOCK_PC_BLTU, MOCK_PC_BNE, - MOCK_PROGRAM, MockProver, - }, + instructions::{Instruction, riscv::test_utils::imm_b}, + scheme::mock_prover::{MOCK_PC_START, MockProver}, }; const A: Word = 0xbead1010; @@ -37,13 +34,14 @@ fn impl_opcode_beq(equal: bool) { .unwrap() .unwrap(); + let insn_code = encode_rv32(InsnKind::BEQ, 2, 3, 0, imm_b(8)); let pc_offset = if equal { 8 } else { PC_STEP_SIZE }; - let (raw_witin, _lkm) = + let (raw_witin, lkm) = BeqInstruction::assign_instances(&config, cb.cs.num_witin as usize, vec![ StepRecord::new_b_instruction( 3, - Change::new(MOCK_PC_BEQ, MOCK_PC_BEQ + pc_offset), - MOCK_PROGRAM[6], + Change::new(MOCK_PC_START, MOCK_PC_START + pc_offset), + insn_code, A, if equal { A } else { B }, 0, @@ -59,7 +57,9 @@ fn impl_opcode_beq(equal: bool) { .into_iter() .map(|v| v.into()) .collect_vec(), + &[insn_code], None, + Some(lkm), ); } @@ -83,13 +83,14 @@ fn impl_opcode_bne(equal: bool) { .unwrap() .unwrap(); + let insn_code = encode_rv32(InsnKind::BNE, 2, 3, 0, imm_b(8)); let pc_offset = if equal { PC_STEP_SIZE } else { 8 }; - let (raw_witin, _lkm) = + let (raw_witin, lkm) = BneInstruction::assign_instances(&config, cb.cs.num_witin as usize, vec![ StepRecord::new_b_instruction( 3, - Change::new(MOCK_PC_BNE, MOCK_PC_BNE + pc_offset), - MOCK_PROGRAM[7], + Change::new(MOCK_PC_START, MOCK_PC_START + pc_offset), + insn_code, A, if equal { A } else { B }, 0, @@ -105,7 +106,9 @@ fn impl_opcode_bne(equal: bool) { .into_iter() .map(|v| v.into()) .collect_vec(), + &[insn_code], None, + Some(lkm), ); } @@ -127,17 +130,19 @@ fn impl_bltu_circuit(taken: bool, a: u32, b: u32) -> Result<(), ZKVMError> { let config = BltuInstruction::construct_circuit(&mut circuit_builder)?; let pc_after = if taken { - ByteAddr(MOCK_PC_BLTU.0 - 8) + ByteAddr(MOCK_PC_START.0 - 8) } else { - MOCK_PC_BLTU + PC_STEP_SIZE + MOCK_PC_START + PC_STEP_SIZE }; - let (raw_witin, _) = + let insn_code = encode_rv32(InsnKind::BLTU, 2, 3, 0, imm_b(-8)); + println!("{:#b}", insn_code); + let (raw_witin, lkm) = BltuInstruction::assign_instances(&config, circuit_builder.cs.num_witin as usize, vec![ StepRecord::new_b_instruction( 12, - Change::new(MOCK_PC_BLTU, pc_after), - MOCK_PROGRAM[15], + Change::new(MOCK_PC_START, pc_after), + insn_code, a as Word, b as Word, 10, @@ -153,7 +158,9 @@ fn impl_bltu_circuit(taken: bool, a: u32, b: u32) -> Result<(), ZKVMError> { .into_iter() .map(|v| v.into()) .collect_vec(), + &[insn_code], None, + Some(lkm), ); Ok(()) } @@ -176,17 +183,18 @@ fn impl_bgeu_circuit(taken: bool, a: u32, b: u32) -> Result<(), ZKVMError> { let config = BgeuInstruction::construct_circuit(&mut circuit_builder)?; let pc_after = if taken { - ByteAddr(MOCK_PC_BGEU.0 - 8) + ByteAddr(MOCK_PC_START.0 - 8) } else { - MOCK_PC_BGEU + PC_STEP_SIZE + MOCK_PC_START + PC_STEP_SIZE }; - let (raw_witin, _) = + let insn_code = encode_rv32(InsnKind::BGEU, 2, 3, 0, imm_b(-8)); + let (raw_witin, lkm) = BgeuInstruction::assign_instances(&config, circuit_builder.cs.num_witin as usize, vec![ StepRecord::new_b_instruction( 12, - Change::new(MOCK_PC_BGEU, pc_after), - MOCK_PROGRAM[16], + Change::new(MOCK_PC_START, pc_after), + insn_code, a as Word, b as Word, 10, @@ -202,7 +210,9 @@ fn impl_bgeu_circuit(taken: bool, a: u32, b: u32) -> Result<(), ZKVMError> { .into_iter() .map(|v| v.into()) .collect_vec(), + &[insn_code], None, + Some(lkm), ); Ok(()) } @@ -226,17 +236,18 @@ fn impl_blt_circuit(taken: bool, a: i32, b: i32) -> Result<(), ZKVMError> { let config = BltInstruction::construct_circuit(&mut circuit_builder)?; let pc_after = if taken { - ByteAddr(MOCK_PC_BLT.0 - 8) + ByteAddr(MOCK_PC_START.0 - 8) } else { - MOCK_PC_BLT + PC_STEP_SIZE + MOCK_PC_START + PC_STEP_SIZE }; - let (raw_witin, _) = + let insn_code = encode_rv32(InsnKind::BLT, 2, 3, 0, imm_b(-8)); + let (raw_witin, lkm) = BltInstruction::assign_instances(&config, circuit_builder.cs.num_witin as usize, vec![ StepRecord::new_b_instruction( 12, - Change::new(MOCK_PC_BLT, pc_after), - MOCK_PROGRAM[8], + Change::new(MOCK_PC_START, pc_after), + insn_code, a as Word, b as Word, 10, @@ -252,7 +263,9 @@ fn impl_blt_circuit(taken: bool, a: i32, b: i32) -> Result<(), ZKVMError> { .into_iter() .map(|v| v.into()) .collect_vec(), + &[insn_code], None, + Some(lkm), ); Ok(()) } @@ -276,17 +289,18 @@ fn impl_bge_circuit(taken: bool, a: i32, b: i32) -> Result<(), ZKVMError> { let config = BgeInstruction::construct_circuit(&mut circuit_builder)?; let pc_after = if taken { - ByteAddr(MOCK_PC_BGE.0 - 8) + ByteAddr(MOCK_PC_START.0 - 8) } else { - MOCK_PC_BGE + PC_STEP_SIZE + MOCK_PC_START + PC_STEP_SIZE }; - let (raw_witin, _) = + let insn_code = encode_rv32(InsnKind::BGE, 2, 3, 0, imm_b(-8)); + let (raw_witin, lkm) = BgeInstruction::assign_instances(&config, circuit_builder.cs.num_witin as usize, vec![ StepRecord::new_b_instruction( 12, - Change::new(MOCK_PC_BGE, pc_after), - MOCK_PROGRAM[17], + Change::new(MOCK_PC_START, pc_after), + insn_code, a as Word, b as Word, 10, @@ -302,7 +316,9 @@ fn impl_bge_circuit(taken: bool, a: i32, b: i32) -> Result<(), ZKVMError> { .into_iter() .map(|v| v.into()) .collect_vec(), + &[insn_code], None, + Some(lkm), ); Ok(()) } diff --git a/ceno_zkvm/src/instructions/riscv/config.rs b/ceno_zkvm/src/instructions/riscv/config.rs index 0449f386b..cd13e1148 100644 --- a/ceno_zkvm/src/instructions/riscv/config.rs +++ b/ceno_zkvm/src/instructions/riscv/config.rs @@ -1,22 +1,9 @@ -use std::{fmt::Display, mem::MaybeUninit}; +use std::mem::MaybeUninit; -use crate::{ - Value, - circuit_builder::CircuitBuilder, - error::ZKVMError, - expression::{Expression, ToExpr, WitIn}, - gadgets::{InnerLtConfig, IsLtConfig}, - set_val, - utils::i64_to_base, - witness::LkMultiplicity, -}; -use ceno_emul::{SWord, Word}; -use ff_ext::ExtensionField; +use crate::{expression::WitIn, set_val, utils::i64_to_base, witness::LkMultiplicity}; use goldilocks::SmallField; use itertools::Itertools; -use super::constants::{UINT_LIMBS, UInt}; - #[derive(Clone)] pub struct IsEqualConfig { pub is_equal_per_limb: Vec, @@ -201,180 +188,3 @@ impl UIntLtInput<'_> { is_lt > 0 } } - -#[allow(dead_code)] -#[derive(Debug)] -pub struct AssertSignedLtConfig { - config: InnerSignedLtConfig, -} - -impl AssertSignedLtConfig { - pub fn construct_circuit< - E: ExtensionField, - NR: Into + Display + Clone, - N: FnOnce() -> NR, - >( - cb: &mut CircuitBuilder, - name_fn: N, - lhs: &UInt, - rhs: &UInt, - ) -> Result { - cb.namespace( - || "assert_signed_lt", - |cb| { - let name = name_fn(); - let config = - InnerSignedLtConfig::construct_circuit(cb, name, lhs, rhs, Expression::ONE)?; - Ok(Self { config }) - }, - ) - } - - pub fn assign_instance( - &self, - instance: &mut [MaybeUninit], - lkm: &mut LkMultiplicity, - lhs: SWord, - rhs: SWord, - ) -> Result<(), ZKVMError> { - let is_lt = lhs < rhs; - self.config - .assign_instance::(instance, lkm, is_lt, lhs, rhs)?; - Ok(()) - } -} - -#[allow(dead_code)] -#[derive(Debug)] -pub struct SignedLtConfig { - is_lt: WitIn, - config: InnerSignedLtConfig, -} - -impl SignedLtConfig { - pub fn expr(&self) -> Expression { - self.is_lt.expr() - } - - pub fn construct_circuit< - E: ExtensionField, - NR: Into + Display + Clone, - N: FnOnce() -> NR, - >( - cb: &mut CircuitBuilder, - name_fn: N, - lhs: &UInt, - rhs: &UInt, - ) -> Result { - cb.namespace( - || "is_signed_lt", - |cb| { - let name = name_fn(); - let is_lt = cb.create_witin(|| format!("{name} is_signed_lt witin"))?; - cb.assert_bit(|| "is_lt_bit", is_lt.expr())?; - let config = - InnerSignedLtConfig::construct_circuit(cb, name, lhs, rhs, is_lt.expr())?; - - Ok(SignedLtConfig { is_lt, config }) - }, - ) - } - - pub fn assign_instance( - &self, - instance: &mut [MaybeUninit], - lkm: &mut LkMultiplicity, - lhs: SWord, - rhs: SWord, - ) -> Result<(), ZKVMError> { - let is_lt = lhs < rhs; - set_val!(instance, self.is_lt, is_lt as u64); - self.config - .assign_instance::(instance, lkm, is_lt, lhs as SWord, rhs as SWord)?; - Ok(()) - } -} - -#[derive(Debug)] -struct InnerSignedLtConfig { - is_lhs_neg: IsLtConfig, - is_rhs_neg: IsLtConfig, - config: InnerLtConfig, -} - -impl InnerSignedLtConfig { - pub fn construct_circuit + Display + Clone>( - cb: &mut CircuitBuilder, - name: NR, - lhs: &UInt, - rhs: &UInt, - is_lt_expr: Expression, - ) -> Result { - let max_signed_limb_expr: Expression<_> = ((1 << (UInt::::C - 1)) - 1).into(); - // Extract the sign bit. - let is_lhs_neg = IsLtConfig::construct_circuit( - cb, - || "lhs_msb", - max_signed_limb_expr.clone(), - lhs.limbs.iter().last().unwrap().expr(), // msb limb - 1, - )?; - let is_rhs_neg = IsLtConfig::construct_circuit( - cb, - || "rhs_msb", - max_signed_limb_expr, - rhs.limbs.iter().last().unwrap().expr(), // msb limb - 1, - )?; - - // Convert two's complement representation into field arithmetic. - // Example: 0xFFFF_FFFF = 2^32 - 1 --> shift --> -1 - let neg_shift = -Expression::Constant((1_u64 << 32).into()); - let lhs_value = lhs.value() + is_lhs_neg.expr() * neg_shift.clone(); - let rhs_value = rhs.value() + is_rhs_neg.expr() * neg_shift; - - let config = InnerLtConfig::construct_circuit( - cb, - format!("{name}lhs( - &self, - instance: &mut [MaybeUninit], - lkm: &mut LkMultiplicity, - is_signed_lt: bool, - lhs: SWord, - rhs: SWord, - ) -> Result<(), ZKVMError> { - let max_signed_limb = (1u64 << (UInt::::C - 1)) - 1; - let lhs_value = Value::new_unchecked(lhs as Word); - let rhs_value = Value::new_unchecked(rhs as Word); - self.is_lhs_neg.assign_instance( - instance, - lkm, - max_signed_limb, - *lhs_value.limbs.last().unwrap() as u64, - )?; - self.is_rhs_neg.assign_instance( - instance, - lkm, - max_signed_limb, - *rhs_value.limbs.last().unwrap() as u64, - )?; - - self.config - .assign_instance_signed(instance, lkm, is_signed_lt, lhs, rhs)?; - Ok(()) - } -} diff --git a/ceno_zkvm/src/instructions/riscv/constants.rs b/ceno_zkvm/src/instructions/riscv/constants.rs index 479e0d33b..abf5330de 100644 --- a/ceno_zkvm/src/instructions/riscv/constants.rs +++ b/ceno_zkvm/src/instructions/riscv/constants.rs @@ -4,14 +4,21 @@ pub use ceno_emul::PC_STEP_SIZE; pub const ECALL_HALT_OPCODE: [usize; 2] = [0x00_00, 0x00_00]; pub const EXIT_PC: usize = 0; pub const EXIT_CODE_IDX: usize = 0; -pub const VALUE_BIT_WIDTH: usize = 16; + +pub const INIT_PC_IDX: usize = 2; +pub const INIT_CYCLE_IDX: usize = 3; +pub const END_PC_IDX: usize = 4; +pub const END_CYCLE_IDX: usize = 5; + +pub const LIMB_BITS: usize = 16; +pub const LIMB_MASK: u32 = 0xFFFF; #[cfg(feature = "riv32")] pub const BIT_WIDTH: usize = 32usize; #[cfg(feature = "riv64")] pub const BIT_WIDTH: usize = 64usize; -pub type UInt = UIntLimbs; -pub type UIntMul = UIntLimbs<{ 2 * BIT_WIDTH }, VALUE_BIT_WIDTH, E>; +pub type UInt = UIntLimbs; +pub type UIntMul = UIntLimbs<{ 2 * BIT_WIDTH }, LIMB_BITS, E>; /// use UInt for x bits limb size pub type UInt8 = UIntLimbs; -pub const UINT_LIMBS: usize = BIT_WIDTH.div_ceil(VALUE_BIT_WIDTH); +pub const UINT_LIMBS: usize = BIT_WIDTH.div_ceil(LIMB_BITS); diff --git a/ceno_zkvm/src/instructions/riscv/divu.rs b/ceno_zkvm/src/instructions/riscv/divu.rs index 9150389fb..c353a6caf 100644 --- a/ceno_zkvm/src/instructions/riscv/divu.rs +++ b/ceno_zkvm/src/instructions/riscv/divu.rs @@ -62,7 +62,7 @@ impl Instruction for ArithInstruction::M) - 1).into(), + ((1u64 << UInt::::TOTAL_BITS) - 1).into(), outcome_value, )?; @@ -117,11 +117,7 @@ impl Instruction for ArithInstruction::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); let config = cb @@ -187,13 +189,15 @@ mod test { } else { dividend / divisor }; + + let insn_code = encode_rv32(InsnKind::DIVU, 2, 3, 4, 0); // values assignment - let (raw_witin, _) = + let (raw_witin, lkm) = DivUInstruction::assign_instances(&config, cb.cs.num_witin as usize, vec![ StepRecord::new_r_instruction( 3, - MOCK_PC_DIVU, - MOCK_PROGRAM[9], + MOCK_PC_START, + insn_code, dividend, divisor, Change::new(0, outcome), @@ -211,7 +215,8 @@ mod test { .require_equal(|| "assert_outcome", &mut cb, &expected_rd_written) .unwrap(); - MockProver::assert_satisfied( + let expected_errors: &[_] = if is_ok { &[] } else { &[name] }; + MockProver::assert_with_expected_errors( &cb, &raw_witin .de_interleaving() @@ -219,19 +224,27 @@ mod test { .into_iter() .map(|v| v.into()) .collect_vec(), + &[insn_code], + expected_errors, None, + Some(lkm), ); } #[test] fn test_opcode_divu() { - verify("basic", 10, 2, 5); - verify("dividend > divisor", 10, 11, 0); - verify("remainder", 11, 2, 5); - verify("u32::MAX", u32::MAX, u32::MAX, 1); - verify("div u32::MAX", 3, u32::MAX, 0); - verify("u32::MAX div by 2", u32::MAX, 2, u32::MAX / 2); - verify("mul with carries", 1202729773, 171818539, 7); - verify("div by zero", 10, 0, u32::MAX); + verify("basic", 10, 2, 5, true); + verify("dividend > divisor", 10, 11, 0, true); + verify("remainder", 11, 2, 5, true); + verify("u32::MAX", u32::MAX, u32::MAX, 1, true); + verify("div u32::MAX", 3, u32::MAX, 0, true); + verify("u32::MAX div by 2", u32::MAX, 2, u32::MAX / 2, true); + verify("mul with carries", 1202729773, 171818539, 7, true); + verify("div by zero", 10, 0, u32::MAX, true); + } + + #[test] + fn test_opcode_divu_unstatisfied() { + verify("assert_outcome", 10, 2, 3, false); } #[test] @@ -239,7 +252,7 @@ mod test { let mut rng = rand::thread_rng(); let a: u32 = rng.gen(); let b: u32 = rng.gen_range(1..u32::MAX); - verify("random", a, b, a / b); + verify("random", a, b, a / b, true); } } } diff --git a/ceno_zkvm/src/instructions/riscv/ecall/halt.rs b/ceno_zkvm/src/instructions/riscv/ecall/halt.rs index d13117282..21ba4de0c 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/halt.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/halt.rs @@ -52,7 +52,7 @@ impl Instruction for HaltInstruction { || "read x10", E::BaseField::from(ceno_emul::CENO_PLATFORM.reg_arg0() as u64), prev_x10_ts.expr(), - ecall_cfg.ts.expr() + (Tracer::SUBCYCLE_RS2 as usize).into(), + ecall_cfg.ts.expr() + Tracer::SUBCYCLE_RS2, exit_code, )?; diff --git a/ceno_zkvm/src/instructions/riscv/ecall_insn.rs b/ceno_zkvm/src/instructions/riscv/ecall_insn.rs index 115042347..49bc1d67a 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall_insn.rs @@ -33,8 +33,8 @@ impl EcallInstructionConfig { cb.state_in(pc.expr(), ts.expr())?; cb.state_out( - next_pc.map_or(pc.expr() + PC_STEP_SIZE.into(), |next_pc| next_pc), - ts.expr() + (Tracer::SUBCYCLES_PER_INSN as usize).into(), + next_pc.map_or(pc.expr() + PC_STEP_SIZE, |next_pc| next_pc), + ts.expr() + (Tracer::SUBCYCLES_PER_INSN as usize), )?; cb.lk_fetch(&InsnRecord::new( @@ -54,7 +54,7 @@ impl EcallInstructionConfig { || "write x5", E::BaseField::from(CENO_PLATFORM.reg_ecall() as u64), prev_x5_ts.expr(), - ts.expr() + (Tracer::SUBCYCLE_RS1 as usize).into(), + ts.expr() + Tracer::SUBCYCLE_RS1, syscall_id.clone(), syscall_ret_value.map_or(syscall_id, |v| v), )?; diff --git a/ceno_zkvm/src/instructions/riscv/i_insn.rs b/ceno_zkvm/src/instructions/riscv/i_insn.rs index 2b140aedd..4a099477b 100644 --- a/ceno_zkvm/src/instructions/riscv/i_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/i_insn.rs @@ -31,9 +31,10 @@ impl IInstructionConfig { imm: &Expression, rs1_read: RegisterExpr, rd_written: RegisterExpr, + branching: bool, ) -> Result { // State in and out - let vm_state = StateInOut::construct_circuit(circuit_builder, false)?; + let vm_state = StateInOut::construct_circuit(circuit_builder, branching)?; // Registers let rs1 = ReadRS1::construct_circuit(circuit_builder, rs1_read, vm_state.ts)?; diff --git a/ceno_zkvm/src/instructions/riscv/im_insn.rs b/ceno_zkvm/src/instructions/riscv/im_insn.rs index 371ac0c3d..0977dfa55 100644 --- a/ceno_zkvm/src/instructions/riscv/im_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/im_insn.rs @@ -1,5 +1,5 @@ use crate::{ - chip_handler::{MemoryExpr, RegisterExpr}, + chip_handler::{AddressExpr, MemoryExpr, RegisterExpr}, circuit_builder::CircuitBuilder, error::ZKVMError, expression::{Expression, ToExpr}, @@ -29,7 +29,7 @@ impl IMInstructionConfig { imm: &Expression, rs1_read: RegisterExpr, memory_read: MemoryExpr, - memory_addr: MemoryExpr, + memory_addr: AddressExpr, rd_written: RegisterExpr, ) -> Result { let vm_state = StateInOut::construct_circuit(circuit_builder, false)?; diff --git a/ceno_zkvm/src/instructions/riscv/insn_base.rs b/ceno_zkvm/src/instructions/riscv/insn_base.rs index 65ee05c1f..5548786f5 100644 --- a/ceno_zkvm/src/instructions/riscv/insn_base.rs +++ b/ceno_zkvm/src/instructions/riscv/insn_base.rs @@ -1,10 +1,12 @@ -use ceno_emul::StepRecord; +use ceno_emul::{StepRecord, Word}; +use ff::Field; use ff_ext::ExtensionField; +use itertools::Itertools; use super::constants::{PC_STEP_SIZE, UINT_LIMBS, UInt}; use crate::{ chip_handler::{ - GlobalStateRegisterMachineChipOperations, MemoryChipOperations, MemoryExpr, + AddressExpr, GlobalStateRegisterMachineChipOperations, MemoryChipOperations, MemoryExpr, RegisterChipOperations, RegisterExpr, }, circuit_builder::CircuitBuilder, @@ -17,7 +19,7 @@ use crate::{ }; use ceno_emul::Tracer; use core::mem::MaybeUninit; -use std::marker::PhantomData; +use std::{iter, marker::PhantomData}; #[derive(Debug)] pub struct StateInOut { @@ -40,10 +42,10 @@ impl StateInOut { let next_pc = circuit_builder.create_witin(|| "next_pc")?; (Some(next_pc), next_pc.expr()) } else { - (None, pc.expr() + PC_STEP_SIZE.into()) + (None, pc.expr() + PC_STEP_SIZE) }; let ts = circuit_builder.create_witin(|| "ts")?; - let next_ts = ts.expr() + (Tracer::SUBCYCLES_PER_INSN as usize).into(); + let next_ts = ts.expr() + Tracer::SUBCYCLES_PER_INSN; circuit_builder.state_in(pc.expr(), ts.expr())?; circuit_builder.state_out(next_pc_expr, next_ts)?; @@ -91,7 +93,7 @@ impl ReadRS1 { || "read_rs1", id, prev_ts.expr(), - cur_ts.expr() + (Tracer::SUBCYCLE_RS1 as usize).into(), + cur_ts.expr() + Tracer::SUBCYCLE_RS1, rs1_read, )?; @@ -146,7 +148,7 @@ impl ReadRS2 { || "read_rs2", id, prev_ts.expr(), - cur_ts.expr() + (Tracer::SUBCYCLE_RS2 as usize).into(), + cur_ts.expr() + Tracer::SUBCYCLE_RS2, rs2_read, )?; @@ -202,7 +204,7 @@ impl WriteRD { || "write_rd", id, prev_ts.expr(), - cur_ts.expr() + (Tracer::SUBCYCLE_RD as usize).into(), + cur_ts.expr() + Tracer::SUBCYCLE_RD, prev_value.register_expr(), rd_written, )?; @@ -252,8 +254,8 @@ pub struct ReadMEM { impl ReadMEM { pub fn construct_circuit( circuit_builder: &mut CircuitBuilder, - mem_addr: MemoryExpr, - mem_read: [Expression; UINT_LIMBS], + mem_addr: AddressExpr, + mem_read: Expression, cur_ts: WitIn, ) -> Result { let prev_ts = circuit_builder.create_witin(|| "prev_ts")?; @@ -261,7 +263,7 @@ impl ReadMEM { || "read_memory", &mem_addr, prev_ts.expr(), - cur_ts.expr() + (Tracer::SUBCYCLE_MEM as usize).into(), + cur_ts.expr() + Tracer::SUBCYCLE_MEM, mem_read, )?; @@ -298,39 +300,34 @@ impl ReadMEM { } #[derive(Debug)] -pub struct WriteMEM { +pub struct WriteMEM { pub prev_ts: WitIn, - pub prev_value: UInt, pub lt_cfg: AssertLTConfig, } -impl WriteMEM { - pub fn construct_circuit( +impl WriteMEM { + pub fn construct_circuit( circuit_builder: &mut CircuitBuilder, - mem_addr: MemoryExpr, - mem_written: [Expression; UINT_LIMBS], + mem_addr: AddressExpr, + prev_value: MemoryExpr, + new_value: MemoryExpr, cur_ts: WitIn, ) -> Result { let prev_ts = circuit_builder.create_witin(|| "prev_ts")?; - let prev_value = UInt::new_unchecked(|| "prev_memory_value", circuit_builder)?; let (_, lt_cfg) = circuit_builder.memory_write( || "write_memory", &mem_addr, prev_ts.expr(), - cur_ts.expr() + (Tracer::SUBCYCLE_RD as usize).into(), - prev_value.memory_expr(), - mem_written, + cur_ts.expr() + Tracer::SUBCYCLE_MEM, + prev_value, + new_value, )?; - Ok(WriteMEM { - prev_ts, - prev_value, - lt_cfg, - }) + Ok(WriteMEM { prev_ts, lt_cfg }) } - pub fn assign_instance( + pub fn assign_instance( &self, instance: &mut [MaybeUninit<::BaseField>], lk_multiplicity: &mut LkMultiplicity, @@ -342,13 +339,6 @@ impl WriteMEM { step.memory_op().unwrap().previous_cycle ); - // Memory State - self.prev_value.assign_value( - instance, - Value::new_unchecked(step.memory_op().unwrap().value.before), - ); - - // Memory Write self.lt_cfg.assign_instance( instance, lk_multiplicity, @@ -359,3 +349,210 @@ impl WriteMEM { Ok(()) } } + +#[derive(Debug)] +pub struct MemAddr { + addr: UInt, + low_bits: Vec, +} + +impl MemAddr { + const N_LOW_BITS: usize = 2; + + /// An address which is range-checked, and not aligned. Bits 0 and 1 are variables. + pub fn construct_unaligned(cb: &mut CircuitBuilder) -> Result { + Self::construct(cb, 0) + } + + /// An address which is range-checked, and aligned to 2 bytes. Bit 0 is constant 0. Bit 1 is variable. + pub fn construct_align2(cb: &mut CircuitBuilder) -> Result { + Self::construct(cb, 1) + } + + /// An address which is range-checked, and aligned to 4 bytes. Bits 0 and 1 are constant 0. + pub fn construct_align4(cb: &mut CircuitBuilder) -> Result { + Self::construct(cb, 2) + } + + /// Represent the address as an expression. + pub fn expr_unaligned(&self) -> AddressExpr { + self.addr.address_expr() + } + + /// Represent the address aligned to 2 bytes. + pub fn expr_align2(&self) -> AddressExpr { + self.addr.address_expr() - self.low_bit_exprs()[0].clone() + } + + /// Represent the address aligned to 4 bytes. + pub fn expr_align4(&self) -> AddressExpr { + let low_bits = self.low_bit_exprs(); + self.addr.address_expr() - low_bits[1].clone() * 2 - low_bits[0].clone() + } + + /// Expressions of the low bits of the address, LSB-first: [bit_0, bit_1]. + pub fn low_bit_exprs(&self) -> Vec> { + iter::repeat_n(Expression::ZERO, self.n_zeros()) + .chain(self.low_bits.iter().map(ToExpr::expr)) + .collect() + } + + fn construct(cb: &mut CircuitBuilder, n_zeros: usize) -> Result { + assert!(n_zeros <= Self::N_LOW_BITS); + + // The address as two u16 limbs. + // Soundness: This does not use the UInt range-check but specialized checks instead. + let addr = UInt::new_unchecked(|| "memory_addr", cb)?; + let limbs = addr.expr(); + + // Witness and constrain the non-zero low bits. + let low_bits = (n_zeros..Self::N_LOW_BITS) + .map(|i| { + let bit = cb.create_witin(|| format!("addr_bit_{}", i))?; + cb.assert_bit(|| format!("addr_bit_{}", i), bit.expr())?; + Ok(bit) + }) + .collect::, ZKVMError>>()?; + + // Express the value of the low bits. + let low_sum: Expression = (n_zeros..Self::N_LOW_BITS) + .zip_eq(low_bits.iter()) + .map(|(pos, bit)| bit.expr() * (1 << pos)) + .sum(); + + // Range check the middle bits, that is the low limb excluding the low bits. + let shift_right = E::BaseField::from(1 << Self::N_LOW_BITS) + .invert() + .unwrap() + .expr(); + let mid_u14 = (limbs[0].clone() - low_sum) * shift_right; + cb.assert_ux::<_, _, 14>(|| "mid_u14", mid_u14)?; + + // Range check the high limb. + for high_u16 in limbs.iter().skip(1) { + cb.assert_ux::<_, _, 16>(|| "high_u16", high_u16.clone())?; + } + + Ok(MemAddr { addr, low_bits }) + } + + pub fn assign_instance( + &self, + instance: &mut [MaybeUninit<::BaseField>], + lkm: &mut LkMultiplicity, + addr: Word, + ) -> Result<(), ZKVMError> { + self.addr.assign_value(instance, Value::new_unchecked(addr)); + + // Witness the non-zero low bits. + for (pos, bit) in (self.n_zeros()..Self::N_LOW_BITS).zip_eq(&self.low_bits) { + let b = (addr >> pos) & 1; + set_val!(instance, bit, b as u64); + } + + // Range check the low limb besides the low bits. + let mid_u14 = (addr & 0xffff) >> Self::N_LOW_BITS; + lkm.assert_ux::<14>(mid_u14 as u64); + + // Range check the high limb. + for i in 1..UINT_LIMBS { + let high_u16 = (addr >> (i * 16)) & 0xffff; + lkm.assert_ux::<16>(high_u16 as u64); + } + + Ok(()) + } + + fn n_zeros(&self) -> usize { + Self::N_LOW_BITS - self.low_bits.len() + } +} + +#[cfg(test)] +mod test { + use goldilocks::{Goldilocks as F, GoldilocksExt2 as E}; + use itertools::Itertools; + use multilinear_extensions::mle::IntoMLEs; + + use crate::{ + ROMType, + circuit_builder::{CircuitBuilder, ConstraintSystem}, + error::ZKVMError, + scheme::mock_prover::MockProver, + witness::{LkMultiplicity, RowMajorMatrix}, + }; + + use super::MemAddr; + + #[test] + fn test_mem_addr() -> Result<(), ZKVMError> { + let aligned_1 = 0xbeadbeef; + let aligned_2 = 0xbeadbeee; + let aligned_4 = 0xbeadbeec; + + impl_test_mem_addr(1, aligned_1, true)?; + impl_test_mem_addr(1, aligned_2, true)?; + impl_test_mem_addr(1, aligned_4, true)?; + + impl_test_mem_addr(2, aligned_1, false)?; + impl_test_mem_addr(2, aligned_2, true)?; + impl_test_mem_addr(2, aligned_4, true)?; + + impl_test_mem_addr(4, aligned_1, false)?; + impl_test_mem_addr(4, aligned_2, false)?; + impl_test_mem_addr(4, aligned_4, true)?; + Ok(()) + } + + fn impl_test_mem_addr(align: u32, addr: u32, is_ok: bool) -> Result<(), ZKVMError> { + let mut cs = ConstraintSystem::::new(|| "riscv"); + let mut cb = CircuitBuilder::new(&mut cs); + + let mem_addr = match align { + 1 => MemAddr::construct_unaligned(&mut cb)?, + 2 => MemAddr::construct_align2(&mut cb)?, + 4 => MemAddr::construct_align4(&mut cb)?, + _ => unreachable!(), + }; + + let mut lkm = LkMultiplicity::default(); + let num_rows = 2; + let mut raw_witin = RowMajorMatrix::::new(num_rows, cb.cs.num_witin as usize); + for instance in raw_witin.iter_mut() { + mem_addr.assign_instance(instance, &mut lkm, addr)?; + } + + // Check the range lookups. + let lkm = lkm.into_finalize_result(); + lkm[ROMType::U14 as usize].iter().for_each(|(k, v)| { + assert_eq!(*k, 0xbeef >> 2); + assert_eq!(*v, num_rows); + }); + assert_eq!(lkm[ROMType::U14 as usize].len(), 1); + lkm[ROMType::U16 as usize].iter().for_each(|(k, v)| { + assert_eq!(*k, 0xbead); + assert_eq!(*v, num_rows); + }); + assert_eq!(lkm[ROMType::U16 as usize].len(), 1); + + if is_ok { + cb.require_equal(|| "", mem_addr.expr_unaligned(), addr.into())?; + cb.require_equal(|| "", mem_addr.expr_align2(), (addr >> 1 << 1).into())?; + cb.require_equal(|| "", mem_addr.expr_align4(), (addr >> 2 << 2).into())?; + } + MockProver::assert_with_expected_errors( + &cb, + &raw_witin + .de_interleaving() + .into_mles() + .into_iter() + .map(|v| v.into()) + .collect_vec(), + &[], + if is_ok { &[] } else { &["mid_u14"] }, + None, + None, + ); + Ok(()) + } +} diff --git a/ceno_zkvm/src/instructions/riscv/jump.rs b/ceno_zkvm/src/instructions/riscv/jump.rs index c87326b55..50708d07a 100644 --- a/ceno_zkvm/src/instructions/riscv/jump.rs +++ b/ceno_zkvm/src/instructions/riscv/jump.rs @@ -1,9 +1,11 @@ mod auipc; mod jal; +mod jalr; mod lui; pub use auipc::AuipcInstruction; pub use jal::JalInstruction; +pub use jalr::JalrInstruction; pub use lui::LuiInstruction; #[cfg(test)] diff --git a/ceno_zkvm/src/instructions/riscv/jump/auipc.rs b/ceno_zkvm/src/instructions/riscv/jump/auipc.rs index e60c320bd..6c979ec25 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/auipc.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/auipc.rs @@ -56,7 +56,7 @@ impl Instruction for AuipcInstruction { circuit_builder.require_equal( || "imm+pc = rd_written+2^32*overflow", imm.expr() + u_insn.vm_state.pc.expr(), - rd_written.value() + overflow_bit.expr() * (1u64 << 32).into(), + rd_written.value() + overflow_bit.expr() * (1u64 << 32), )?; Ok(AuipcConfig { diff --git a/ceno_zkvm/src/instructions/riscv/jump/jal.rs b/ceno_zkvm/src/instructions/riscv/jump/jal.rs index be7b4731d..44facf944 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jal.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jal.rs @@ -54,7 +54,7 @@ impl Instruction for JalInstruction { circuit_builder.require_equal( || "jal rd_written", rd_written.value(), - j_insn.vm_state.pc.expr() + PC_STEP_SIZE.into(), + j_insn.vm_state.pc.expr() + PC_STEP_SIZE, )?; Ok(JalConfig { j_insn, rd_written }) diff --git a/ceno_zkvm/src/instructions/riscv/jump/jalr.rs b/ceno_zkvm/src/instructions/riscv/jump/jalr.rs new file mode 100644 index 000000000..889f3eca8 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/jump/jalr.rs @@ -0,0 +1,142 @@ +use std::{marker::PhantomData, mem::MaybeUninit}; + +use ff::Field; +use ff_ext::ExtensionField; + +use crate::{ + Value, + circuit_builder::CircuitBuilder, + error::ZKVMError, + expression::{ToExpr, WitIn}, + instructions::{ + Instruction, + riscv::{constants::UInt, i_insn::IInstructionConfig, insn_base::MemAddr}, + }, + set_val, + tables::InsnRecord, + witness::LkMultiplicity, +}; +use ceno_emul::{InsnKind, PC_STEP_SIZE}; + +pub struct JalrConfig { + pub i_insn: IInstructionConfig, + pub rs1_read: UInt, + pub imm: WitIn, + pub next_pc_addr: MemAddr, + pub overflow: WitIn, + pub rd_written: UInt, +} + +pub struct JalrInstruction(PhantomData); + +/// JALR instruction circuit +/// NOTE: does not validate that next_pc is aligned by 4-byte increments, which +/// should be verified by lookup argument of the next execution step against +/// the program table +impl Instruction for JalrInstruction { + type InstructionConfig = JalrConfig; + + fn name() -> String { + format!("{:?}", InsnKind::JALR) + } + + fn construct_circuit( + circuit_builder: &mut CircuitBuilder, + ) -> Result, ZKVMError> { + let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; // unsigned 32-bit value + let imm = circuit_builder.create_witin(|| "imm")?; // signed 12-bit value + let rd_written = UInt::new(|| "rd_written", circuit_builder)?; + + let i_insn = IInstructionConfig::construct_circuit( + circuit_builder, + InsnKind::JALR, + &imm.expr(), + rs1_read.register_expr(), + rd_written.register_expr(), + true, + )?; + + // Next pc is obtained by rounding rs1+imm down to an even value. + // To implement this, check three conditions: + // 1. rs1 + imm = next_pc_addr + overflow*2^32 + // 2. overflow in {-1, 0, 1} + // 3. next_pc = next_pc_addr aligned to even value (round down) + + let next_pc_addr = MemAddr::::construct_unaligned(circuit_builder)?; + let overflow = circuit_builder.create_witin(|| "overflow")?; + + circuit_builder.require_equal( + || "rs1+imm = next_pc_unrounded + overflow*2^32", + rs1_read.value() + imm.expr(), + next_pc_addr.expr_unaligned() + overflow.expr() * (1u64 << 32), + )?; + + circuit_builder.require_zero( + || "overflow_0_or_pm1", + overflow.expr() * (overflow.expr() - 1) * (overflow.expr() + 1), + )?; + + circuit_builder.require_equal( + || "next_pc_addr = next_pc", + next_pc_addr.expr_align2(), + i_insn.vm_state.next_pc.unwrap().expr(), + )?; + + // write pc+4 to rd + circuit_builder.require_equal( + || "rd_written = pc+4", + rd_written.value(), + i_insn.vm_state.pc.expr() + PC_STEP_SIZE, + )?; + + Ok(JalrConfig { + i_insn, + rs1_read, + imm, + next_pc_addr, + overflow, + rd_written, + }) + } + + fn assign_instance( + config: &Self::InstructionConfig, + instance: &mut [MaybeUninit], + lk_multiplicity: &mut LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), ZKVMError> { + let insn = step.insn(); + + let rs1 = step.rs1().unwrap().value; + let imm: i32 = insn.imm_or_funct7() as i32; + let rd = step.rd().unwrap().value.after; + + let (sum, overflowing) = rs1.overflowing_add_signed(imm); + + config + .rs1_read + .assign_value(instance, Value::new_unchecked(rs1)); + config + .rd_written + .assign_value(instance, Value::new(rd, lk_multiplicity)); + + let imm_field = InsnRecord::imm_or_funct7_field::(&insn); + set_val!(instance, config.imm, imm_field); + + config + .next_pc_addr + .assign_instance(instance, lk_multiplicity, sum)?; + let overflow: E::BaseField = match (overflowing, imm < 0) { + (false, _) => E::BaseField::ZERO, + (true, false) => E::BaseField::ONE, + (true, true) => -E::BaseField::ONE, + }; + set_val!(instance, config.overflow, overflow); + + config + .i_insn + .assign_instance(instance, lk_multiplicity, step)?; + + Ok(()) + } +} diff --git a/ceno_zkvm/src/instructions/riscv/jump/test.rs b/ceno_zkvm/src/instructions/riscv/jump/test.rs index b143dd0b6..a1b17e911 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/test.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/test.rs @@ -1,15 +1,18 @@ -use ceno_emul::{ByteAddr, Change, PC_STEP_SIZE, StepRecord}; +use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, Word, encode_rv32}; use goldilocks::GoldilocksExt2; use itertools::Itertools; use multilinear_extensions::mle::IntoMLEs; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, - instructions::Instruction, - scheme::mock_prover::{MOCK_PC_AUIPC, MOCK_PC_JAL, MOCK_PC_LUI, MOCK_PROGRAM, MockProver}, + instructions::{ + Instruction, + riscv::test_utils::{imm_j, imm_u}, + }, + scheme::mock_prover::{MOCK_PC_START, MockProver}, }; -use super::{AuipcInstruction, JalInstruction, LuiInstruction}; +use super::{AuipcInstruction, JalInstruction, JalrInstruction, LuiInstruction}; #[test] fn test_opcode_jal() { @@ -26,16 +29,17 @@ fn test_opcode_jal() { .unwrap() .unwrap(); - let pc_offset: i32 = -4i32; - let new_pc: ByteAddr = ByteAddr(MOCK_PC_JAL.0.wrapping_add_signed(pc_offset)); - let (raw_witin, _lkm) = JalInstruction::::assign_instances( + let pc_offset: i32 = -8i32; + let new_pc: ByteAddr = ByteAddr(MOCK_PC_START.0.wrapping_add_signed(pc_offset)); + let insn_code = encode_rv32(InsnKind::JAL, 0, 0, 4, imm_j(pc_offset)); + let (raw_witin, lkm) = JalInstruction::::assign_instances( &config, cb.cs.num_witin as usize, vec![StepRecord::new_j_instruction( 4, - Change::new(MOCK_PC_JAL, new_pc), - MOCK_PROGRAM[21], - Change::new(0, (MOCK_PC_JAL + PC_STEP_SIZE).into()), + Change::new(MOCK_PC_START, new_pc), + insn_code, + Change::new(0, (MOCK_PC_START + PC_STEP_SIZE).into()), 0, )], ) @@ -49,7 +53,57 @@ fn test_opcode_jal() { .into_iter() .map(|v| v.into()) .collect_vec(), + &[insn_code], None, + Some(lkm), + ); +} + +#[test] +fn test_opcode_jalr() { + let mut cs = ConstraintSystem::::new(|| "riscv"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = cb + .namespace( + || "jalr", + |cb| { + let config = JalrInstruction::::construct_circuit(cb); + Ok(config) + }, + ) + .unwrap() + .unwrap(); + + let imm = -15i32; + let rs1_read: Word = 10u32; + let new_pc: ByteAddr = ByteAddr(rs1_read.wrapping_add_signed(imm) & (!1)); + let insn_code = encode_rv32(InsnKind::JALR, 2, 0, 4, imm as u32); + + let (raw_witin, lkm) = JalrInstruction::::assign_instances( + &config, + cb.cs.num_witin as usize, + vec![StepRecord::new_i_instruction( + 4, + Change::new(MOCK_PC_START, new_pc), + insn_code, + rs1_read, + Change::new(0, (MOCK_PC_START + PC_STEP_SIZE).into()), + 0, + )], + ) + .unwrap(); + + MockProver::assert_satisfied( + &cb, + &raw_witin + .de_interleaving() + .into_mles() + .into_iter() + .map(|v| v.into()) + .collect_vec(), + &[insn_code], + None, + Some(lkm), ); } @@ -68,16 +122,16 @@ fn test_opcode_lui() { .unwrap() .unwrap(); - let lui_insn = MOCK_PROGRAM[22]; - let imm = lui_insn & 0xfffff000; - let (raw_witin, _lkm) = LuiInstruction::::assign_instances( + let imm_value = imm_u(0x90005); + let insn_code = encode_rv32(InsnKind::LUI, 0, 0, 4, imm_value); + let (raw_witin, lkm) = LuiInstruction::::assign_instances( &config, cb.cs.num_witin as usize, vec![StepRecord::new_u_instruction( 4, - MOCK_PC_LUI, - lui_insn, - Change::new(0, imm), + MOCK_PC_START, + insn_code, + Change::new(0, imm_value), 0, )], ) @@ -91,7 +145,9 @@ fn test_opcode_lui() { .into_iter() .map(|v| v.into()) .collect_vec(), + &[insn_code], None, + Some(lkm), ); } @@ -110,16 +166,16 @@ fn test_opcode_auipc() { .unwrap() .unwrap(); - let auipc_insn = MOCK_PROGRAM[23]; - let imm = auipc_insn & 0xfffff000; - let (raw_witin, _lkm) = AuipcInstruction::::assign_instances( + let imm_value = imm_u(0x90005); + let insn_code = encode_rv32(InsnKind::AUIPC, 0, 0, 4, imm_value); + let (raw_witin, lkm) = AuipcInstruction::::assign_instances( &config, cb.cs.num_witin as usize, vec![StepRecord::new_u_instruction( 4, - MOCK_PC_AUIPC, - auipc_insn, - Change::new(0, MOCK_PC_AUIPC.0.wrapping_add(imm)), + MOCK_PC_START, + insn_code, + Change::new(0, MOCK_PC_START.0.wrapping_add(imm_value)), 0, )], ) @@ -133,6 +189,8 @@ fn test_opcode_auipc() { .into_iter() .map(|v| v.into()) .collect_vec(), + &[insn_code], None, + Some(lkm), ); } diff --git a/ceno_zkvm/src/instructions/riscv/logic/test.rs b/ceno_zkvm/src/instructions/riscv/logic/test.rs index ae3367e8f..50d73751e 100644 --- a/ceno_zkvm/src/instructions/riscv/logic/test.rs +++ b/ceno_zkvm/src/instructions/riscv/logic/test.rs @@ -1,13 +1,12 @@ -use ceno_emul::{Change, StepRecord, Word}; +use ceno_emul::{Change, StepRecord, Word, encode_rv32}; use goldilocks::GoldilocksExt2; use itertools::Itertools; use multilinear_extensions::mle::IntoMLEs; use crate::{ - ROMType, circuit_builder::{CircuitBuilder, ConstraintSystem}, instructions::{Instruction, riscv::constants::UInt8}, - scheme::mock_prover::{MOCK_PC_AND, MOCK_PC_OR, MOCK_PC_XOR, MOCK_PROGRAM, MockProver}, + scheme::mock_prover::{MOCK_PC_START, MockProver}, utils::split_to_u8, }; @@ -15,8 +14,6 @@ use super::*; const A: Word = 0xbead1010; const B: Word = 0xef552020; -// The pair of bytes from A and B. -const LOOKUPS: &[(u64, usize)] = &[(0x2010, 2), (0x55ad, 1), (0xefbe, 1)]; #[test] fn test_opcode_and() { @@ -33,12 +30,13 @@ fn test_opcode_and() { .unwrap() .unwrap(); + let insn_code = encode_rv32(InsnKind::AND, 2, 3, 4, 0); let (raw_witin, lkm) = AndInstruction::assign_instances(&config, cb.cs.num_witin as usize, vec![ StepRecord::new_r_instruction( 3, - MOCK_PC_AND, - MOCK_PROGRAM[3], + MOCK_PC_START, + insn_code, A, B, Change::new(0, A & B), @@ -47,9 +45,6 @@ fn test_opcode_and() { ]) .unwrap(); - let lkm = lkm.into_finalize_result()[ROMType::And as usize].clone(); - assert_eq!(&lkm.into_iter().sorted().collect_vec(), LOOKUPS); - let expected_rd_written = UInt8::from_const_unchecked(split_to_u8::(A & B)); config @@ -65,7 +60,9 @@ fn test_opcode_and() { .into_iter() .map(|v| v.into()) .collect_vec(), + &[insn_code], None, + Some(lkm), ); } @@ -84,12 +81,13 @@ fn test_opcode_or() { .unwrap() .unwrap(); + let insn_code = encode_rv32(InsnKind::OR, 2, 3, 4, 0); let (raw_witin, lkm) = OrInstruction::assign_instances(&config, cb.cs.num_witin as usize, vec![ StepRecord::new_r_instruction( 3, - MOCK_PC_OR, - MOCK_PROGRAM[4], + MOCK_PC_START, + insn_code, A, B, Change::new(0, A | B), @@ -98,9 +96,6 @@ fn test_opcode_or() { ]) .unwrap(); - let lkm = lkm.into_finalize_result()[ROMType::Or as usize].clone(); - assert_eq!(&lkm.into_iter().sorted().collect_vec(), LOOKUPS); - let expected_rd_written = UInt8::from_const_unchecked(split_to_u8::(A | B)); config @@ -116,7 +111,9 @@ fn test_opcode_or() { .into_iter() .map(|v| v.into()) .collect_vec(), + &[insn_code], None, + Some(lkm), ); } @@ -135,12 +132,13 @@ fn test_opcode_xor() { .unwrap() .unwrap(); + let insn_code = encode_rv32(InsnKind::XOR, 2, 3, 4, 0); let (raw_witin, lkm) = XorInstruction::assign_instances(&config, cb.cs.num_witin as usize, vec![ StepRecord::new_r_instruction( 3, - MOCK_PC_XOR, - MOCK_PROGRAM[5], + MOCK_PC_START, + insn_code, A, B, Change::new(0, A ^ B), @@ -149,9 +147,6 @@ fn test_opcode_xor() { ]) .unwrap(); - let lkm = lkm.into_finalize_result()[ROMType::Xor as usize].clone(); - assert_eq!(&lkm.into_iter().sorted().collect_vec(), LOOKUPS); - let expected_rd_written = UInt8::from_const_unchecked(split_to_u8::(A ^ B)); config @@ -167,6 +162,8 @@ fn test_opcode_xor() { .into_iter() .map(|v| v.into()) .collect_vec(), + &[insn_code], None, + Some(lkm), ); } diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm.rs b/ceno_zkvm/src/instructions/riscv/logic_imm.rs new file mode 100644 index 000000000..43b7d509f --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/logic_imm.rs @@ -0,0 +1,26 @@ +mod logic_imm_circuit; +use logic_imm_circuit::{LogicInstruction, LogicOp}; + +use crate::tables::{AndTable, OrTable, XorTable}; +use ceno_emul::InsnKind; + +pub struct AndiOp; +impl LogicOp for AndiOp { + const INST_KIND: InsnKind = InsnKind::ANDI; + type OpsTable = AndTable; +} +pub type AndiInstruction = LogicInstruction; + +pub struct OriOp; +impl LogicOp for OriOp { + const INST_KIND: InsnKind = InsnKind::ORI; + type OpsTable = OrTable; +} +pub type OriInstruction = LogicInstruction; + +pub struct XoriOp; +impl LogicOp for XoriOp { + const INST_KIND: InsnKind = InsnKind::XORI; + type OpsTable = XorTable; +} +pub type XoriInstruction = LogicInstruction; diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs new file mode 100644 index 000000000..8972766f5 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs @@ -0,0 +1,221 @@ +//! The circuit implementation of logic instructions. + +use core::mem::MaybeUninit; +use ff_ext::ExtensionField; +use std::marker::PhantomData; + +use crate::{ + circuit_builder::CircuitBuilder, + error::ZKVMError, + instructions::{ + Instruction, + riscv::{constants::UInt8, i_insn::IInstructionConfig}, + }, + tables::OpsTable, + utils::split_to_u8, + witness::LkMultiplicity, +}; +use ceno_emul::{InsnKind, StepRecord}; + +/// This trait defines a logic instruction, connecting an instruction type to a lookup table. +pub trait LogicOp { + const INST_KIND: InsnKind; + type OpsTable: OpsTable; +} + +/// The Instruction circuit for a given LogicOp. +pub struct LogicInstruction(PhantomData<(E, I)>); + +impl Instruction for LogicInstruction { + type InstructionConfig = LogicConfig; + + fn name() -> String { + format!("{:?}", I::INST_KIND) + } + + fn construct_circuit(cb: &mut CircuitBuilder) -> Result { + let config = LogicConfig::construct_circuit(cb, I::INST_KIND)?; + + // Constrain the registers based on the given lookup table. + UInt8::logic( + cb, + I::OpsTable::ROM_TYPE, + &config.rs1_read, + &config.imm, + &config.rd_written, + )?; + + Ok(config) + } + + fn assign_instance( + config: &Self::InstructionConfig, + instance: &mut [MaybeUninit<::BaseField>], + lkm: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + UInt8::::logic_assign::( + lkm, + step.rs1().unwrap().value.into(), + step.insn().imm_or_funct7().into(), + ); + + config.assign_instance(instance, lkm, step) + } +} + +/// This config implements I-Instructions that represent registers values as 4 * u8. +/// Non-generic code shared by several circuits. +#[derive(Debug)] +pub struct LogicConfig { + i_insn: IInstructionConfig, + + rs1_read: UInt8, + pub(crate) rd_written: UInt8, + imm: UInt8, +} + +impl LogicConfig { + fn construct_circuit( + cb: &mut CircuitBuilder, + insn_kind: InsnKind, + ) -> Result { + let rs1_read = UInt8::new_unchecked(|| "rs1_read", cb)?; + let rd_written = UInt8::new_unchecked(|| "rd_written", cb)?; + let imm = UInt8::new_unchecked(|| "imm", cb)?; + + let i_insn = IInstructionConfig::::construct_circuit( + cb, + insn_kind, + &imm.value(), + rs1_read.register_expr(), + rd_written.register_expr(), + false, + )?; + + Ok(Self { + i_insn, + rs1_read, + imm, + rd_written, + }) + } + + fn assign_instance( + &self, + instance: &mut [MaybeUninit<::BaseField>], + lkm: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + self.i_insn.assign_instance(instance, lkm, step)?; + + let rs1_read = split_to_u8(step.rs1().unwrap().value); + self.rs1_read.assign_limbs(instance, &rs1_read); + + let imm = split_to_u8::(step.insn().imm_or_funct7()); + self.imm.assign_limbs(instance, &imm); + + let rd_written = split_to_u8(step.rd().unwrap().value.after); + self.rd_written.assign_limbs(instance, &rd_written); + + Ok(()) + } +} + +#[cfg(test)] +mod test { + use ceno_emul::{Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; + use goldilocks::GoldilocksExt2; + use itertools::Itertools; + use multilinear_extensions::mle::IntoMLEs; + + use crate::{ + chip_handler::test::DebugIndex, + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{ + Instruction, + riscv::{ + constants::UInt8, + logic_imm::{AndiOp, OriOp, XoriOp, logic_imm_circuit::LogicInstruction}, + }, + }, + scheme::mock_prover::{MOCK_PC_START, MockProver}, + utils::split_to_u8, + }; + + use super::LogicOp; + + #[test] + fn test_opcode_andi() { + verify::("basic", 0x0000_0011, 3, 0x0000_0011 & 3); + verify::("zero result", 0x0000_0100, 3, 0x0000_0100 & 3); + } + + #[test] + fn test_opcode_ori() { + verify::("basic", 0x0000_0011, 3, 0x0000_0011 | 3); + verify::("basic2", 0x0000_0100, 3, 0x0000_0100 | 3); + } + + #[test] + fn test_opcode_xori() { + verify::("basic", 0x0000_0011, 3, 0x0000_0011 ^ 3); + verify::("non-overlap", 0x0000_0100, 3, 0x0000_0100 ^ 3); + } + + fn verify(name: &'static str, rs1_read: u32, imm: u32, expected_rd_written: u32) { + let mut cs = ConstraintSystem::::new(|| "riscv"); + let mut cb = CircuitBuilder::new(&mut cs); + + let (prefix, rd_written) = match I::INST_KIND { + InsnKind::ANDI => ("ANDI", rs1_read & imm), + InsnKind::ORI => ("ORI", rs1_read | imm), + InsnKind::XORI => ("XORI", rs1_read ^ imm), + _ => unreachable!(), + }; + + let config = cb + .namespace( + || format!("{prefix}_({name})"), + |cb| { + let config = LogicInstruction::::construct_circuit(cb); + Ok(config) + }, + ) + .unwrap() + .unwrap(); + + let insn_code = encode_rv32(I::INST_KIND, 2, 0, 4, imm); + let (raw_witin, lkm) = LogicInstruction::::assign_instances( + &config, + cb.cs.num_witin as usize, + vec![StepRecord::new_i_instruction( + 3, + Change::new(MOCK_PC_START, MOCK_PC_START + PC_STEP_SIZE), + insn_code, + rs1_read, + Change::new(0, rd_written), + 0, + )], + ) + .unwrap(); + + let expected = UInt8::from_const_unchecked(split_to_u8::(expected_rd_written)); + let rd_written_expr = cb.get_debug_expr(DebugIndex::RdWrite as usize)[0].clone(); + cb.require_equal(|| "assert_rd_written", rd_written_expr, expected.value()) + .unwrap(); + + MockProver::assert_satisfied( + &cb, + &raw_witin + .de_interleaving() + .into_mles() + .into_iter() + .map(|v| v.into()) + .collect_vec(), + &[insn_code], + None, + Some(lkm), + ); + } +} diff --git a/ceno_zkvm/src/instructions/riscv/mem.rs b/ceno_zkvm/src/instructions/riscv/mem.rs index c6ffd9edf..e69de29bb 100644 --- a/ceno_zkvm/src/instructions/riscv/mem.rs +++ b/ceno_zkvm/src/instructions/riscv/mem.rs @@ -1,172 +0,0 @@ -use crate::{ - Value, - circuit_builder::CircuitBuilder, - error::ZKVMError, - instructions::{ - Instruction, - riscv::{ - RIVInstruction, constants::UInt, im_insn::IMInstructionConfig, - s_insn::SInstructionConfig, - }, - }, - witness::LkMultiplicity, -}; -use ceno_emul::{InsnKind, StepRecord}; -use ff_ext::ExtensionField; -use std::{marker::PhantomData, mem::MaybeUninit}; - -pub struct StoreConfig { - s_insn: SInstructionConfig, - - rs1_read: UInt, - rs2_read: UInt, - imm: UInt, -} - -pub struct StoreInstruction(PhantomData<(E, I)>); - -pub struct SWOp; - -impl RIVInstruction for SWOp { - const INST_KIND: InsnKind = InsnKind::SW; -} - -#[allow(dead_code)] -pub type StoreWord = StoreInstruction; - -impl Instruction for StoreInstruction { - type InstructionConfig = StoreConfig; - - fn name() -> String { - format!("{:?}", I::INST_KIND) - } - - fn construct_circuit( - circuit_builder: &mut CircuitBuilder, - ) -> Result { - let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; - let rs2_read = UInt::new_unchecked(|| "rs2_read", circuit_builder)?; - let imm = UInt::new(|| "imm", circuit_builder)?; - - let memory_addr = rs1_read.add(|| "memory_addr", circuit_builder, &imm, true)?; - - let memory_value = match I::INST_KIND { - InsnKind::SW => rs2_read.memory_expr(), - _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), - }; - - let s_insn = SInstructionConfig::::construct_circuit( - circuit_builder, - I::INST_KIND, - &imm.value(), - rs1_read.register_expr(), - rs2_read.register_expr(), - memory_addr.memory_expr(), - memory_value, - )?; - - Ok(StoreConfig { - s_insn, - rs1_read, - rs2_read, - imm, - }) - } - - fn assign_instance( - config: &Self::InstructionConfig, - instance: &mut [MaybeUninit], - lk_multiplicity: &mut LkMultiplicity, - step: &StepRecord, - ) -> Result<(), ZKVMError> { - let rs1 = Value::new_unchecked(step.rs1().unwrap().value); - let rs2 = Value::new_unchecked(step.rs2().unwrap().value); - let imm = Value::new(step.insn().imm_or_funct7(), lk_multiplicity); - - config - .s_insn - .assign_instance(instance, lk_multiplicity, step)?; - config.rs1_read.assign_value(instance, rs1); - config.rs2_read.assign_value(instance, rs2); - config.imm.assign_value(instance, imm); - - Ok(()) - } -} - -pub struct LoadConfig { - im_insn: IMInstructionConfig, - - rs1_read: UInt, - imm: UInt, - memory_read: UInt, -} - -pub struct LoadInstruction(PhantomData<(E, I)>); - -pub struct LWOp; - -impl RIVInstruction for LWOp { - const INST_KIND: InsnKind = InsnKind::LW; -} - -impl Instruction for LoadInstruction { - type InstructionConfig = LoadConfig; - - fn name() -> String { - format!("{:?}", I::INST_KIND) - } - - fn construct_circuit( - circuit_builder: &mut CircuitBuilder, - ) -> Result { - let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; - let imm = UInt::new(|| "imm", circuit_builder)?; - let memory_read = UInt::new_unchecked(|| "memory_read", circuit_builder)?; - - let (memory_addr, memory_value) = match I::INST_KIND { - InsnKind::LW => ( - rs1_read.add(|| "memory_addr", circuit_builder, &imm, true)?, - memory_read.register_expr(), - ), - _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), - }; - - let im_insn = IMInstructionConfig::::construct_circuit( - circuit_builder, - I::INST_KIND, - &imm.value(), - rs1_read.register_expr(), - memory_read.memory_expr(), - memory_addr.memory_expr(), - memory_value, - )?; - - Ok(LoadConfig { - im_insn, - rs1_read, - memory_read, - imm, - }) - } - - fn assign_instance( - config: &Self::InstructionConfig, - instance: &mut [MaybeUninit], - lk_multiplicity: &mut LkMultiplicity, - step: &StepRecord, - ) -> Result<(), ZKVMError> { - let rs1 = Value::new_unchecked(step.rs1().unwrap().value); - let memory_read = Value::new_unchecked(step.memory_op().unwrap().value.before); - let imm = Value::new(step.insn().imm_or_funct7(), lk_multiplicity); - - config - .im_insn - .assign_instance(instance, lk_multiplicity, step)?; - config.rs1_read.assign_value(instance, rs1); - config.memory_read.assign_value(instance, memory_read); - config.imm.assign_value(instance, imm); - - Ok(()) - } -} diff --git a/ceno_zkvm/src/instructions/riscv/memory.rs b/ceno_zkvm/src/instructions/riscv/memory.rs new file mode 100644 index 000000000..71c8cba72 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/memory.rs @@ -0,0 +1,10 @@ +mod gadget; +pub mod load; +pub mod store; + +#[cfg(test)] +mod test; + +pub use load::LwInstruction; +#[cfg(test)] +pub use store::{SbInstruction, ShInstruction, SwInstruction}; diff --git a/ceno_zkvm/src/instructions/riscv/memory/gadget.rs b/ceno_zkvm/src/instructions/riscv/memory/gadget.rs new file mode 100644 index 000000000..0980253db --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/memory/gadget.rs @@ -0,0 +1,208 @@ +use crate::{ + Value, + circuit_builder::CircuitBuilder, + error::ZKVMError, + expression::{Expression, ToExpr, WitIn}, + instructions::riscv::{constants::UInt, insn_base::MemAddr}, + set_val, + witness::LkMultiplicity, +}; +use ceno_emul::StepRecord; +use ff::Field; +use ff_ext::ExtensionField; +use itertools::izip; +use std::mem::MaybeUninit; + +pub struct MemWordChange { + prev_limb_bytes: Vec, + rs2_limb_bytes: Vec, + + expected_changes: Vec, +} + +impl MemWordChange { + pub(crate) fn construct_circuit( + cb: &mut CircuitBuilder, + addr: &MemAddr, + prev_word: &UInt, + rs2_word: &UInt, + ) -> Result { + let alloc_bytes = |cb: &mut CircuitBuilder, + anno: &str, + num_bytes: usize| + -> Result, ZKVMError> { + (0..num_bytes) + .map(|i| { + let byte = cb.create_witin(|| format!("{}.le_bytes[{}]", anno, i))?; + cb.assert_ux::<_, _, 8>(|| "byte range check", byte.expr())?; + + Ok(byte) + }) + .collect() + }; + + let decompose_limb = |cb: &mut CircuitBuilder, + limb_anno: &str, + limb: &Expression, + num_bytes: usize| + -> Result, ZKVMError> { + let bytes = alloc_bytes(cb, limb_anno, num_bytes)?; + + cb.require_equal( + || format!("decompose {} into {} bytes", limb_anno, num_bytes), + limb.clone(), + bytes + .iter() + .enumerate() + .map(|(idx, byte)| (1 << (idx * 8)) * byte.expr()) + .sum(), + )?; + + Ok(bytes) + }; + + assert_eq!(UInt::::NUM_LIMBS, 2); + // for sb (n_zeros = 0) + match N_ZEROS { + 0 => { + assert!(prev_word.wits_in().is_some() && rs2_word.wits_in().is_some()); + + let low_bits = addr.low_bit_exprs(); + let prev_limbs = prev_word.expr(); + let rs2_limbs = rs2_word.expr(); + + // degree 2 expression + let prev_target_limb = cb.select(&low_bits[1], &prev_limbs[1], &prev_limbs[0]); + let prev_limb_bytes = decompose_limb(cb, "prev_limb", &prev_target_limb, 2)?; + + // extract the least significant byte from u16 limb + let rs2_limb_bytes = alloc_bytes(cb, "rs2_limb[0]", 1)?; + let u8_base_inv = E::BaseField::from(1 << 8).invert().unwrap(); + cb.assert_ux::<_, _, 8>( + || "rs2_limb[0].le_bytes[1]", + u8_base_inv.expr() * (rs2_limbs[0].clone() - rs2_limb_bytes[0].expr()), + )?; + + // alloc a new witIn to cache degree 2 expression + let expected_limb_change = cb.create_witin(|| "expected_limb_change")?; + cb.condition_require_equal( + || "expected_limb_change = select(low_bits[0], rs2 - prev)", + low_bits[0].clone(), + expected_limb_change.expr(), + (1 << 8) * (rs2_limb_bytes[0].expr() - prev_limb_bytes[1].expr()), + rs2_limb_bytes[0].expr() - prev_limb_bytes[0].expr(), + )?; + + // alloc a new witIn to cache degree 2 expression + let expected_change = cb.create_witin(|| "expected_change")?; + cb.condition_require_equal( + || "expected_change = select(low_bits[1], limb_change*2^16, limb_change)", + low_bits[1].clone(), + expected_change.expr(), + (1 << 16) * expected_limb_change.expr(), + expected_limb_change.expr(), + )?; + + Ok(MemWordChange { + prev_limb_bytes, + rs2_limb_bytes, + expected_changes: vec![expected_limb_change, expected_change], + }) + } + // for sh (n_zeros = 1) + 1 => { + assert!(prev_word.wits_in().is_some() && rs2_word.wits_in().is_some()); + + let low_bits = addr.low_bit_exprs(); + let prev_limbs = prev_word.expr(); + let rs2_limbs = rs2_word.expr(); + + let expected_change = cb.create_witin(|| "expected_change")?; + + // alloc a new witIn to cache degree 2 expression + cb.condition_require_equal( + || "expected_change = select(low_bits[1], 2^16*(limb_change))", + // degree 2 expression + low_bits[1].clone(), + expected_change.expr(), + (1 << 16) * (rs2_limbs[0].clone() - prev_limbs[1].clone()), + rs2_limbs[0].clone() - prev_limbs[0].clone(), + )?; + + Ok(MemWordChange { + prev_limb_bytes: vec![], + rs2_limb_bytes: vec![], + expected_changes: vec![expected_change], + }) + } + _ => unreachable!("N_ZEROS cannot be larger than 1"), + } + } + + pub(crate) fn value(&self) -> Expression { + assert!(N_ZEROS <= 1); + + self.expected_changes[1 - N_ZEROS].expr() + } + + pub fn assign_instance( + &self, + instance: &mut [MaybeUninit], + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + shift: u32, + ) -> Result<(), ZKVMError> { + let memory_op = step.memory_op().clone().unwrap(); + let prev_value = Value::new_unchecked(memory_op.value.before); + let rs2_value = Value::new_unchecked(step.rs2().unwrap().value); + + let low_bits = [shift & 1, (shift >> 1) & 1]; + let prev_limb = prev_value.as_u16_limbs()[low_bits[1] as usize]; + let rs2_limb = rs2_value.as_u16_limbs()[0]; + + match N_ZEROS { + 0 => { + for (&col, byte) in izip!(&self.prev_limb_bytes, prev_limb.to_le_bytes()) { + set_val!(instance, col, E::BaseField::from(byte as u64)); + lk_multiplicity.assert_ux::<8>(byte as u64); + } + + set_val!( + instance, + self.rs2_limb_bytes[0], + E::BaseField::from(rs2_limb.to_le_bytes()[0] as u64) + ); + + rs2_limb.to_le_bytes().into_iter().for_each(|byte| { + lk_multiplicity.assert_ux::<8>(byte as u64); + }); + let change = if low_bits[0] == 0 { + E::BaseField::from(rs2_limb.to_le_bytes()[0] as u64) + - E::BaseField::from(prev_limb.to_le_bytes()[0] as u64) + } else { + E::BaseField::from((rs2_limb.to_le_bytes()[0] as u64) << 8) + - E::BaseField::from((prev_limb.to_le_bytes()[1] as u64) << 8) + }; + let final_change = if low_bits[1] == 0 { + change + } else { + E::BaseField::from(1u64 << 16) * change + }; + set_val!(instance, self.expected_changes[0], change); + set_val!(instance, self.expected_changes[1], final_change); + } + 1 => { + let final_change = if low_bits[1] == 0 { + E::BaseField::from(rs2_limb as u64) - E::BaseField::from(prev_limb as u64) + } else { + E::BaseField::from((rs2_limb as u64) << 16) + - E::BaseField::from((prev_limb as u64) << 16) + }; + set_val!(instance, self.expected_changes[0], final_change); + } + _ => unreachable!("N_ZEROS cannot be larger than 1"), + } + + Ok(()) + } +} diff --git a/ceno_zkvm/src/instructions/riscv/memory/load.rs b/ceno_zkvm/src/instructions/riscv/memory/load.rs new file mode 100644 index 000000000..e7d7e20b5 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/memory/load.rs @@ -0,0 +1,257 @@ +// will remove #[allow(dead_code)] when we finished fibonacci integration test +use crate::{ + Value, + circuit_builder::CircuitBuilder, + error::ZKVMError, + expression::{Expression, ToExpr, WitIn}, + gadgets::SignedExtendConfig, + instructions::{ + Instruction, + riscv::{ + RIVInstruction, constants::UInt, im_insn::IMInstructionConfig, insn_base::MemAddr, + }, + }, + set_val, + tables::InsnRecord, + witness::LkMultiplicity, +}; +use ceno_emul::{ByteAddr, InsnKind, StepRecord}; +use ff_ext::ExtensionField; +use itertools::izip; +use std::{marker::PhantomData, mem::MaybeUninit}; + +pub struct LoadConfig { + im_insn: IMInstructionConfig, + + rs1_read: UInt, + imm: WitIn, + memory_addr: MemAddr, + + memory_read: UInt, + target_limb: Option, + target_limb_bytes: Option>, + signed_extend_config: Option, +} + +pub struct LoadInstruction(PhantomData<(E, I)>); + +pub struct LwOp; + +impl RIVInstruction for LwOp { + const INST_KIND: InsnKind = InsnKind::LW; +} + +pub type LwInstruction = LoadInstruction; + +pub struct LhOp; +impl RIVInstruction for LhOp { + const INST_KIND: InsnKind = InsnKind::LH; +} +#[allow(dead_code)] +pub type LhInstruction = LoadInstruction; + +pub struct LhuOp; +impl RIVInstruction for LhuOp { + const INST_KIND: InsnKind = InsnKind::LHU; +} +#[allow(dead_code)] +pub type LhuInstruction = LoadInstruction; + +pub struct LbOp; +impl RIVInstruction for LbOp { + const INST_KIND: InsnKind = InsnKind::LB; +} +#[allow(dead_code)] +pub type LbInstruction = LoadInstruction; + +pub struct LbuOp; +impl RIVInstruction for LbuOp { + const INST_KIND: InsnKind = InsnKind::LBU; +} +#[allow(dead_code)] +pub type LbuInstruction = LoadInstruction; + +impl Instruction for LoadInstruction { + type InstructionConfig = LoadConfig; + + fn name() -> String { + format!("{:?}", I::INST_KIND) + } + + fn construct_circuit( + circuit_builder: &mut CircuitBuilder, + ) -> Result { + let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; // unsigned 32-bit value + let imm = circuit_builder.create_witin(|| "imm")?; // signed 12-bit value + let memory_read = UInt::new(|| "memory_read", circuit_builder)?; + + let memory_addr = match I::INST_KIND { + InsnKind::LW => MemAddr::construct_align4(circuit_builder), + InsnKind::LH | InsnKind::LHU => MemAddr::construct_align2(circuit_builder), + InsnKind::LB | InsnKind::LBU => MemAddr::construct_unaligned(circuit_builder), + _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), + }?; + + circuit_builder.require_equal( + || "memory_addr = rs1_read + imm", + memory_addr.expr_unaligned(), + rs1_read.value() + imm.expr(), + )?; + + let addr_low_bits = memory_addr.low_bit_exprs(); + let memory_value = memory_read.expr(); + + // get target limb from memory word for load instructions except LW + let target_limb = match I::INST_KIND { + InsnKind::LB | InsnKind::LBU | InsnKind::LH | InsnKind::LHU => { + let target_limb = circuit_builder.create_witin(|| "target_limb")?; + circuit_builder.condition_require_equal( + || "target_limb = memory_value[low_bits[1]]", + addr_low_bits[1].clone(), + target_limb.expr(), + memory_value[1].clone(), + memory_value[0].clone(), + )?; + Some(target_limb) + } + _ => None, + }; + + // get target byte from memory word for LB and LBU + let (target_byte_expr, target_limb_bytes) = match I::INST_KIND { + InsnKind::LB | InsnKind::LBU => { + let target_byte = circuit_builder.create_u8(|| "limb.le_bytes[low_bits[0]]")?; + let dummy_byte = circuit_builder.create_u8(|| "limb.le_bytes[1-low_bits[0]]")?; + + circuit_builder.condition_require_equal( + || "target_byte = target_limb[low_bits[0]]", + addr_low_bits[0].clone(), + target_limb.unwrap().expr(), + target_byte.expr() * (1<<8) + dummy_byte.expr(), // target_byte = limb.le_bytes[1] + dummy_byte.expr() * (1<<8) + target_byte.expr(), // target_byte = limb.le_bytes[0] + )?; + + ( + Some(target_byte.expr()), + Some(vec![target_byte, dummy_byte]), + ) + } + _ => (None, None), + }; + let (signed_extend_config, rd_written) = match I::INST_KIND { + InsnKind::LW => (None, memory_read.clone()), + InsnKind::LH => { + let val = target_limb.unwrap(); + let signed_extend_config = + SignedExtendConfig::construct_limb(circuit_builder, val.expr())?; + let rd_written = signed_extend_config.signed_extended_value(val.expr()); + + (Some(signed_extend_config), rd_written) + } + InsnKind::LHU => { + ( + None, + // it's safe to unwrap as `UInt::from_exprs_unchecked` never return error + UInt::from_exprs_unchecked(vec![ + target_limb.as_ref().map(|limb| limb.expr()).unwrap(), + Expression::ZERO, + ]), + ) + } + InsnKind::LB => { + let val = target_byte_expr.unwrap(); + let signed_extend_config = + SignedExtendConfig::construct_byte(circuit_builder, val.clone())?; + let rd_written = signed_extend_config.signed_extended_value(val); + + (Some(signed_extend_config), rd_written) + } + InsnKind::LBU => ( + None, + UInt::from_exprs_unchecked(vec![target_byte_expr.unwrap(), Expression::ZERO]), + ), + _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), + }; + + let im_insn = IMInstructionConfig::::construct_circuit( + circuit_builder, + I::INST_KIND, + &imm.expr(), + rs1_read.register_expr(), + memory_read.memory_expr(), + memory_addr.expr_align4(), + rd_written.register_expr(), + )?; + + Ok(LoadConfig { + im_insn, + rs1_read, + imm, + memory_addr, + memory_read, + target_limb, + target_limb_bytes, + signed_extend_config, + }) + } + + fn assign_instance( + config: &Self::InstructionConfig, + instance: &mut [MaybeUninit], + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + let rs1 = Value::new_unchecked(step.rs1().unwrap().value); + let memory_value = step.memory_op().unwrap().value.before; + let memory_read = Value::new(memory_value, lk_multiplicity); + // imm is signed 12-bit value + let imm: E::BaseField = InsnRecord::imm_or_funct7_field(&step.insn()); + let unaligned_addr = ByteAddr::from( + step.rs1() + .unwrap() + .value + .wrapping_add(step.insn().imm_or_funct7()), + ); + let shift = unaligned_addr.shift(); + let addr_low_bits = [shift & 0x01, (shift >> 1) & 0x01]; + let target_limb = memory_read.as_u16_limbs()[addr_low_bits[1] as usize]; + let mut target_limb_bytes = target_limb.to_le_bytes(); + + set_val!(instance, config.imm, imm); + config + .im_insn + .assign_instance(instance, lk_multiplicity, step)?; + config.rs1_read.assign_value(instance, rs1); + config.memory_read.assign_value(instance, memory_read); + config + .memory_addr + .assign_instance(instance, lk_multiplicity, unaligned_addr.into())?; + if let Some(&limb) = config.target_limb.as_ref() { + set_val!(instance, limb, E::BaseField::from(target_limb as u64)); + } + if let Some(limb_bytes) = config.target_limb_bytes.as_ref() { + if addr_low_bits[0] == 1 { + // target_limb_bytes[0] = target_limb.to_le_bytes[1] + // target_limb_bytes[1] = target_limb.to_le_bytes[0] + target_limb_bytes.reverse(); + } + for (&col, byte) in izip!( + limb_bytes.iter(), + target_limb_bytes.into_iter().map(|byte| byte as u64) + ) { + lk_multiplicity.assert_ux::<8>(byte); + set_val!(instance, col, E::BaseField::from(byte)); + } + } + let val = match I::INST_KIND { + InsnKind::LB | InsnKind::LBU => target_limb_bytes[0] as u64, + InsnKind::LH | InsnKind::LHU => target_limb as u64, + _ => 0, + }; + if let Some(signed_ext_config) = config.signed_extend_config.as_ref() { + signed_ext_config.assign_instance::(instance, lk_multiplicity, val)?; + } + + Ok(()) + } +} diff --git a/ceno_zkvm/src/instructions/riscv/memory/store.rs b/ceno_zkvm/src/instructions/riscv/memory/store.rs new file mode 100644 index 000000000..fc8f0455f --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/memory/store.rs @@ -0,0 +1,176 @@ +use crate::{ + Value, + circuit_builder::CircuitBuilder, + error::ZKVMError, + expression::{ToExpr, WitIn}, + instructions::{ + Instruction, + riscv::{ + RIVInstruction, constants::UInt, insn_base::MemAddr, memory::gadget::MemWordChange, + s_insn::SInstructionConfig, + }, + }, + set_val, + tables::InsnRecord, + witness::LkMultiplicity, +}; +use ceno_emul::{ByteAddr, CENO_PLATFORM, InsnKind, StepRecord}; +use ff_ext::ExtensionField; +use std::{marker::PhantomData, mem::MaybeUninit}; + +pub struct StoreConfig { + s_insn: SInstructionConfig, + + rs1_read: UInt, + rs2_read: UInt, + imm: WitIn, + prev_memory_value: UInt, + + memory_addr: MemAddr, + word_change: Option>, +} + +pub struct StoreInstruction(PhantomData<(E, I)>); + +pub struct SWOp; + +impl RIVInstruction for SWOp { + const INST_KIND: InsnKind = InsnKind::SW; +} + +// this is actually used in test +#[allow(dead_code)] +pub type SwInstruction = StoreInstruction; + +pub struct SHOp; + +impl RIVInstruction for SHOp { + const INST_KIND: InsnKind = InsnKind::SH; +} + +// this is actually used in test +#[allow(dead_code)] +pub type ShInstruction = StoreInstruction; + +pub struct SBOp; + +impl RIVInstruction for SBOp { + const INST_KIND: InsnKind = InsnKind::SB; +} + +// this is actually used in test +#[allow(dead_code)] +pub type SbInstruction = StoreInstruction; + +impl Instruction + for StoreInstruction +{ + type InstructionConfig = StoreConfig; + + fn name() -> String { + format!("{:?}", I::INST_KIND) + } + + fn construct_circuit( + circuit_builder: &mut CircuitBuilder, + ) -> Result { + let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; // unsigned 32-bit value + let rs2_read = UInt::new_unchecked(|| "rs2_read", circuit_builder)?; + let prev_memory_value = UInt::new(|| "prev_memory_value", circuit_builder)?; + let imm = circuit_builder.create_witin(|| "imm")?; // signed 12-bit value + + let memory_addr = match I::INST_KIND { + InsnKind::SW => MemAddr::construct_align4(circuit_builder), + InsnKind::SH => MemAddr::construct_align2(circuit_builder), + InsnKind::SB => MemAddr::construct_unaligned(circuit_builder), + _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), + }?; + + if cfg!(feature = "forbid_overflow") { + const MAX_RAM_ADDR: u32 = u32::MAX - 0x7FF; // max positive imm is 0x7FF + const MIN_RAM_ADDR: u32 = 0x800; // min negative imm is -0x800 + assert!( + !CENO_PLATFORM.can_write(MAX_RAM_ADDR + 1) + && !CENO_PLATFORM.can_write(MIN_RAM_ADDR - 1) + ); + } + circuit_builder.require_equal( + || "memory_addr = rs1_read + imm", + memory_addr.expr_unaligned(), + rs1_read.value() + imm.expr(), + )?; + + let (new_memory_value, word_change) = match I::INST_KIND { + InsnKind::SW => (rs2_read.memory_expr(), None), + InsnKind::SH | InsnKind::SB => { + let change = MemWordChange::::construct_circuit( + circuit_builder, + &memory_addr, + &prev_memory_value, + &rs2_read, + )?; + (prev_memory_value.value() + change.value(), Some(change)) + } + _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), + }; + + let s_insn = SInstructionConfig::::construct_circuit( + circuit_builder, + I::INST_KIND, + &imm.expr(), + rs1_read.register_expr(), + rs2_read.register_expr(), + memory_addr.expr_align4(), + prev_memory_value.memory_expr(), + new_memory_value, + )?; + + Ok(StoreConfig { + s_insn, + rs1_read, + rs2_read, + imm, + prev_memory_value, + memory_addr, + word_change, + }) + } + + fn assign_instance( + config: &Self::InstructionConfig, + instance: &mut [MaybeUninit], + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + let rs1 = Value::new_unchecked(step.rs1().unwrap().value); + let rs2 = Value::new_unchecked(step.rs2().unwrap().value); + let memory_op = step.memory_op().unwrap(); + let imm: E::BaseField = InsnRecord::imm_or_funct7_field(&step.insn()); + let prev_mem_value = Value::new(memory_op.value.before, lk_multiplicity); + + let addr = ByteAddr::from( + step.rs1() + .unwrap() + .value + .wrapping_add(step.insn().imm_or_funct7()), + ); + config + .s_insn + .assign_instance(instance, lk_multiplicity, step)?; + config.rs1_read.assign_value(instance, rs1); + config.rs2_read.assign_value(instance, rs2); + set_val!(instance, config.imm, imm); + config + .prev_memory_value + .assign_value(instance, prev_mem_value); + + config + .memory_addr + .assign_instance(instance, lk_multiplicity, addr.into())?; + if let Some(change) = config.word_change.as_ref() { + change.assign_instance::(instance, lk_multiplicity, step, addr.shift())?; + } + + Ok(()) + } +} diff --git a/ceno_zkvm/src/instructions/riscv/memory/test.rs b/ceno_zkvm/src/instructions/riscv/memory/test.rs new file mode 100644 index 000000000..ad4ee4b3e --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/memory/test.rs @@ -0,0 +1,299 @@ +use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{ + Instruction, + riscv::{ + RIVInstruction, + memory::{ + LwInstruction, SbInstruction, ShInstruction, SwInstruction, + load::{ + LbInstruction, LbOp, LbuInstruction, LbuOp, LhInstruction, LhOp, + LhuInstruction, LhuOp, LwOp, + }, + store::{SBOp, SHOp, SWOp}, + }, + }, + }, + scheme::mock_prover::{MOCK_PC_START, MockProver}, +}; +use ceno_emul::{ByteAddr, Change, InsnKind, ReadOp, StepRecord, Word, WriteOp, encode_rv32}; +use ff_ext::ExtensionField; +use goldilocks::GoldilocksExt2; +use itertools::Itertools; +use multilinear_extensions::mle::IntoMLEs; +use std::hash::Hash; + +fn sb(prev: Word, rs2: Word, shift: u32) -> Word { + let shift = (shift * 8) as usize; + let mut data = prev; + data ^= data & (0xff << shift); + data |= (rs2 & 0xff) << shift; + + data +} + +fn sh(prev: Word, rs2: Word, shift: u32) -> Word { + assert_eq!(shift & 1, 0); + let shift = (shift * 8) as usize; + let mut data = prev; + + data ^= data & (0xffff << shift); + data |= (rs2 & 0xffff) << shift; + + data +} + +fn sw(_prev: Word, rs2: Word) -> Word { + rs2 +} + +fn signed_extend(val: u32, n_bits: u32) -> u32 { + match n_bits { + 8 => (val as i8) as u32, + 16 => (val as i16) as u32, + _ => unreachable!("unsupported n_bits = {}", n_bits), + } +} + +fn load(mem_value: Word, insn: InsnKind, shift: u32) -> Word { + let val = mem_value >> (8 * shift); + match insn { + InsnKind::LB => signed_extend(val & 0xff_u32, 8), + InsnKind::LBU => val & 0xff_u32, + InsnKind::LH => { + assert_eq!(shift & 0x01, 0); + signed_extend(val & 0xffff_u32, 16) + } + InsnKind::LHU => { + assert_eq!(shift & 0x01, 0); + val & 0xffff_u32 + } + InsnKind::LW => { + assert_eq!(shift & 0x03, 0); + mem_value + } + _ => unreachable!(), + } +} + +fn impl_opcode_store>(imm: u32) { + let mut cs = ConstraintSystem::::new(|| "riscv"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = cb + .namespace( + || Inst::name(), + |cb| { + let config = Inst::construct_circuit(cb); + Ok(config) + }, + ) + .unwrap() + .unwrap(); + + let insn_code = encode_rv32(I::INST_KIND, 2, 3, 0, imm); + let prev_mem_value = 0x40302010; + let rs2_word = Word::from(0x12345678_u32); + let rs1_word = Word::from(0x4000000_u32); + let unaligned_addr = ByteAddr::from(rs1_word.wrapping_add(imm)); + let new_mem_value = match I::INST_KIND { + InsnKind::SB => sb(prev_mem_value, rs2_word, unaligned_addr.shift()), + InsnKind::SH => sh(prev_mem_value, rs2_word, unaligned_addr.shift()), + InsnKind::SW => sw(prev_mem_value, rs2_word), + x => unreachable!("{:?} is not store instruction", x), + }; + let (raw_witin, lkm) = Inst::assign_instances(&config, cb.cs.num_witin as usize, vec![ + StepRecord::new_s_instruction( + 12, + MOCK_PC_START, + insn_code, + rs1_word, + rs2_word, + WriteOp { + addr: unaligned_addr.waddr(), + value: Change { + before: prev_mem_value, + after: new_mem_value, + }, + previous_cycle: 4, + }, + 8, + ), + ]) + .unwrap(); + + MockProver::assert_satisfied( + &cb, + &raw_witin + .de_interleaving() + .into_mles() + .into_iter() + .map(|v| v.into()) + .collect_vec(), + &[insn_code], + None, + Some(lkm), + ); +} + +fn impl_opcode_load>(imm: u32) { + let mut cs = ConstraintSystem::::new(|| "riscv"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = cb + .namespace( + || Inst::name(), + |cb| { + let config = Inst::construct_circuit(cb); + Ok(config) + }, + ) + .unwrap() + .unwrap(); + + let insn_code = encode_rv32(I::INST_KIND, 2, 3, 0, imm); + let mem_value = 0x40302010; + let rs1_word = Word::from(0x4000000_u32); + let prev_rd_word = Word::from(0x12345678_u32); + let unaligned_addr = ByteAddr::from(rs1_word.wrapping_add(imm)); + let new_rd_word = load(mem_value, I::INST_KIND, unaligned_addr.shift()); + let rd_change = Change { + before: prev_rd_word, + after: new_rd_word, + }; + let (raw_witin, lkm) = Inst::assign_instances(&config, cb.cs.num_witin as usize, vec![ + StepRecord::new_im_instruction( + 12, + MOCK_PC_START, + insn_code, + rs1_word, + rd_change, + ReadOp { + addr: unaligned_addr.waddr(), + value: mem_value, + previous_cycle: 4, + }, + 8, + ), + ]) + .unwrap(); + + MockProver::assert_satisfied( + &cb, + &raw_witin + .de_interleaving() + .into_mles() + .into_iter() + .map(|v| v.into()) + .collect_vec(), + &[insn_code], + None, + Some(lkm), + ); +} + +fn impl_opcode_sb(imm: u32) { + impl_opcode_store::>(imm) +} + +fn impl_opcode_sh(imm: u32) { + assert_eq!(imm & 0x01, 0); + impl_opcode_store::>(imm) +} + +fn impl_opcode_sw(imm: u32) { + assert_eq!(imm & 0x03, 0); + impl_opcode_store::>(imm) +} + +#[test] +fn test_sb() { + impl_opcode_sb(0); + impl_opcode_sb(5); + impl_opcode_sb(10); + impl_opcode_sb(15); + + let neg_one = u32::MAX; + for i in 0..4 { + impl_opcode_sb(neg_one - i); + } +} + +#[test] +fn test_sh() { + impl_opcode_sh(0); + impl_opcode_sh(2); + + let neg_two = u32::MAX - 1; + for i in [0, 2] { + impl_opcode_sh(neg_two - i) + } +} + +#[test] +fn test_sw() { + impl_opcode_sw(0); + impl_opcode_sw(4); + + let neg_four = u32::MAX - 3; + impl_opcode_sw(neg_four); +} + +#[test] +fn test_lb() { + impl_opcode_load::>(0); + impl_opcode_load::>(1); + impl_opcode_load::>(2); + impl_opcode_load::>(3); + + let neg_one = u32::MAX; + // imm = -1, -2, -3 + for i in 0..3 { + impl_opcode_load::>(neg_one - i); + } +} + +#[test] +fn test_lbu() { + impl_opcode_load::>(0); + impl_opcode_load::>(1); + impl_opcode_load::>(2); + impl_opcode_load::>(3); + + let neg_one = u32::MAX; + // imm = -1, -2, -3 + for i in 0..3 { + impl_opcode_load::>(neg_one - i); + } +} + +#[test] +fn test_lh() { + impl_opcode_load::>(0); + impl_opcode_load::>(2); + impl_opcode_load::>(4); + + let neg_two = u32::MAX - 1; + // imm = -2, -4 + for i in [0, 2] { + impl_opcode_load::>(neg_two - i); + } +} + +#[test] +fn test_lhu() { + impl_opcode_load::>(0); + impl_opcode_load::>(2); + impl_opcode_load::>(4); + + let neg_two = u32::MAX - 1; + // imm = -2, -4 + for i in [0, 2] { + impl_opcode_load::>(neg_two - i); + } +} + +#[test] +fn test_lw() { + impl_opcode_load::>(0); + impl_opcode_load::>(4); + impl_opcode_load::>(u32::MAX - 3); // imm = -4 +} diff --git a/ceno_zkvm/src/instructions/riscv/mulh.rs b/ceno_zkvm/src/instructions/riscv/mulh.rs index 79f60a6a5..b852d9e2f 100644 --- a/ceno_zkvm/src/instructions/riscv/mulh.rs +++ b/ceno_zkvm/src/instructions/riscv/mulh.rs @@ -117,7 +117,7 @@ impl Instruction for MulhInstruction().as_u32()), @@ -182,7 +183,9 @@ mod test { .into_iter() .map(|v| v.into()) .collect_vec(), + &[insn_code], None, + Some(lkm), ); } } diff --git a/ceno_zkvm/src/instructions/riscv/rv32im.rs b/ceno_zkvm/src/instructions/riscv/rv32im.rs new file mode 100644 index 000000000..c50b9218d --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/rv32im.rs @@ -0,0 +1,168 @@ +use crate::{ + error::ZKVMError, + instructions::Instruction, + structs::{ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses}, + tables::{ + AndTableCircuit, LtuTableCircuit, MemFinalRecord, MemInitRecord, MemTableCircuit, + RegTableCircuit, TableCircuit, U14TableCircuit, U16TableCircuit, + }, +}; +use ceno_emul::{CENO_PLATFORM, InsnKind, StepRecord}; +use ff_ext::ExtensionField; + +use super::{ + arith::AddInstruction, + branch::BltuInstruction, + ecall::HaltInstruction, + jump::{JalInstruction, LuiInstruction}, + memory::LwInstruction, +}; + +pub struct Rv32imConfig { + // Opcodes. + pub add_config: as Instruction>::InstructionConfig, + pub bltu_config: >::InstructionConfig, + pub jal_config: as Instruction>::InstructionConfig, + pub halt_config: as Instruction>::InstructionConfig, + pub lui_config: as Instruction>::InstructionConfig, + pub lw_config: as Instruction>::InstructionConfig, + + // Tables. + pub u16_range_config: as TableCircuit>::TableConfig, + pub u14_range_config: as TableCircuit>::TableConfig, + pub and_config: as TableCircuit>::TableConfig, + pub ltu_config: as TableCircuit>::TableConfig, + + // RW tables. + pub reg_config: as TableCircuit>::TableConfig, + pub mem_config: as TableCircuit>::TableConfig, +} + +impl Rv32imConfig { + pub fn construct_circuits(cs: &mut ZKVMConstraintSystem) -> Self { + // opcode circuits + let add_config = cs.register_opcode_circuit::>(); + let bltu_config = cs.register_opcode_circuit::(); + let jal_config = cs.register_opcode_circuit::>(); + let halt_config = cs.register_opcode_circuit::>(); + let lui_config = cs.register_opcode_circuit::>(); + let lw_config = cs.register_opcode_circuit::>(); + + // tables + let u16_range_config = cs.register_table_circuit::>(); + let u14_range_config = cs.register_table_circuit::>(); + let and_config = cs.register_table_circuit::>(); + let ltu_config = cs.register_table_circuit::>(); + + // RW tables + let reg_config = cs.register_table_circuit::>(); + let mem_config = cs.register_table_circuit::>(); + + Self { + add_config, + bltu_config, + jal_config, + halt_config, + lui_config, + lw_config, + u16_range_config, + u14_range_config, + and_config, + ltu_config, + + reg_config, + mem_config, + } + } + + pub fn generate_fixed_traces( + &self, + cs: &ZKVMConstraintSystem, + fixed: &mut ZKVMFixedTraces, + reg_init: &[MemInitRecord], + mem_init: &[MemInitRecord], + ) { + fixed.register_opcode_circuit::>(cs); + fixed.register_opcode_circuit::(cs); + fixed.register_opcode_circuit::>(cs); + fixed.register_opcode_circuit::>(cs); + fixed.register_opcode_circuit::>(cs); + fixed.register_opcode_circuit::>(cs); + + fixed.register_table_circuit::>(cs, self.u16_range_config.clone(), &()); + fixed.register_table_circuit::>(cs, self.u14_range_config.clone(), &()); + fixed.register_table_circuit::>(cs, self.and_config.clone(), &()); + fixed.register_table_circuit::>(cs, self.ltu_config.clone(), &()); + + fixed.register_table_circuit::>(cs, self.reg_config.clone(), reg_init); + fixed.register_table_circuit::>(cs, self.mem_config.clone(), mem_init); + } + + pub fn assign_opcode_circuit( + &self, + cs: &ZKVMConstraintSystem, + witness: &mut ZKVMWitnesses, + steps: Vec, + ) -> Result<(), ZKVMError> { + use InsnKind::*; + + let mut add_records = Vec::new(); + let mut bltu_records = Vec::new(); + let mut jal_records = Vec::new(); + let mut halt_records = Vec::new(); + let mut lui_records = Vec::new(); + let mut lw_records = Vec::new(); + steps + .into_iter() + .for_each(|record| match record.insn().codes().kind { + ADD => add_records.push(record), + BLTU => bltu_records.push(record), + JAL => jal_records.push(record), + EANY if record.rs1().unwrap().value == CENO_PLATFORM.ecall_halt() => { + halt_records.push(record); + } + LUI => lui_records.push(record), + LW => lw_records.push(record), + i => unimplemented!("instruction {i:?}"), + }); + + tracing::info!( + "tracer generated {} ADD records, {} BLTU records, {} JAL records", + add_records.len(), + bltu_records.len(), + jal_records.len(), + ); + assert_eq!(halt_records.len(), 1); + + witness.assign_opcode_circuit::>(cs, &self.add_config, add_records)?; + witness.assign_opcode_circuit::(cs, &self.bltu_config, bltu_records)?; + witness.assign_opcode_circuit::>(cs, &self.jal_config, jal_records)?; + witness.assign_opcode_circuit::>(cs, &self.halt_config, halt_records)?; + witness.assign_opcode_circuit::>(cs, &self.lui_config, lui_records)?; + witness.assign_opcode_circuit::>(cs, &self.lw_config, lw_records)?; + Ok(()) + } + + pub fn assign_table_circuit( + &self, + cs: &ZKVMConstraintSystem, + witness: &mut ZKVMWitnesses, + reg_final: &[MemFinalRecord], + mem_final: &[MemFinalRecord], + ) -> Result<(), ZKVMError> { + witness.assign_table_circuit::>(cs, &self.u16_range_config, &())?; + witness.assign_table_circuit::>(cs, &self.u14_range_config, &())?; + witness.assign_table_circuit::>(cs, &self.and_config, &())?; + witness.assign_table_circuit::>(cs, &self.ltu_config, &())?; + + // assign register finalization. + witness + .assign_table_circuit::>(cs, &self.reg_config, reg_final) + .unwrap(); + // assign memory finalization. + witness + .assign_table_circuit::>(cs, &self.mem_config, mem_final) + .unwrap(); + Ok(()) + } +} diff --git a/ceno_zkvm/src/instructions/riscv/s_insn.rs b/ceno_zkvm/src/instructions/riscv/s_insn.rs index a0166d929..702a3d5ea 100644 --- a/ceno_zkvm/src/instructions/riscv/s_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/s_insn.rs @@ -1,5 +1,5 @@ use crate::{ - chip_handler::{MemoryExpr, RegisterExpr}, + chip_handler::{AddressExpr, MemoryExpr, RegisterExpr}, circuit_builder::CircuitBuilder, error::ZKVMError, expression::{Expression, ToExpr}, @@ -19,18 +19,20 @@ pub struct SInstructionConfig { vm_state: StateInOut, rs1: ReadRS1, rs2: ReadRS2, - mem_write: WriteMEM, + mem_write: WriteMEM, } impl SInstructionConfig { + #[allow(clippy::too_many_arguments)] pub fn construct_circuit( circuit_builder: &mut CircuitBuilder, insn_kind: InsnKind, imm: &Expression, rs1_read: RegisterExpr, rs2_read: RegisterExpr, - memory_addr: MemoryExpr, - memory_value: MemoryExpr, + memory_addr: AddressExpr, + prev_memory_value: MemoryExpr, + new_memory_value: MemoryExpr, ) -> Result { // State in and out let vm_state = StateInOut::construct_circuit(circuit_builder, false)?; @@ -51,8 +53,13 @@ impl SInstructionConfig { ))?; // Memory - let mem_write = - WriteMEM::construct_circuit(circuit_builder, memory_addr, memory_value, vm_state.ts)?; + let mem_write = WriteMEM::construct_circuit( + circuit_builder, + memory_addr, + prev_memory_value, + new_memory_value, + vm_state.ts, + )?; Ok(SInstructionConfig { vm_state, @@ -72,7 +79,7 @@ impl SInstructionConfig { self.rs1.assign_instance(instance, lk_multiplicity, step)?; self.rs2.assign_instance(instance, lk_multiplicity, step)?; self.mem_write - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance::(instance, lk_multiplicity, step)?; // Fetch instruction lk_multiplicity.fetch(step.pc().before.0); diff --git a/ceno_zkvm/src/instructions/riscv/shift.rs b/ceno_zkvm/src/instructions/riscv/shift.rs index 747e3f1e1..189811dd2 100644 --- a/ceno_zkvm/src/instructions/riscv/shift.rs +++ b/ceno_zkvm/src/instructions/riscv/shift.rs @@ -104,7 +104,7 @@ impl Instruction for ShiftLogicalInstru circuit_builder.require_equal( || "rs2 == rs2_high * 2^5 + rs2_low5", rs2_read.value(), - rs2_high.value() * (1 << 5).into() + rs2_low5.expr(), + rs2_high.value() * (1 << 5) + rs2_low5.expr(), )?; Ok(ShiftConfig { @@ -176,17 +176,22 @@ impl Instruction for ShiftLogicalInstru .r_insn .assign_instance(instance, lk_multiplicity, step)?; config.rs2_read.assign_value(instance, rs2_read); + set_val!(instance, config.rs2_low5, rs2_low5); + lk_multiplicity.assert_ux::<5>(rs2_low5); + config.rs2_high.assign_value(instance, rs2_high); config.pow2_rs2_low5.assign_value(instance, pow2_rs2_low5); + lk_multiplicity.lookup_pow2(rs2_low5); + Ok(()) } } #[cfg(test)] mod tests { - use ceno_emul::{Change, InsnKind, StepRecord}; + use ceno_emul::{Change, InsnKind, StepRecord, encode_rv32}; use goldilocks::GoldilocksExt2; use itertools::Itertools; use multilinear_extensions::mle::IntoMLEs; @@ -198,7 +203,7 @@ mod tests { Instruction, riscv::{RIVInstruction, constants::UInt}, }, - scheme::mock_prover::{MOCK_PC_SLL, MOCK_PC_SRL, MOCK_PROGRAM, MockProver}, + scheme::mock_prover::{MOCK_PC_START, MockProver}, }; use super::{ShiftLogicalInstruction, SllOp, SrlOp}; @@ -235,9 +240,17 @@ mod tests { let mut cb = CircuitBuilder::new(&mut cs); let shift = rs2_read & 0b11111; - let (prefix, mock_pc, mock_program_op, rd_written) = match I::INST_KIND { - InsnKind::SLL => ("SLL", MOCK_PC_SLL, MOCK_PROGRAM[19], rs1_read << shift), - InsnKind::SRL => ("SRL", MOCK_PC_SRL, MOCK_PROGRAM[20], rs1_read >> shift), + let (prefix, insn_code, rd_written) = match I::INST_KIND { + InsnKind::SLL => ( + "SLL", + encode_rv32(InsnKind::SLL, 2, 3, 4, 0), + rs1_read << shift, + ), + InsnKind::SRL => ( + "SRL", + encode_rv32(InsnKind::SRL, 2, 3, 4, 0), + rs1_read >> shift, + ), _ => unreachable!(), }; @@ -266,13 +279,13 @@ mod tests { ) .unwrap(); - let (raw_witin, _) = ShiftLogicalInstruction::::assign_instances( + let (raw_witin, lkm) = ShiftLogicalInstruction::::assign_instances( &config, cb.cs.num_witin as usize, vec![StepRecord::new_r_instruction( 3, - mock_pc, - mock_program_op, + MOCK_PC_START, + insn_code, rs1_read, rs2_read, Change::new(0, rd_written), @@ -289,7 +302,9 @@ mod tests { .into_iter() .map(|v| v.into()) .collect_vec(), + &[insn_code], None, + Some(lkm), ); } } diff --git a/ceno_zkvm/src/instructions/riscv/shift_imm.rs b/ceno_zkvm/src/instructions/riscv/shift_imm.rs index 4b4b2b8e2..d1af0871b 100644 --- a/ceno_zkvm/src/instructions/riscv/shift_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/shift_imm.rs @@ -3,35 +3,51 @@ use crate::{ Value, circuit_builder::CircuitBuilder, error::ZKVMError, - gadgets::DivConfig, + expression::{Expression, ToExpr, WitIn}, + gadgets::{AssertLTConfig, IsLtConfig}, instructions::{ Instruction, riscv::{constants::UInt, i_insn::IInstructionConfig}, }, + set_val, witness::LkMultiplicity, }; -use ceno_emul::StepRecord; +use ceno_emul::{InsnKind, StepRecord}; use ff_ext::ExtensionField; use std::{marker::PhantomData, mem::MaybeUninit}; -pub struct InstructionConfig { +pub struct ShiftImmConfig { i_insn: IInstructionConfig, - imm: UInt, + imm: WitIn, + rs1_read: UInt, rd_written: UInt, - remainder: UInt, - div_config: DivConfig, + outflow: WitIn, + assert_lt_config: AssertLTConfig, + + // SRAI + is_lt_config: Option, } pub struct ShiftImmInstruction(PhantomData<(E, I)>); +pub struct SlliOp; +impl RIVInstruction for SlliOp { + const INST_KIND: ceno_emul::InsnKind = ceno_emul::InsnKind::SLLI; +} + +pub struct SraiOp; +impl RIVInstruction for SraiOp { + const INST_KIND: ceno_emul::InsnKind = ceno_emul::InsnKind::SRAI; +} + pub struct SrliOp; impl RIVInstruction for SrliOp { - const INST_KIND: ceno_emul::InsnKind = ceno_emul::InsnKind::SRLI; + const INST_KIND: ceno_emul::InsnKind = InsnKind::SRLI; } impl Instruction for ShiftImmInstruction { - type InstructionConfig = InstructionConfig; + type InstructionConfig = ShiftImmConfig; fn name() -> String { format!("{:?}", I::INST_KIND) @@ -40,35 +56,77 @@ impl Instruction for ShiftImmInstructio fn construct_circuit( circuit_builder: &mut CircuitBuilder, ) -> Result { - let mut imm = UInt::new(|| "imm", circuit_builder)?; - let mut rd_written = UInt::new(|| "rd_written", circuit_builder)?; - - // Note: `imm` is set to 2**imm (upto 32 bit) just for SRLI for efficient verification - // Goal is to constrain: - // rs1 == rd_written * imm + remainder - let remainder = UInt::new(|| "remainder", circuit_builder)?; - let div_config = DivConfig::construct_circuit( + // Note: `imm` wtns is set to 2**imm (upto 32 bit) just for efficient verification. + let imm = circuit_builder.create_witin(|| "imm")?; + let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; + let rd_written = UInt::new(|| "rd_written", circuit_builder)?; + + let outflow = circuit_builder.create_witin(|| "outflow")?; + let assert_lt_config = AssertLTConfig::construct_circuit( circuit_builder, - || "srli_div", - &mut imm, - &mut rd_written, - &remainder, + || "outflow < imm", + outflow.expr(), + imm.expr(), + 2, )?; + let two_pow_total_bits: Expression<_> = (1u64 << UInt::::TOTAL_BITS).into(); + + let is_lt_config = match I::INST_KIND { + InsnKind::SLLI => { + circuit_builder.require_equal( + || "shift check", + rs1_read.value() * imm.expr(), // inflow is zero for this case + outflow.expr() * two_pow_total_bits + rd_written.value(), + )?; + None + } + InsnKind::SRAI | InsnKind::SRLI => { + let (inflow, is_lt_config) = match I::INST_KIND { + InsnKind::SRAI => { + let max_signed_limb_expr: Expression<_> = + ((1 << (UInt::::LIMB_BITS - 1)) - 1).into(); + let is_rs1_neg = IsLtConfig::construct_circuit( + circuit_builder, + || "lhs_msb", + max_signed_limb_expr, + rs1_read.limbs.iter().last().unwrap().expr(), // msb limb + 1, + )?; + let msb_expr: Expression = is_rs1_neg.is_lt.expr(); + let ones = imm.expr() - Expression::ONE; + (msb_expr * ones, Some(is_rs1_neg)) + } + InsnKind::SRLI => (Expression::ZERO, None), + _ => unreachable!(), + }; + circuit_builder.require_equal( + || "shift check", + rd_written.value() * imm.expr() + outflow.expr(), + inflow * two_pow_total_bits + rs1_read.value(), + )?; + is_lt_config + } + _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), + }; + let i_insn = IInstructionConfig::::construct_circuit( circuit_builder, I::INST_KIND, - &imm.value(), - div_config.dividend.register_expr(), + &imm.expr(), + rs1_read.register_expr(), rd_written.register_expr(), + false, )?; - Ok(InstructionConfig { + Ok(ShiftImmConfig { i_insn, imm, + rs1_read, rd_written, - remainder, - div_config, + outflow, + assert_lt_config, + is_lt_config, }) } @@ -78,26 +136,36 @@ impl Instruction for ShiftImmInstructio lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { + let imm = step.insn().imm_or_funct7(); + let rs1_read = Value::new_unchecked(step.rs1().unwrap().value); let rd_written = Value::new(step.rd().unwrap().value.after, lk_multiplicity); - let (remainder, imm) = { - let rs1_read = step.rs1().unwrap().value; - let imm = step.insn().imm_or_funct7(); - ( - Value::new(rs1_read % imm, lk_multiplicity), - Value::new(imm, lk_multiplicity), - ) - }; - config.div_config.assign_instance( - instance, - lk_multiplicity, - &imm, - &rd_written, - &remainder, - )?; - config.imm.assign_value(instance, imm); + set_val!(instance, config.imm, imm as u64); + config.rs1_read.assign_value(instance, rs1_read.clone()); config.rd_written.assign_value(instance, rd_written); - config.remainder.assign_value(instance, remainder); + + let outflow = match I::INST_KIND { + InsnKind::SLLI => (rs1_read.as_u64() * imm as u64) >> UInt::::TOTAL_BITS, + InsnKind::SRAI | InsnKind::SRLI => { + if I::INST_KIND == InsnKind::SRAI { + let max_signed_limb_expr = (1 << (UInt::::LIMB_BITS - 1)) - 1; + config.is_lt_config.as_ref().unwrap().assign_instance( + instance, + lk_multiplicity, + max_signed_limb_expr, + rs1_read.as_u64() >> UInt::::LIMB_BITS, + )?; + } + + rs1_read.as_u64() & (imm as u64 - 1) + } + _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), + }; + + set_val!(instance, config.outflow, outflow); + config + .assert_lt_config + .assign_instance(instance, lk_multiplicity, outflow, imm as u64)?; config .i_insn @@ -109,39 +177,96 @@ impl Instruction for ShiftImmInstructio #[cfg(test)] mod test { - use ceno_emul::{ByteAddr, Change, StepRecord}; + use ceno_emul::{Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; use goldilocks::GoldilocksExt2; use itertools::Itertools; use multilinear_extensions::mle::IntoMLEs; + use super::{ShiftImmInstruction, SlliOp, SraiOp, SrliOp}; use crate::{ Value, circuit_builder::{CircuitBuilder, ConstraintSystem}, - instructions::{Instruction, riscv::constants::UInt}, - scheme::mock_prover::{MOCK_PC_SRLI, MOCK_PC_SRLI_31, MOCK_PROGRAM, MockProver}, + instructions::{ + Instruction, + riscv::{RIVInstruction, constants::UInt}, + }, + scheme::mock_prover::{MOCK_PC_START, MockProver}, }; - use super::{ShiftImmInstruction, SrliOp}; + #[test] + fn test_opcode_slli() { + // imm = 3 + verify::("32 << 3", 32, 3, 32 << 3); + verify::("33 << 3", 33, 3, 33 << 3); + // imm = 31 + verify::("32 << 31", 32, 31, 32 << 31); + verify::("33 << 31", 33, 31, 33 << 31); + } + + #[test] + fn test_opcode_srai() { + // positive rs1 + // imm = 3 + verify::("32 >> 3", 32, 3, 32 >> 3); + verify::("33 >> 3", 33, 3, 33 >> 3); + // imm = 31 + verify::("32 >> 31", 32, 31, 32 >> 31); + verify::("33 >> 31", 33, 31, 33 >> 31); + + // negative rs1 + // imm = 3 + verify::("-32 >> 3", (-32_i32) as u32, 3, (-32_i32 >> 3) as u32); + verify::("-33 >> 3", (-33_i32) as u32, 3, (-33_i32 >> 3) as u32); + // imm = 31 + verify::("-32 >> 31", (-32_i32) as u32, 31, (-32_i32 >> 31) as u32); + verify::("-33 >> 31", (-33_i32) as u32, 31, (-33_i32 >> 31) as u32); + } #[test] fn test_opcode_srli() { // imm = 3 - verify_srli(MOCK_PC_SRLI, MOCK_PROGRAM[10], 32, 32 >> 3); - verify_srli(MOCK_PC_SRLI, MOCK_PROGRAM[10], 33, 33 >> 3); + verify::("32 >> 3", 32, 3, 32 >> 3); + verify::("33 >> 3", 33, 3, 33 >> 3); // imm = 31 - verify_srli(MOCK_PC_SRLI_31, MOCK_PROGRAM[11], 32, 32 >> 31); - verify_srli(MOCK_PC_SRLI_31, MOCK_PROGRAM[11], 33, 33 >> 31); + verify::("32 >> 31", 32, 31, 32 >> 31); + verify::("33 >> 31", 33, 31, 33 >> 31); + // rs1 top bit is 1 + verify::("-32 >> 3", (-32_i32) as u32, 3, (-32_i32) as u32 >> 3); } - fn verify_srli(pc: ByteAddr, program: u32, rs1_read: u32, expected_rd_written: u32) { + fn verify( + name: &'static str, + rs1_read: u32, + imm: u32, + expected_rd_written: u32, + ) { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); + + let (prefix, insn_code, rd_written) = match I::INST_KIND { + InsnKind::SLLI => ( + "SLLI", + encode_rv32(InsnKind::SLLI, 2, 0, 4, imm), + rs1_read << imm, + ), + InsnKind::SRAI => ( + "SRAI", + encode_rv32(InsnKind::SRAI, 2, 0, 4, imm), + (rs1_read as i32 >> imm as i32) as u32, + ), + InsnKind::SRLI => ( + "SRLI", + encode_rv32(InsnKind::SRLI, 2, 0, 4, imm), + rs1_read >> imm, + ), + _ => unreachable!(), + }; + let config = cb .namespace( - || "srli", + || format!("{prefix}_({name})"), |cb| { - let config = - ShiftImmInstruction::::construct_circuit(cb); + let config = ShiftImmInstruction::::construct_circuit(cb); Ok(config) }, ) @@ -151,7 +276,7 @@ mod test { config .rd_written .require_equal( - || "assert_rd_written", + || format!("{prefix}_({name})_assert_rd_written"), &mut cb, &UInt::from_const_unchecked( Value::new_unchecked(expected_rd_written) @@ -161,15 +286,15 @@ mod test { ) .unwrap(); - let (raw_witin, _) = ShiftImmInstruction::::assign_instances( + let (raw_witin, lkm) = ShiftImmInstruction::::assign_instances( &config, cb.cs.num_witin as usize, vec![StepRecord::new_i_instruction( 3, - pc, - program, + Change::new(MOCK_PC_START, MOCK_PC_START + PC_STEP_SIZE), + insn_code, rs1_read, - Change::new(0, expected_rd_written), + Change::new(0, rd_written), 0, )], ) @@ -183,7 +308,9 @@ mod test { .into_iter() .map(|v| v.into()) .collect_vec(), + &[insn_code], None, + Some(lkm), ); } } diff --git a/ceno_zkvm/src/instructions/riscv/slt.rs b/ceno_zkvm/src/instructions/riscv/slt.rs new file mode 100644 index 000000000..fb138f2aa --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/slt.rs @@ -0,0 +1,187 @@ +use std::marker::PhantomData; + +use ceno_emul::{InsnKind, SWord, StepRecord}; +use ff_ext::ExtensionField; + +use super::{constants::UInt, r_insn::RInstructionConfig}; +use crate::{ + circuit_builder::CircuitBuilder, error::ZKVMError, gadgets::SignedLtConfig, + instructions::Instruction, uint::Value, witness::LkMultiplicity, +}; +use core::mem::MaybeUninit; + +/// This config handles R-Instructions that represent registers values as 2 * u16. +#[derive(Debug)] +pub struct SltConfig { + r_insn: RInstructionConfig, + + rs1_read: UInt, + rs2_read: UInt, + #[allow(dead_code)] + rd_written: UInt, + + signed_lt: SignedLtConfig, +} + +pub struct SltInstruction(PhantomData); + +// TODO combine with SLTU +impl Instruction for SltInstruction { + type InstructionConfig = SltConfig; + + fn name() -> String { + format!("{:?}", InsnKind::SLT) + } + + fn construct_circuit(cb: &mut CircuitBuilder) -> Result { + // If rs1_read < rs2_read, rd_written = 1. Otherwise rd_written = 0 + let rs1_read = UInt::new_unchecked(|| "rs1_read", cb)?; + let rs2_read = UInt::new_unchecked(|| "rs2_read", cb)?; + + let lt = SignedLtConfig::construct_circuit(cb, || "rs1 < rs2", &rs1_read, &rs2_read)?; + let rd_written = UInt::from_exprs_unchecked(vec![lt.expr()]); + + let r_insn = RInstructionConfig::::construct_circuit( + cb, + InsnKind::SLT, + rs1_read.register_expr(), + rs2_read.register_expr(), + rd_written.register_expr(), + )?; + + Ok(SltConfig { + r_insn, + rs1_read, + rs2_read, + rd_written, + signed_lt: lt, + }) + } + + fn assign_instance( + config: &Self::InstructionConfig, + instance: &mut [MaybeUninit<::BaseField>], + lkm: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + config.r_insn.assign_instance(instance, lkm, step)?; + + let rs1 = step.rs1().unwrap().value; + let rs2 = step.rs2().unwrap().value; + + let rs1_read = Value::new_unchecked(rs1); + let rs2_read = Value::new_unchecked(rs2); + config + .rs1_read + .assign_limbs(instance, rs1_read.as_u16_limbs()); + config + .rs2_read + .assign_limbs(instance, rs2_read.as_u16_limbs()); + config + .signed_lt + .assign_instance::(instance, lkm, rs1 as SWord, rs2 as SWord)?; + + Ok(()) + } +} + +#[cfg(test)] +mod test { + use ceno_emul::{Change, StepRecord, Word, encode_rv32}; + use goldilocks::GoldilocksExt2; + + use itertools::Itertools; + use multilinear_extensions::mle::IntoMLEs; + use rand::Rng; + + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::Instruction, + scheme::mock_prover::{MOCK_PC_START, MockProver}, + }; + + fn verify(name: &'static str, rs1: i32, rs2: i32, rd: Word) { + let mut cs = ConstraintSystem::::new(|| "riscv"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = cb + .namespace( + || format!("SLT/{name}"), + |cb| { + let config = SltInstruction::construct_circuit(cb); + Ok(config) + }, + ) + .unwrap() + .unwrap(); + + let insn_code = encode_rv32(InsnKind::SLT, 2, 3, 4, 0); + let (raw_witin, lkm) = + SltInstruction::assign_instances(&config, cb.cs.num_witin as usize, vec![ + StepRecord::new_r_instruction( + 3, + MOCK_PC_START, + insn_code, + rs1 as Word, + rs2 as Word, + Change::new(0, rd), + 0, + ), + ]) + .unwrap(); + + let expected_rd_written = + UInt::from_const_unchecked(Value::new_unchecked(rd).as_u16_limbs().to_vec()); + config + .rd_written + .require_equal(|| "assert_rd_written", &mut cb, &expected_rd_written) + .unwrap(); + + MockProver::assert_satisfied( + &cb, + &raw_witin + .de_interleaving() + .into_mles() + .into_iter() + .map(|v| v.into()) + .collect_vec(), + &[insn_code], + None, + Some(lkm), + ); + } + + #[test] + fn test_slt_true() { + verify("lt = true, 0 < 1", 0, 1, 1); + verify("lt = true, 1 < 2", 1, 2, 1); + verify("lt = true, -1 < 0", -1, 0, 1); + verify("lt = true, -1 < 1", -1, 1, 1); + verify("lt = true, -2 < -1", -2, -1, 1); + verify("lt = true, large number", i32::MIN, i32::MAX, 1); + } + + #[test] + fn test_slt_false() { + verify("lt = false, 1 < 0", 1, 0, 0); + verify("lt = false, 2 < 1", 2, 1, 0); + verify("lt = false, 0 < -1", 0, -1, 0); + verify("lt = false, 1 < -1", 1, -1, 0); + verify("lt = false, -1 < -2", -1, -2, 0); + verify("lt = false, 0 == 0", 0, 0, 0); + verify("lt = false, 1 == 1", 1, 1, 0); + verify("lt = false, -1 == -1", -1, -1, 0); + // This case causes subtract overflow in `assign_instance_signed` + verify("lt = false, large number", i32::MAX, i32::MIN, 0); + } + + #[test] + fn test_slt_random() { + let mut rng = rand::thread_rng(); + let a: i32 = rng.gen(); + let b: i32 = rng.gen(); + println!("random: {} = b) as u32); + } +} diff --git a/ceno_zkvm/src/instructions/riscv/slti.rs b/ceno_zkvm/src/instructions/riscv/slti.rs new file mode 100644 index 000000000..a0db3fbc0 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/slti.rs @@ -0,0 +1,220 @@ +use std::marker::PhantomData; + +use ceno_emul::{InsnKind, SWord, StepRecord, Word}; +use ff_ext::ExtensionField; + +use super::{ + constants::{UINT_LIMBS, UInt}, + i_insn::IInstructionConfig, +}; +use crate::{ + circuit_builder::CircuitBuilder, + error::ZKVMError, + expression::{Expression, ToExpr, WitIn}, + gadgets::IsLtConfig, + instructions::Instruction, + set_val, + tables::InsnRecord, + uint::Value, + witness::LkMultiplicity, +}; +use core::mem::MaybeUninit; + +#[derive(Debug)] +pub struct InstructionConfig { + i_insn: IInstructionConfig, + + rs1_read: UInt, + imm: WitIn, + #[allow(dead_code)] + rd_written: UInt, + + is_rs1_neg: IsLtConfig, + lt: IsLtConfig, +} + +pub struct SltiInstruction(PhantomData); + +impl Instruction for SltiInstruction { + type InstructionConfig = InstructionConfig; + + fn name() -> String { + format!("{:?}", InsnKind::SLTI) + } + + fn construct_circuit(cb: &mut CircuitBuilder) -> Result { + // If rs1_read < imm, rd_written = 1. Otherwise rd_written = 0 + let rs1_read = UInt::new_unchecked(|| "rs1_read", cb)?; + let imm = cb.create_witin(|| "imm")?; + + let max_signed_limb_expr: Expression<_> = ((1 << (UInt::::LIMB_BITS - 1)) - 1).into(); + let is_rs1_neg = IsLtConfig::construct_circuit( + cb, + || "lhs_msb", + max_signed_limb_expr, + rs1_read.limbs.iter().last().unwrap().expr(), // msb limb + 1, + )?; + + let lt = IsLtConfig::construct_circuit( + cb, + || "rs1 < imm", + rs1_read.to_field_expr(is_rs1_neg.expr()), + imm.expr(), + UINT_LIMBS, + )?; + let rd_written = UInt::from_exprs_unchecked(vec![lt.expr()]); + + let i_insn = IInstructionConfig::::construct_circuit( + cb, + InsnKind::SLTI, + &imm.expr(), + rs1_read.register_expr(), + rd_written.register_expr(), + false, + )?; + + Ok(InstructionConfig { + i_insn, + rs1_read, + imm, + rd_written, + is_rs1_neg, + lt, + }) + } + + fn assign_instance( + config: &Self::InstructionConfig, + instance: &mut [MaybeUninit], + lkm: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + config.i_insn.assign_instance(instance, lkm, step)?; + + let rs1 = step.rs1().unwrap().value; + let max_signed_limb = (1u64 << (UInt::::LIMB_BITS - 1)) - 1; + let rs1_value = Value::new_unchecked(rs1 as Word); + config + .rs1_read + .assign_value(instance, Value::new_unchecked(rs1)); + config.is_rs1_neg.assign_instance( + instance, + lkm, + max_signed_limb, + *rs1_value.limbs.last().unwrap() as u64, + )?; + + let imm = step.insn().imm_or_funct7(); + let imm_field = InsnRecord::imm_or_funct7_field::(&step.insn()); + set_val!(instance, config.imm, imm_field); + + config + .lt + .assign_instance_signed(instance, lkm, rs1 as SWord, imm as SWord)?; + + Ok(()) + } +} + +#[cfg(test)] +mod test { + use ceno_emul::{Change, PC_STEP_SIZE, StepRecord, Word, encode_rv32}; + use goldilocks::GoldilocksExt2; + + use itertools::Itertools; + use multilinear_extensions::mle::IntoMLEs; + use rand::Rng; + + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::test_utils::imm_i}, + scheme::mock_prover::{MOCK_PC_START, MockProver}, + }; + + fn verify(name: &'static str, rs1: i32, imm: i32, rd: Word) { + let mut cs = ConstraintSystem::::new(|| "riscv"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = cb + .namespace( + || format!("SLTI/{name}"), + |cb| { + let config = SltiInstruction::construct_circuit(cb); + Ok(config) + }, + ) + .unwrap() + .unwrap(); + + let insn_code = encode_rv32(InsnKind::SLTI, 2, 0, 4, imm_i(imm)); + let (raw_witin, lkm) = + SltiInstruction::assign_instances(&config, cb.cs.num_witin as usize, vec![ + StepRecord::new_i_instruction( + 3, + Change::new(MOCK_PC_START, MOCK_PC_START + PC_STEP_SIZE), + insn_code, + rs1 as Word, + Change::new(0, rd), + 0, + ), + ]) + .unwrap(); + + let expected_rd_written = + UInt::from_const_unchecked(Value::new_unchecked(rd).as_u16_limbs().to_vec()); + config + .rd_written + .require_equal(|| "assert_rd_written", &mut cb, &expected_rd_written) + .unwrap(); + + MockProver::assert_satisfied( + &cb, + &raw_witin + .de_interleaving() + .into_mles() + .into_iter() + .map(|v| v.into()) + .collect_vec(), + &[insn_code], + None, + Some(lkm), + ); + } + + #[test] + fn test_slti_true() { + verify("lt = true, 0 < 1", 0, 1, 1); + verify("lt = true, 1 < 2", 1, 2, 1); + verify("lt = true, -1 < 0", -1, 0, 1); + verify("lt = true, -1 < 1", -1, 1, 1); + verify("lt = true, -2 < -1", -2, -1, 1); + // -2048 <= imm <= 2047 + verify("lt = true, imm upper bondary", i32::MIN, 2047, 1); + verify("lt = true, imm lower bondary", i32::MIN, -2048, 1); + } + + #[test] + fn test_slti_false() { + verify("lt = false, 1 < 0", 1, 0, 0); + verify("lt = false, 2 < 1", 2, 1, 0); + verify("lt = false, 0 < -1", 0, -1, 0); + verify("lt = false, 1 < -1", 1, -1, 0); + verify("lt = false, -1 < -2", -1, -2, 0); + verify("lt = false, 0 == 0", 0, 0, 0); + verify("lt = false, 1 == 1", 1, 1, 0); + verify("lt = false, -1 == -1", -1, -1, 0); + // -2048 <= imm <= 2047 + verify("lt = false, imm upper bondary", i32::MAX, 2047, 0); + verify("lt = false, imm lower bondary", i32::MAX, -2048, 0); + } + + #[test] + fn test_slti_random() { + let mut rng = rand::thread_rng(); + let a: i32 = rng.gen(); + let b: i32 = rng.gen::() % 2048; + println!("random: {} = ArithInstruction; +// TODO combine with SLT impl Instruction for ArithInstruction { type InstructionConfig = ArithConfig; @@ -56,7 +57,7 @@ impl Instruction for ArithInstruction::construct_circuit( circuit_builder, @@ -104,7 +105,7 @@ impl Instruction for ArithInstruction u32 { + // imm is 13 bits in B-type + imm_with_max_valid_bits(imm, 13) +} + +pub fn imm_i(imm: i32) -> u32 { + // imm is 12 bits in I-type + imm_with_max_valid_bits(imm, 12) +} + +pub fn imm_j(imm: i32) -> u32 { + // imm is 21 bits in J-type + imm_with_max_valid_bits(imm, 21) +} + +fn imm_with_max_valid_bits(imm: i32, bits: u32) -> u32 { + imm as u32 & !(u32::MAX << bits) +} + +pub fn imm_u(imm: u32) -> u32 { + // valid imm is imm[12:31] in U-type + imm << 12 +} diff --git a/ceno_zkvm/src/keygen.rs b/ceno_zkvm/src/keygen.rs index af6fa0935..b17573384 100644 --- a/ceno_zkvm/src/keygen.rs +++ b/ceno_zkvm/src/keygen.rs @@ -24,6 +24,9 @@ impl ZKVMConstraintSystem { assert!(vm_pk.circuit_pks.insert(c_name, circuit_pk).is_none()); } + vm_pk.initial_global_state_expr = self.initial_global_state_expr; + vm_pk.finalize_global_state_expr = self.finalize_global_state_expr; + Ok(vm_pk) } } diff --git a/ceno_zkvm/src/lib.rs b/ceno_zkvm/src/lib.rs index 5af066611..cb5978754 100644 --- a/ceno_zkvm/src/lib.rs +++ b/ceno_zkvm/src/lib.rs @@ -12,6 +12,7 @@ pub mod circuit_builder; pub mod expression; pub mod gadgets; mod keygen; +pub mod state; pub mod structs; mod uint; mod utils; diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index 24b6aebd3..3e07f64d8 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -1,5 +1,6 @@ use ff_ext::ExtensionField; use mpcs::PolynomialCommitmentScheme; +use serde::{Deserialize, Serialize}; use std::{collections::BTreeMap, fmt::Debug}; use sumcheck::structs::IOPProverMessage; @@ -43,21 +44,18 @@ pub struct ZKVMOpcodeProof pub wits_in_evals: Vec, } -#[derive(Clone)] +#[derive(Clone, Serialize, Deserialize)] pub struct ZKVMTableProof> { - pub num_instances: usize, - // logup sum at layer 1 - pub lk_p1_out_eval: E, - pub lk_p2_out_eval: E, - pub lk_q1_out_eval: E, - pub lk_q2_out_eval: E, + // tower evaluation at layer 1 + pub r_out_evals: Vec<[E; 2]>, + pub w_out_evals: Vec<[E; 2]>, + pub lk_out_evals: Vec<[E; 4]>, - pub tower_proof: TowerProofs, + pub same_r_sumcheck_proofs: Option>>, + pub rw_in_evals: Vec, + pub lk_in_evals: Vec, - // select layer sumcheck proof - pub sel_sumcheck_proofs: Vec>, - pub lk_d_in_evals: Vec, - pub lk_n_in_evals: Vec, + pub tower_proof: TowerProofs, pub fixed_in_evals: Vec, pub fixed_opening_proof: PCS::Proof, @@ -69,18 +67,30 @@ pub struct ZKVMTableProof> #[derive(Default, Clone, Debug)] pub struct PublicValues { exit_code: T, + init_pc: T, + init_cycle: T, end_pc: T, + end_cycle: T, } impl PublicValues { - pub fn new(exit_code: u32, end_pc: u32) -> Self { - Self { exit_code, end_pc } + pub fn new(exit_code: u32, init_pc: u32, init_cycle: u32, end_pc: u32, end_cycle: u32) -> Self { + Self { + exit_code, + init_pc, + init_cycle, + end_pc, + end_cycle, + } } pub fn to_vec(&self) -> Vec { vec![ E::BaseField::from((self.exit_code & 0xffff) as u64), E::BaseField::from(((self.exit_code >> 16) & 0xffff) as u64), + E::BaseField::from(self.init_pc as u64), + E::BaseField::from(self.init_cycle as u64), E::BaseField::from(self.end_pc as u64), + E::BaseField::from(self.end_cycle as u64), ] } } diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index ab3c13dca..6c325573a 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -1,19 +1,23 @@ use super::utils::{eval_by_expr, wit_infer_by_expr}; use crate::{ + ROMType, circuit_builder::{CircuitBuilder, ConstraintSystem}, expression::{Expression, fmt}, scheme::utils::eval_by_expr_with_fixed, tables::{ AndTable, LtuTable, OpsTable, OrTable, PowTable, ProgramTableCircuit, RangeTable, - TableCircuit, U5Table, U8Table, U16Table, XorTable, + TableCircuit, U5Table, U8Table, U14Table, U16Table, XorTable, }, + witness::LkMultiplicity, }; use ark_std::test_rng; use base64::{Engine, engine::general_purpose::STANDARD_NO_PAD}; use ceno_emul::{ByteAddr, CENO_PLATFORM}; +use ff::Field; use ff_ext::ExtensionField; use generic_static::StaticTypeMap; -use itertools::Itertools; +use goldilocks::SmallField; +use itertools::{Itertools, izip}; use multilinear_extensions::virtual_poly_v2::ArcMultilinearExtension; use std::{ collections::HashSet, @@ -24,111 +28,25 @@ use std::{ ops::Neg, sync::OnceLock, }; +use strum::IntoEnumIterator; -pub const MOCK_RS1: u32 = 2; -pub const MOCK_RS2: u32 = 3; -pub const MOCK_RD: u32 = 4; -pub const MOCK_IMM_3: u32 = 3; -pub const MOCK_IMM_31: u32 = 31; -pub const MOCK_IMM_NEG3: u32 = 32 - 3; -/// The program baked in the MockProver. -/// TODO: Make this a parameter? -#[allow(clippy::identity_op)] -#[allow(clippy::unusual_byte_groupings)] -pub const MOCK_PROGRAM: &[u32] = &[ - // R-Type - // funct7 | rs2 | rs1 | funct3 | rd | opcode - // ----------------------------------------- - // add x4, x2, x3 - 0x00 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0x00 << 12 | MOCK_RD << 7 | 0x33, - // sub x4, x2, x3 - 0x20 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0x00 << 12 | MOCK_RD << 7 | 0x33, - // mul (0x01, 0x00, 0x33) - 0x01 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0x00 << 12 | MOCK_RD << 7 | 0x33, - // and x4, x2, x3 - 0x00 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0b111 << 12 | MOCK_RD << 7 | 0x33, - // or x4, x2, x3 - 0x00 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0b110 << 12 | MOCK_RD << 7 | 0x33, - // xor x4, x2, x3 - 0x00 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0b100 << 12 | MOCK_RD << 7 | 0x33, - // B-Type - // beq x2, x3, 8 - 0x00 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0b000 << 12 | 0x08 << 7 | 0x63, - // bne x2, x3, 8 - 0x00 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0b001 << 12 | 0x08 << 7 | 0x63, - // blt x2, x3, -8 - 0b_1_111111 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0b_100 << 12 | 0b_1100_1 << 7 | 0x63, - // divu (0x01, 0x05, 0x33) - 0x01 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0b101 << 12 | MOCK_RD << 7 | 0x33, - // srli x4, x2, 3 - 0x00 << 25 | MOCK_IMM_3 << 20 | MOCK_RS1 << 15 | 0x05 << 12 | MOCK_RD << 7 | 0x13, - // srli x4, x2, 31 - 0x00 << 25 | MOCK_IMM_31 << 20 | MOCK_RS1 << 15 | 0x05 << 12 | MOCK_RD << 7 | 0x13, - // sltu (0x00, 0x03, 0x33) - 0x00 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0b011 << 12 | MOCK_RD << 7 | 0x33, - // addi x4, x2, 3 - 0x00 << 25 | MOCK_IMM_3 << 20 | MOCK_RS1 << 15 | 0x00 << 12 | MOCK_RD << 7 | 0x13, - // addi x4, x2, -3 - 0b_1_111111 << 25 | MOCK_IMM_NEG3 << 20 | MOCK_RS1 << 15 | 0x00 << 12 | MOCK_RD << 7 | 0x13, - // bltu x2, x3, -8 - 0b_1_111111 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0b_110 << 12 | 0b_1100_1 << 7 | 0x63, - // bgeu x2, x3, -8 - 0b_1_111111 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0b_111 << 12 | 0b_1100_1 << 7 | 0x63, - // bge x2, x3, -8 - 0b_1_111111 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0b_101 << 12 | 0b_1100_1 << 7 | 0x63, - // mulhu (0x01, 0x00, 0x33) - 0x01 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0x3 << 12 | MOCK_RD << 7 | 0x33, - // sll x4, x2, x3 - 0x00 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0b001 << 12 | MOCK_RD << 7 | 0x33, - // srl x4, x2, x3 - 0x00 << 25 | MOCK_RS2 << 20 | MOCK_RS1 << 15 | 0b101 << 12 | MOCK_RD << 7 | 0x33, - // jal x4, 0xffffe - 0b_1_1111111110_1_11111111 << 12 | MOCK_RD << 7 | 0x6f, - // lui x4, 0x90005 - 0x90005 << 12 | MOCK_RD << 7 | 0x37, - // auipc x4, 0x90005 - 0x90005 << 12 | MOCK_RD << 7 | 0x17, -]; -// Addresses of particular instructions in the mock program. -pub const MOCK_PC_ADD: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start()); -pub const MOCK_PC_SUB: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 4); -pub const MOCK_PC_MUL: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 8); -pub const MOCK_PC_AND: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 12); -pub const MOCK_PC_OR: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 16); -pub const MOCK_PC_XOR: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 20); -pub const MOCK_PC_BEQ: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 24); -pub const MOCK_PC_BNE: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 28); -pub const MOCK_PC_BLT: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 32); -pub const MOCK_PC_DIVU: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 36); -pub const MOCK_PC_SRLI: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 40); -pub const MOCK_PC_SRLI_31: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 44); -pub const MOCK_PC_SLTU: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 48); -pub const MOCK_PC_ADDI: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 52); -pub const MOCK_PC_ADDI_SUB: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 56); -pub const MOCK_PC_BLTU: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 60); -pub const MOCK_PC_BGEU: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 64); -pub const MOCK_PC_BGE: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 68); -pub const MOCK_PC_MULHU: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 72); -pub const MOCK_PC_SLL: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 76); -pub const MOCK_PC_SRL: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 80); -pub const MOCK_PC_JAL: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 84); -pub const MOCK_PC_LUI: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 88); -pub const MOCK_PC_AUIPC: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start() + 92); +const MOCK_PROGRAM_SIZE: usize = 32; +pub const MOCK_PC_START: ByteAddr = ByteAddr(CENO_PLATFORM.pc_start()); #[allow(clippy::enum_variant_names)] #[derive(Debug, Clone)] pub(crate) enum MockProverError { AssertZeroError { expression: Expression, - evaluated: E, + evaluated: E::BaseField, name: String, inst_id: usize, }, AssertEqualError { left_expression: Expression, right_expression: Expression, - left: E, - right: E, + left: E::BaseField, + right: E::BaseField, name: String, inst_id: usize, }, @@ -141,6 +59,12 @@ pub(crate) enum MockProverError { // TODO later // r_expressions // w_expressions + LkMultiplicityError { + rom_type: ROMType, + key: u64, + count: isize, // +ve => missing in cs, -ve => missing in assignments + inst_id: usize, + }, } impl PartialEq for MockProverError { @@ -225,7 +149,7 @@ impl MockProverError { } => { let expression_fmt = fmt::expr(expression, &mut wtns, false); let wtns_fmt = fmt::wtns(&wtns, wits_in, *inst_id, wits_in_name); - let eval_fmt = fmt::field(evaluated); + let eval_fmt = fmt::base_field(evaluated, false); println!( "\nAssertZeroError {name:?}: Evaluated expression is not zero\n\ Expression: {expression_fmt}\n\ @@ -244,8 +168,8 @@ impl MockProverError { let left_expression_fmt = fmt::expr(left_expression, &mut wtns, false); let right_expression_fmt = fmt::expr(right_expression, &mut wtns, false); let wtns_fmt = fmt::wtns(&wtns, wits_in, *inst_id, wits_in_name); - let left_eval_fmt = fmt::field(left); - let right_eval_fmt = fmt::field(right); + let left_eval_fmt = fmt::base_field(left, false); + let right_eval_fmt = fmt::base_field(right, false); println!( "\nAssertEqualError {name:?}\n\ Left: {left_eval_fmt} != Right: {right_eval_fmt}\n\ @@ -270,6 +194,54 @@ impl MockProverError { Inst[{inst_id}]:\n{wtns_fmt}\n", ); } + Self::LkMultiplicityError { + rom_type, + key, + count, + .. + } => { + let lookups = if count.abs() > 1 { + format!("{} Lookups", count.abs()) + } else { + "Lookup".to_string() + }; + let location = if *count > 0 { + "constraint system" + } else { + "assignments" + }; + let element = match rom_type { + ROMType::U5 | ROMType::U8 | ROMType::U14 | ROMType::U16 => { + format!("Element: {key}") + } + ROMType::And => { + let (a, b) = AndTable::unpack(*key); + format!("Element: {a} < {b}") + } + ROMType::Or => { + let (a, b) = OrTable::unpack(*key); + format!("Element: {a} || {b}") + } + ROMType::Xor => { + let (a, b) = XorTable::unpack(*key); + format!("Element: {a} ^ {b}") + } + ROMType::Ltu => { + let (a, b) = LtuTable::unpack(*key); + format!("Element: {a} < {b}") + } + ROMType::Pow => { + let (a, b) = PowTable::unpack(*key); + format!("Element: {a} ** {b}") + } + ROMType::Instruction => format!("PC: {key}"), + }; + println!( + "\nLkMultiplicityError:\n\ + {lookups} of {rom_type:?} missing in {location}\n\ + {element}\n" + ); + } } } @@ -277,9 +249,14 @@ impl MockProverError { match self { Self::AssertZeroError { inst_id, .. } | Self::AssertEqualError { inst_id, .. } - | Self::LookupError { inst_id, .. } => *inst_id, + | Self::LookupError { inst_id, .. } + | Self::LkMultiplicityError { inst_id, .. } => *inst_id, } } + + fn contains(&self, constraint_name: &str) -> bool { + format!("{:?}", self).contains(constraint_name) + } } pub(crate) struct MockProver { @@ -317,39 +294,17 @@ fn load_tables(cb: &CircuitBuilder, challenge: [E; 2]) -> } } - fn load_program_table( - t_vec: &mut Vec>, - _cb: &CircuitBuilder, - challenge: [E; 2], - ) { - let mut cs = ConstraintSystem::::new(|| "mock_program"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = ProgramTableCircuit::construct_circuit(&mut cb).unwrap(); - let fixed = - ProgramTableCircuit::::generate_fixed_traces(&config, cs.num_fixed, MOCK_PROGRAM); - for table_expr in &cs.lk_table_expressions { - for row in fixed.iter_rows() { - // TODO: Find a better way to obtain the row content. - let row = row - .iter() - .map(|v| unsafe { (*v).assume_init() }.into()) - .collect::>(); - let rlc_record = eval_by_expr_with_fixed(&row, &[], &challenge, &table_expr.values); - t_vec.push(rlc_record.to_canonical_u64_vec()); - } - } - } - let mut table_vec = vec![]; load_range_table::(&mut table_vec, cb, challenge); load_range_table::(&mut table_vec, cb, challenge); + load_range_table::(&mut table_vec, cb, challenge); load_range_table::(&mut table_vec, cb, challenge); load_op_table::(&mut table_vec, cb, challenge); load_op_table::(&mut table_vec, cb, challenge); load_op_table::(&mut table_vec, cb, challenge); load_op_table::(&mut table_vec, cb, challenge); load_op_table::(&mut table_vec, cb, challenge); - load_program_table(&mut table_vec, cb, challenge); + HashSet::from_iter(table_vec) } @@ -358,7 +313,7 @@ fn load_tables(cb: &CircuitBuilder, challenge: [E; 2]) -> #[allow(clippy::type_complexity)] fn load_once_tables( cb: &CircuitBuilder, -) -> ([E; 2], &'static HashSet>) { +) -> ([E; 2], HashSet>) { static CACHE: OnceLock; 2], HashSet>)>> = OnceLock::new(); let cache = CACHE.get_or_init(StaticTypeMap::new); @@ -402,7 +357,7 @@ fn load_once_tables( let ptr = repr.as_slice().as_ptr() as *const E; *ptr }), - table, + table.clone(), ) } @@ -411,31 +366,47 @@ impl<'a, E: ExtensionField + Hash> MockProver { cb: &CircuitBuilder, wits_in: &[ArcMultilinearExtension<'a, E>], challenge: [E; 2], + lkm: Option, ) -> Result<(), Vec>> { - Self::run_maybe_challenge(cb, wits_in, &[], Some(challenge)) + Self::run_maybe_challenge(cb, wits_in, &[], &[], Some(challenge), lkm) } pub fn run( cb: &CircuitBuilder, wits_in: &[ArcMultilinearExtension<'a, E>], + programs: &[u32], + lkm: Option, ) -> Result<(), Vec>> { - Self::run_maybe_challenge(cb, wits_in, &[], None) + Self::run_maybe_challenge(cb, wits_in, programs, &[], None, lkm) } fn run_maybe_challenge( cb: &CircuitBuilder, wits_in: &[ArcMultilinearExtension<'a, E>], + input_programs: &[u32], pi: &[E::BaseField], challenge: Option<[E; 2]>, + lkm: Option, ) -> Result<(), Vec>> { - let table = challenge.map(|challenge| load_tables(cb, challenge)); - let (challenge, table) = if let Some(challenge) = challenge { - (challenge, table.as_ref().unwrap()) + // fix the program table + let mut programs = [0u32; MOCK_PROGRAM_SIZE]; + for (i, &program) in input_programs.iter().enumerate() { + programs[i] = program; + } + + // load tables + let (challenge, mut table) = if let Some(challenge) = challenge { + (challenge, load_tables(cb, challenge)) } else { load_once_tables(cb) }; - let mut errors = vec![]; + let mut prog_table = vec![]; + Self::load_program_table(&mut prog_table, &programs, challenge); + for prog in prog_table { + table.insert(prog); + } + let mut errors = vec![]; // Assert zero expressions for (expr, name) in cb .cs @@ -456,39 +427,21 @@ impl<'a, E: ExtensionField + Hash> MockProver { let right = right.neg(); let left_evaluated = wit_infer_by_expr(&[], wits_in, pi, &challenge, &left); - let left_evaluated = left_evaluated - .get_ext_field_vec_optn() - .map(|v| v.to_vec()) - .unwrap_or_else(|| { - left_evaluated - .get_base_field_vec() - .iter() - .map(|v| E::from(*v)) - .collect_vec() - }); + let left_evaluated = left_evaluated.get_base_field_vec(); let right_evaluated = wit_infer_by_expr(&[], wits_in, pi, &challenge, &right); - let right_evaluated = right_evaluated - .get_ext_field_vec_optn() - .map(|v| v.to_vec()) - .unwrap_or_else(|| { - right_evaluated - .get_base_field_vec() - .iter() - .map(|v| E::from(*v)) - .collect_vec() - }); + let right_evaluated = right_evaluated.get_base_field_vec(); // left_evaluated.len() ?= right_evaluated.len() due to padding instance for (inst_id, (left_element, right_element)) in - left_evaluated.into_iter().zip(right_evaluated).enumerate() + izip!(left_evaluated, right_evaluated).enumerate() { if left_element != right_element { errors.push(MockProverError::AssertEqualError { left_expression: left.clone(), right_expression: right.clone(), - left: left_element, - right: right_element, + left: *left_element, + right: *right_element, name: name.clone(), inst_id, }); @@ -497,19 +450,10 @@ impl<'a, E: ExtensionField + Hash> MockProver { } else { // contains require_zero let expr_evaluated = wit_infer_by_expr(&[], wits_in, pi, &challenge, expr); - let expr_evaluated = expr_evaluated - .get_ext_field_vec_optn() - .map(|v| v.to_vec()) - .unwrap_or_else(|| { - expr_evaluated - .get_base_field_vec() - .iter() - .map(|v| E::from(*v)) - .collect_vec() - }); + let expr_evaluated = expr_evaluated.get_base_field_vec(); for (inst_id, element) in expr_evaluated.iter().enumerate() { - if *element != E::ZERO { + if *element != E::BaseField::ZERO { errors.push(MockProverError::AssertZeroError { expression: expr.clone(), evaluated: *element, @@ -544,6 +488,100 @@ impl<'a, E: ExtensionField + Hash> MockProver { } } + // LK Multiplicity check + if let Some(lkm_from_assignment) = lkm { + // Infer LK Multiplicity from constraint system. + let lkm_from_cs = cb + .cs + .lk_expressions_items_map + .iter() + .map(|(rom_type, items)| { + ( + rom_type, + items + .iter() + .map(|expr| { + // TODO generalized to all inst_id + let inst_id = 0; + wit_infer_by_expr(&[], wits_in, pi, &challenge, expr) + .get_base_field_vec()[inst_id] + .to_canonical_u64() + }) + .collect::>(), + ) + }) + .fold(LkMultiplicity::default(), |mut lkm, (rom_type, args)| { + match rom_type { + ROMType::U5 => lkm.assert_ux::<5>(args[0]), + ROMType::U8 => lkm.assert_ux::<8>(args[0]), + ROMType::U14 => lkm.assert_ux::<14>(args[0]), + ROMType::U16 => lkm.assert_ux::<16>(args[0]), + ROMType::And => lkm.lookup_and_byte(args[0], args[1]), + ROMType::Or => lkm.lookup_or_byte(args[0], args[1]), + ROMType::Xor => lkm.lookup_xor_byte(args[0], args[1]), + ROMType::Ltu => lkm.lookup_ltu_byte(args[0], args[1]), + ROMType::Pow => { + assert_eq!(args[0], 2); + lkm.lookup_pow2(args[1]) + } + ROMType::Instruction => lkm.fetch(args[0] as u32), + }; + + lkm + }); + + let lkm_from_cs = lkm_from_cs.into_finalize_result(); + let lkm_from_assignment = lkm_from_assignment.into_finalize_result(); + + // Compare each LK Multiplicity. + + for (rom_type, cs_map, ass_map) in + izip!(ROMType::iter(), &lkm_from_cs, &lkm_from_assignment) + { + if *cs_map != *ass_map { + let cs_keys: HashSet<_> = cs_map.keys().collect(); + let ass_keys: HashSet<_> = ass_map.keys().collect(); + + // lookup missing in lkm Constraint System. + ass_keys.difference(&cs_keys).for_each(|k| { + let count_ass = ass_map.get(k).unwrap(); + errors.push(MockProverError::LkMultiplicityError { + rom_type, + key: **k, + count: *count_ass as isize, + inst_id: 0, + }) + }); + + // lookup missing in lkm Assignments. + cs_keys.difference(&ass_keys).for_each(|k| { + let count_cs = cs_map.get(k).unwrap(); + errors.push(MockProverError::LkMultiplicityError { + rom_type, + key: **k, + count: -(*count_cs as isize), + inst_id: 0, + }) + }); + + // count of specific lookup differ lkm assignments and lkm cs + cs_keys.intersection(&ass_keys).for_each(|k| { + let count_cs = cs_map.get(k).unwrap(); + let count_ass = ass_map.get(k).unwrap(); + + if count_cs != count_ass { + errors.push(MockProverError::LkMultiplicityError { + rom_type, + key: **k, + count: (*count_ass as isize) - (*count_cs as isize), + inst_id: 0, + }) + } + }); + } + } + } + if errors.is_empty() { Ok(()) } else { @@ -551,56 +589,98 @@ impl<'a, E: ExtensionField + Hash> MockProver { } } - pub fn assert_satisfied( + fn load_program_table( + t_vec: &mut Vec>, + programs: &[u32; MOCK_PROGRAM_SIZE], + challenge: [E; 2], + ) { + let mut cs = ConstraintSystem::::new(|| "mock_program"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + ProgramTableCircuit::<_, MOCK_PROGRAM_SIZE>::construct_circuit(&mut cb).unwrap(); + let fixed = ProgramTableCircuit::::generate_fixed_traces( + &config, + cs.num_fixed, + programs, + ); + for table_expr in &cs.lk_table_expressions { + for row in fixed.iter_rows() { + // TODO: Find a better way to obtain the row content. + let row = row + .iter() + .map(|v| unsafe { (*v).assume_init() }.into()) + .collect::>(); + let rlc_record = eval_by_expr_with_fixed(&row, &[], &challenge, &table_expr.values); + t_vec.push(rlc_record.to_canonical_u64_vec()); + } + } + } + + /// Run and check errors + /// + /// Panic, unless we see exactly the expected errors. + /// (Expecting no errors is a valid expectation.) + pub fn assert_with_expected_errors( cb: &CircuitBuilder, wits_in: &[ArcMultilinearExtension<'a, E>], + programs: &[u32], + constraint_names: &[&str], challenge: Option<[E; 2]>, + lkm: Option, ) { - let result = if let Some(challenge) = challenge { - Self::run_with_challenge(cb, wits_in, challenge) + let error_groups = if let Some(challenge) = challenge { + Self::run_with_challenge(cb, wits_in, challenge, lkm) } else { - Self::run(cb, wits_in) - }; - match result { - Ok(_) => {} - Err(errors) => { - println!("======================================================"); - - println!( - r"Hints: - - If you encounter a constraint error that sporadically occurs in different environments - (e.g., passes locally but fails in CI), - this often points to unassigned witnesses during the assignment phase. - Accessing these cells before they are properly written leads to undefined behavior. + Self::run(cb, wits_in, programs, lkm) + } + .err() + .into_iter() + .flatten() + .into_group_map_by(|error| constraint_names.iter().find(|&name| error.contains(name))); + // Unexpected errors + if let Some(errors) = error_groups.get(&None) { + println!("======================================================"); + + println!( + r" +Hints: +- If you encounter a constraint error that sporadically occurs in different environments + (e.g., passes locally but fails in CI), + this often points to unassigned witnesses during the assignment phase. + Accessing these cells before they are properly written leads to undefined behavior. " - ); - - // Print errors and skip consecutive duplicates errors if they are equal. - let mut duplicates = 0; - let mut prev_err = None; - for error in &errors { - if prev_err.is_some() && prev_err.unwrap() == error { - duplicates += 1; - } else { - error.print(wits_in, &cb.cs.witin_namespace_map); - } - prev_err = Some(error); - } + ); - if duplicates > 0 { - println!( - "Error: {} constraints not satisfied ({} duplicates hidden)", - errors.len(), - duplicates - ); - } else { - println!("Error: {} constraints not satisfied", errors.len()); + for (count, error) in errors.iter().dedup_with_count() { + error.print(wits_in, &cb.cs.witin_namespace_map); + if count > 1 { + println!("Error: {} duplicates hidden.", count - 1); } - println!("======================================================"); - panic!("Constraints not satisfied"); } + println!("Error: {} constraints not satisfied", errors.len()); + println!("======================================================"); + panic!("(Unexpected) Constraints not satisfied"); + } + for constraint_name in constraint_names { + // Expected errors didn't happen: + error_groups.get(&Some(constraint_name)).unwrap_or_else(|| { + println!("======================================================"); + println!("Error: {} constraint satisfied", constraint_name); + println!("======================================================"); + panic!("Constraints unexpectedly satisfied"); + }); } } + + pub fn assert_satisfied( + cb: &CircuitBuilder, + wits_in: &[ArcMultilinearExtension<'a, E>], + programs: &[u32], + challenge: Option<[E; 2]>, + lkm: Option, + ) { + Self::assert_with_expected_errors(cb, wits_in, programs, &[], challenge, lkm); + } } #[cfg(test)] @@ -637,14 +717,14 @@ mod tests { let c = cb.create_witin(|| "c")?; // degree 1 - cb.require_equal(|| "a + 1 == b", b.expr(), a.expr() + 1.into())?; - cb.require_zero(|| "c - 2 == 0", c.expr() - 2.into())?; + cb.require_equal(|| "a + 1 == b", b.expr(), a.expr() + 1)?; + cb.require_zero(|| "c - 2 == 0", c.expr() - 2)?; // degree > 1 let d = cb.create_witin(|| "d")?; cb.require_zero( || "d*d - 6*d + 9 == 0", - d.expr() * d.expr() - d.expr() * 6.into() + 9.into(), + d.expr() * d.expr() - d.expr() * 6 + 9, )?; Ok(Self { a, b, c }) @@ -673,7 +753,7 @@ mod tests { .into(), ]; - MockProver::assert_satisfied(&builder, &wits_in, None); + MockProver::assert_satisfied(&builder, &wits_in, &[], None, None); } #[derive(Debug)] @@ -706,7 +786,7 @@ mod tests { ]; let challenge = [1.into(), 1000.into()]; - MockProver::assert_satisfied(&builder, &wits_in, Some(challenge)); + MockProver::assert_satisfied(&builder, &wits_in, &[], Some(challenge), None); } #[test] @@ -720,7 +800,7 @@ mod tests { let wits_in = vec![vec![Goldilocks::from(123)].into_mle().into()]; let challenge = [2.into(), 1000.into()]; - let result = MockProver::run_with_challenge(&builder, &wits_in, challenge); + let result = MockProver::run_with_challenge(&builder, &wits_in, challenge, None); assert!(result.is_err(), "Expected error"); let err = result.unwrap_err(); assert_eq!(err, vec![MockProverError::LookupError { @@ -833,7 +913,9 @@ mod tests { .into_iter() .map(|v| v.into()) .collect_vec(), + &[], Some([1.into(), 1000.into()]), + None, ); } @@ -869,7 +951,9 @@ mod tests { .into_iter() .map(|v| v.into()) .collect_vec(), + &[], Some([1.into(), 1000.into()]), + None, ); } @@ -953,7 +1037,9 @@ mod tests { .into_iter() .map(|v| v.into()) .collect_vec(), + &[], Some([1.into(), 1000.into()]), + None, ); } @@ -990,7 +1076,9 @@ mod tests { .into_iter() .map(|v| v.into()) .collect_vec(), + &[], Some([1.into(), 1000.into()]), + None, ); } } diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index dc07bda6a..c87283c01 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -4,6 +4,7 @@ use std::{ sync::Arc, }; +use ff::Field; use itertools::Itertools; use mpcs::PolynomialCommitmentScheme; use multilinear_extensions::{ @@ -31,7 +32,7 @@ use crate::{ structs::{ Point, ProvingKey, TowerProofs, TowerProver, TowerProverSpec, ZKVMProvingKey, ZKVMWitnesses, }, - utils::{get_challenge_pows, next_pow2_instance_padding, proper_num_threads}, + utils::{get_challenge_pows, next_pow2_instance_padding, optimal_sumcheck_threads}, virtual_polys::VirtualPolynomials, }; @@ -51,7 +52,6 @@ impl> ZKVMProver { &self, witnesses: ZKVMWitnesses, pi: PublicValues, - max_threads: usize, mut transcript: Transcript, ) -> Result, ZKVMError> { let mut vm_proof = ZKVMProof::empty(pi); @@ -110,7 +110,9 @@ impl> ZKVMProver { let wits_commit = commitments.remove(circuit_name).unwrap(); // TODO: add an enum for circuit type either in constraint_system or vk let cs = pk.get_cs(); - let is_opcode_circuit = cs.lk_table_expressions.is_empty(); + let is_opcode_circuit = cs.lk_table_expressions.is_empty() + && cs.r_table_expressions.is_empty() + && cs.w_table_expressions.is_empty(); if is_opcode_circuit { tracing::debug!( @@ -132,7 +134,6 @@ impl> ZKVMProver { wits_commit, pi, num_instances, - max_threads, transcript, &challenges, )?; @@ -152,8 +153,6 @@ impl> ZKVMProver { witness.into_iter().map(|v| v.into()).collect_vec(), wits_commit, pi, - num_instances, - max_threads, transcript, &challenges, )?; @@ -184,7 +183,6 @@ impl> ZKVMProver { wits_commit: PCS::CommitmentWithData, pi: &[E::BaseField], num_instances: usize, - max_threads: usize, transcript: &mut Transcript, challenges: &[E; 2], ) -> Result, ZKVMError> { @@ -318,7 +316,6 @@ impl> ZKVMProver { let lk_q2_out_eval = lk_wit_layers[0][3].get_ext_field_vec()[0]; assert!(record_r_out_evals.len() == NUM_FANIN && record_w_out_evals.len() == NUM_FANIN); let (rt_tower, tower_proof) = TowerProver::create_proof( - max_threads, vec![ TowerProverSpec { witness: r_wit_layers, @@ -361,7 +358,7 @@ impl> ZKVMProver { rt_tower[..log2_num_instances].to_vec(), ); - let num_threads = proper_num_threads(log2_num_instances, max_threads); + let num_threads = optimal_sumcheck_threads(log2_num_instances); let alpha_pow = get_challenge_pows( MAINCONSTRAIN_SUMCHECK_BATCH_SIZE + cs.assert_zero_sumcheck_expressions.len(), transcript, @@ -490,10 +487,10 @@ impl> ZKVMProver { let expected_zero_poly = wit_infer_by_expr(&[], &witnesses, pi, challenges, expr); let top_100_errors = expected_zero_poly - .get_ext_field_vec() + .get_base_field_vec() .iter() .enumerate() - .filter(|(_, v)| **v != E::ZERO) + .filter(|(_, v)| **v != E::BaseField::ZERO) .take(100) .collect_vec(); if !top_100_errors.is_empty() { @@ -611,6 +608,9 @@ impl> ZKVMProver { } #[allow(clippy::too_many_arguments)] + /// support batch prove for logup + product arguments each with different num_vars() + /// side effect: concurrency will be determine based on min(thread, num_vars()), + /// so suggest dont batch too small table (size < threads) with large table together pub fn create_table_proof( &self, name: &str, @@ -619,8 +619,6 @@ impl> ZKVMProver { witnesses: Vec>, wits_commit: PCS::CommitmentWithData, pi: &[E::BaseField], - num_instances: usize, - max_threads: usize, transcript: &mut Transcript, challenges: &[E; 2], ) -> Result, ZKVMError> { @@ -632,163 +630,345 @@ impl> ZKVMProver { .iter() .map(|f| -> ArcMultilinearExtension { Arc::new(f.get_ranged_mle(1, 0)) }) .collect::>>(); - let log2_num_instances = ceil_log2(num_instances); - let next_pow2_instances = 1 << log2_num_instances; // sanity check assert_eq!(witnesses.len(), cs.num_witin as usize); assert_eq!(fixed.len(), cs.num_fixed); - assert!(witnesses.iter().all(|v| { - v.num_vars() == log2_num_instances && v.evaluations().len() == next_pow2_instances - })); - assert!(!cs.lk_table_expressions.is_empty()); + // check all witness size are power of 2 + assert!( + witnesses + .iter() + .all(|v| { v.evaluations().len().is_power_of_two() }) + ); + assert!( + !cs.r_table_expressions.is_empty() + || !cs.w_table_expressions.is_empty() + || !cs.lk_table_expressions.is_empty() + ); + assert!( + cs.r_table_expressions + .iter() + .zip_eq(cs.w_table_expressions.iter()) + .all(|(r, w)| r.table_len == w.table_len) + ); // main constraint: lookup denominator and numerator record witness inference let span = entered_span!("wit_inference::record"); - let records_wit: Vec> = cs - .lk_table_expressions + let mut records_wit: Vec> = cs + .r_table_expressions .par_iter() - .map(|lk| &lk.values) + .map(|r| &r.values) + .chain(cs.w_table_expressions.par_iter().map(|w| &w.values)) .chain( cs.lk_table_expressions .par_iter() .map(|lk| &lk.multiplicity), ) + .chain(cs.lk_table_expressions.par_iter().map(|lk| &lk.values)) .map(|expr| { assert_eq!(expr.degree(), 1); wit_infer_by_expr(&fixed, &witnesses, pi, challenges, expr) }) .collect(); - let (lk_d_wit, lk_n_wit) = records_wit.split_at(cs.lk_table_expressions.len()); - exit_span!(span); + let max_log2_num_instance = records_wit.iter().map(|mle| mle.num_vars()).max().unwrap(); + let min_log2_num_instance = records_wit.iter().map(|mle| mle.num_vars()).min().unwrap(); + let (r_set_wit, remains) = records_wit.split_at_mut(cs.r_table_expressions.len()); + let (w_set_wit, remains) = remains.split_at_mut(cs.w_table_expressions.len()); + let (lk_n_wit, remains) = remains.split_at_mut(cs.lk_table_expressions.len()); + let (lk_d_wit, _empty) = remains.split_at_mut(cs.lk_table_expressions.len()); + assert!(_empty.is_empty()); - let lk_counts_per_instance = cs.lk_table_expressions.len(); - let log2_lk_count = ceil_log2(lk_counts_per_instance); + exit_span!(span); // infer all tower witness after last layer let span = entered_span!("wit_inference::tower_witness_lk_last_layer"); - // TODO optimize last layer to avoid alloc new vector to save memory - let lk_denominator_last_layer = - interleaving_mles_to_mles(lk_d_wit, num_instances, NUM_FANIN_LOGUP, E::ONE); - let lk_numerator_last_layer = - interleaving_mles_to_mles(lk_n_wit, num_instances, NUM_FANIN_LOGUP, E::ZERO); - assert_eq!(lk_denominator_last_layer.len(), NUM_FANIN_LOGUP); - assert_eq!(lk_numerator_last_layer.len(), NUM_FANIN_LOGUP); + let mut r_set_last_layer = r_set_wit + .iter() + .chain(w_set_wit.iter()) + .map(|wit| { + let (first, second) = wit + .get_ext_field_vec() + .split_at(wit.evaluations().len() / 2); + let res = vec![ + first.to_vec().into_mle().into(), + second.to_vec().into_mle().into(), + ]; + assert_eq!(res.len(), NUM_FANIN_LOGUP); + res + }) + .collect::>(); + let w_set_last_layer = r_set_last_layer.split_off(r_set_wit.len()); + + let lk_numerator_last_layer = lk_n_wit + .iter() + .map(|wit| { + let (first, second) = wit + .get_base_field_vec() + .split_at(wit.evaluations().len() / 2); + let res = vec![ + first.to_vec().into_mle().into(), + second.to_vec().into_mle().into(), + ]; + assert_eq!(res.len(), NUM_FANIN_LOGUP); + res + }) + .collect::>(); + let lk_denominator_last_layer = lk_d_wit + .iter_mut() + .map(|wit| { + let (first, second) = wit + .get_ext_field_vec() + .split_at(wit.evaluations().len() / 2); + let res = vec![ + first.to_vec().into_mle().into(), + second.to_vec().into_mle().into(), + ]; + assert_eq!(res.len(), NUM_FANIN_LOGUP); + res + }) + .collect::>(); exit_span!(span); let span = entered_span!("wit_inference::tower_witness_lk_layers"); - let lk_wit_layers = - infer_tower_logup_witness(Some(lk_numerator_last_layer), lk_denominator_last_layer); + let r_wit_layers = r_set_last_layer + .into_iter() + .zip(r_set_wit.iter()) + .map(|(last_layer, origin_mle)| { + infer_tower_product_witness(origin_mle.num_vars(), last_layer, NUM_FANIN) + }) + .collect_vec(); + let w_wit_layers = w_set_last_layer + .into_iter() + .zip(w_set_wit.iter()) + .map(|(last_layer, origin_mle)| { + infer_tower_product_witness(origin_mle.num_vars(), last_layer, NUM_FANIN) + }) + .collect_vec(); + let lk_wit_layers = lk_numerator_last_layer + .into_iter() + .zip(lk_denominator_last_layer) + .map(|(lk_n, lk_d)| infer_tower_logup_witness(Some(lk_n), lk_d)) + .collect_vec(); exit_span!(span); if cfg!(test) { // sanity check - assert_eq!(lk_wit_layers.len(), log2_num_instances + log2_lk_count); - assert!(lk_wit_layers.iter().enumerate().all(|(i, w)| { - let expected_size = 1 << i; - let (p1, p2, q1, q2) = (&w[0], &w[1], &w[2], &w[3]); - p1.evaluations().len() == expected_size - && p2.evaluations().len() == expected_size - && q1.evaluations().len() == expected_size - && q2.evaluations().len() == expected_size + assert_eq!(r_wit_layers.len(), cs.r_table_expressions.len()); + assert!( + r_wit_layers + .iter() + .zip(r_set_wit.iter()) // depth equals to num_vars + .all(|(layers, origin_mle)| layers.len() == origin_mle.num_vars()) + ); + assert!(r_wit_layers.iter().all(|layers| { + layers.iter().enumerate().all(|(i, w)| { + let expected_size = 1 << i; + w[0].evaluations().len() == expected_size + && w[1].evaluations().len() == expected_size + }) + })); + + assert_eq!(w_wit_layers.len(), cs.w_table_expressions.len()); + assert!( + w_wit_layers + .iter() + .zip(w_set_wit.iter()) // depth equals to num_vars + .all(|(layers, origin_mle)| layers.len() == origin_mle.num_vars()) + ); + assert!(w_wit_layers.iter().all(|layers| { + layers.iter().enumerate().all(|(i, w)| { + let expected_size = 1 << i; + w[0].evaluations().len() == expected_size + && w[1].evaluations().len() == expected_size + }) + })); + + assert_eq!(lk_wit_layers.len(), cs.lk_table_expressions.len()); + assert!( + lk_wit_layers + .iter() + .zip(lk_n_wit.iter()) // depth equals to num_vars + .all(|(layers, origin_mle)| layers.len() == origin_mle.num_vars()) + ); + assert!(lk_wit_layers.iter().all(|layers| { + layers.iter().enumerate().all(|(i, w)| { + let expected_size = 1 << i; + let (p1, p2, q1, q2) = (&w[0], &w[1], &w[2], &w[3]); + p1.evaluations().len() == expected_size + && p2.evaluations().len() == expected_size + && q1.evaluations().len() == expected_size + && q2.evaluations().len() == expected_size + }) })); } // product constraint tower sumcheck let span = entered_span!("sumcheck::tower"); // final evals for verifier - let lk_p1_out_eval = lk_wit_layers[0][0].get_ext_field_vec()[0]; - let lk_p2_out_eval = lk_wit_layers[0][1].get_ext_field_vec()[0]; - let lk_q1_out_eval = lk_wit_layers[0][2].get_ext_field_vec()[0]; - let lk_q2_out_eval = lk_wit_layers[0][3].get_ext_field_vec()[0]; + let r_out_evals = r_wit_layers + .iter() + .map(|r_wit_layers| { + [ + r_wit_layers[0][0].get_ext_field_vec()[0], + r_wit_layers[0][1].get_ext_field_vec()[0], + ] + }) + .collect_vec(); + let w_out_evals = w_wit_layers + .iter() + .map(|w_wit_layers| { + [ + w_wit_layers[0][0].get_ext_field_vec()[0], + w_wit_layers[0][1].get_ext_field_vec()[0], + ] + }) + .collect_vec(); + let lk_out_evals = lk_wit_layers + .iter() + .map(|lk_wit_layers| { + [ + // p1, p2, q1, q2 + lk_wit_layers[0][0].get_ext_field_vec()[0], + lk_wit_layers[0][1].get_ext_field_vec()[0], + lk_wit_layers[0][2].get_ext_field_vec()[0], + lk_wit_layers[0][3].get_ext_field_vec()[0], + ] + }) + .collect_vec(); + let (rt_tower, tower_proof) = TowerProver::create_proof( - max_threads, - vec![], - vec![TowerProverSpec { - witness: lk_wit_layers, - }], + // pattern [r1, w1, r2, w2, ...] same pair are chain together + r_wit_layers + .into_iter() + .zip(w_wit_layers) + .flat_map(|(r, w)| { + vec![TowerProverSpec { witness: r }, TowerProverSpec { + witness: w, + }] + }) + .collect_vec(), + lk_wit_layers + .into_iter() + .map(|lk_wit_layers| TowerProverSpec { + witness: lk_wit_layers, + }) + .collect_vec(), NUM_FANIN_LOGUP, transcript, ); - assert_eq!(rt_tower.len(), log2_num_instances + log2_lk_count); + assert_eq!( + rt_tower.len(), // num var length should equal to max_num_instance + max_log2_num_instance + ); exit_span!(span); - // selector layer sumcheck - let span = entered_span!("sumcheck::main_sel"); - let rt_lk: Vec = tower_proof.logup_specs_points[0] - .last() - .expect("error getting rt_lk") - .to_vec(); - - let num_threads = proper_num_threads(log2_num_instances, max_threads); - // 2 for denominator and numerator - let alpha_pow = get_challenge_pows(2, transcript); - let mut alpha_pow_iter = alpha_pow.iter(); - let (alpha_lk_d, alpha_lk_n) = ( - alpha_pow_iter.next().unwrap(), - alpha_pow_iter.next().unwrap(), - ); - // create selector: all ONE, but padding ZERO to ceil_log2 - let sel_lk: ArcMultilinearExtension = { - let mut sel_lk = build_eq_x_r_vec(&rt_lk[log2_lk_count..]); - if num_instances < sel_lk.len() { - sel_lk.splice( - num_instances..sel_lk.len(), - std::iter::repeat(E::ZERO).take(sel_lk.len() - num_instances), + // same point sumcheck is optional when all witin + fixed are in same num_vars + let is_skip_same_point_sumcheck = witnesses + .iter() + .chain(fixed.iter()) + .map(|v| v.num_vars()) + .all_equal(); + + let (input_open_point, same_r_sumcheck_proofs, rw_in_evals, lk_in_evals) = + if is_skip_same_point_sumcheck { + (rt_tower, None, vec![], vec![]) + } else { + // one sumcheck to make them opening on same point r (with different prefix) + // If all table length are the same, we can skip this sumcheck + let span = entered_span!("sumcheck::opening_same_point"); + // NOTE: max concurrency will be dominated by smallest table since it will blo + let num_threads = optimal_sumcheck_threads(min_log2_num_instance); + let alpha_pow = get_challenge_pows( + cs.r_table_expressions.len() + + cs.w_table_expressions.len() + + cs.lk_table_expressions.len() * 2, + transcript, ); - } - sel_lk.into_mle().into() - }; + let mut alpha_pow_iter = alpha_pow.iter(); - let mut virtual_polys = VirtualPolynomials::::new(num_threads, log2_num_instances); + // create eq + // TODO same size rt lead to same identical poly eq which can be merged together + let eq = tower_proof + .prod_specs_points + .iter() + .step_by(2) // r,w are in same length therefore share same point + .chain(tower_proof.logup_specs_points.iter()) + .map(|layer_points| { + let rt = layer_points.last().unwrap(); + build_eq_x_r_vec(rt).into_mle().into() + }) + .collect::>>(); - let eq_lk = build_eq_x_r_vec(&rt_lk[..log2_lk_count]); - // lk denominator - // rt := rt || rs - for i in 0..lk_counts_per_instance { - // \sum_t (sel(rt, t) * (\sum_i alpha_lk_d * eq(rs, i) * lk_d_record[i])) - virtual_polys.add_mle_list(vec![&sel_lk, &lk_d_wit[i]], eq_lk[i] * alpha_lk_d); - } - // \sum_t alpha_lk_d * sel(rt, t) * (\sum_i (eq(rs, i)) - 1) - virtual_polys.add_mle_list( - vec![&sel_lk], - *alpha_lk_d * (eq_lk[lk_counts_per_instance..].iter().sum::() - E::ONE), - ); + let (eq_rw, eq_lk) = eq.split_at(cs.r_table_expressions.len()); - // lk numerator - for i in 0..lk_counts_per_instance { - // \sum_t (sel(rt, t) * (\sum_i alpha_lk_n * eq(rs, i) * lk_n_record[i])) - virtual_polys.add_mle_list(vec![&sel_lk, &lk_n_wit[i]], eq_lk[i] * alpha_lk_n); - } + let mut virtual_polys = + VirtualPolynomials::::new(num_threads, max_log2_num_instance); - let (sel_sumcheck_proofs, state) = IOPProverStateV2::prove_batch_polys( - num_threads, - virtual_polys.get_batched_polys(), - transcript, - ); - let sel_evals = state.get_mle_final_evaluations(); - assert_eq!( - sel_evals.len(), - lk_counts_per_instance * 2 + 1 // 1 for sel_lk - ); - let mut sel_evals_iter = sel_evals.into_iter(); - sel_evals_iter.next(); // skip sel_lk - let lk_d_in_evals = (0..lk_counts_per_instance) - .map(|_| sel_evals_iter.next().unwrap()) - .collect_vec(); - let lk_n_in_evals = (0..lk_counts_per_instance) - .map(|_| sel_evals_iter.next().unwrap()) - .collect_vec(); - assert!(sel_evals_iter.count() == 0); - let input_open_point = sel_sumcheck_proofs.point.clone(); - assert!(input_open_point.len() == log2_num_instances); - exit_span!(span); + // alpha_r{i} * eq(rt_{i}, s) * r(s) + alpha_w{i} * eq(rt_{i}, s) * w(s) + for ((r_set_wit, w_set_wit), eq) in r_set_wit + .iter() + .zip_eq(w_set_wit.iter()) + .zip_eq(eq_rw.iter()) + { + let alpha = alpha_pow_iter.next().unwrap(); + virtual_polys.add_mle_list(vec![eq, r_set_wit], *alpha); + let alpha = alpha_pow_iter.next().unwrap(); + virtual_polys.add_mle_list(vec![eq, w_set_wit], *alpha); + } + + // alpha_lkn{i} * eq(rt_{i}, s) * lk_n(s) + alpha_lkd{i} * eq(rt_{i}, s) * lk_d(s) + for ((lk_n_wit, lk_d_wit), eq) in + lk_n_wit.iter().zip_eq(lk_d_wit.iter()).zip_eq(eq_lk.iter()) + { + let alpha = alpha_pow_iter.next().unwrap(); + virtual_polys.add_mle_list(vec![eq, lk_n_wit], *alpha); + let alpha = alpha_pow_iter.next().unwrap(); + virtual_polys.add_mle_list(vec![eq, lk_d_wit], *alpha); + } + + let (same_r_sumcheck_proofs, state) = IOPProverStateV2::prove_batch_polys( + num_threads, + virtual_polys.get_batched_polys(), + transcript, + ); + let evals = state.get_mle_final_evaluations(); + let mut evals_iter = evals.into_iter(); + let rw_in_evals = cs + // r, w table len are identical + .r_table_expressions + .iter() + .flat_map(|_table| { + let _eq = evals_iter.next().unwrap(); // skip eq + [evals_iter.next().unwrap(), evals_iter.next().unwrap()] // r, w + }) + .collect_vec(); + let lk_in_evals = cs + .lk_table_expressions + .iter() + .flat_map(|_table| { + let _eq = evals_iter.next().unwrap(); // skip eq + [evals_iter.next().unwrap(), evals_iter.next().unwrap()] // n, d + }) + .collect_vec(); + assert_eq!(evals_iter.count(), 0); + + let input_open_point = same_r_sumcheck_proofs.point.clone(); + assert_eq!(input_open_point.len(), max_log2_num_instance); + exit_span!(span); + + ( + input_open_point, + Some(same_r_sumcheck_proofs.proofs), + rw_in_evals, + lk_in_evals, + ) + }; let span = entered_span!("fixed::evals + witin::evals"); let mut evals = witnesses .par_iter() .chain(fixed.par_iter()) - .map(|poly| poly.evaluate(&input_open_point)) + .map(|poly| poly.evaluate(&input_open_point[..poly.num_vars()])) .collect::>(); let fixed_in_evals = evals.split_off(witnesses.len()); let wits_in_evals = evals; @@ -834,15 +1014,13 @@ impl> ZKVMProver { ); Ok(ZKVMTableProof { - num_instances, - lk_p1_out_eval, - lk_p2_out_eval, - lk_q1_out_eval, - lk_q2_out_eval, + r_out_evals, + w_out_evals, + lk_out_evals, + same_r_sumcheck_proofs, + rw_in_evals, + lk_in_evals, tower_proof, - sel_sumcheck_proofs: sel_sumcheck_proofs.proofs, - lk_d_in_evals, - lk_n_in_evals, fixed_in_evals, fixed_opening_proof, wits_in_evals, @@ -889,7 +1067,6 @@ impl TowerProofs { /// Tower Prover impl TowerProver { pub fn create_proof<'a, E: ExtensionField>( - max_threads: usize, prod_specs: Vec>, logup_specs: Vec>, num_fanin: usize, @@ -924,7 +1101,7 @@ impl TowerProver { let (next_rt, _) = (1..=max_round_index).fold((initial_rt, alpha_pows), |(out_rt, alpha_pows), round| { // in first few round we just run on single thread - let num_threads = proper_num_threads(out_rt.len(), max_threads); + let num_threads = optimal_sumcheck_threads(out_rt.len()); let eq: ArcMultilinearExtension = build_eq_x_r_vec(&out_rt).into_mle().into(); let mut virtual_polys = VirtualPolynomials::::new(num_threads, out_rt.len()); diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index a87f58585..f928a8918 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -15,6 +15,7 @@ use transcript::Transcript; use crate::{ circuit_builder::CircuitBuilder, + declare_program, error::ZKVMError, expression::{Expression, ToExpr, WitIn}, instructions::{ @@ -140,7 +141,6 @@ fn test_rw_lk_expression_combination() { commit, &[], num_instances, - 1, &mut transcript, &prover_challenges, ) @@ -178,16 +178,24 @@ fn test_rw_lk_expression_combination() { test_rw_lk_expression_combination_inner::<17, 61>(); } +const PROGRAM_SIZE: usize = 4; #[allow(clippy::unusual_byte_groupings)] const ECALL_HALT: u32 = 0b_000000000000_00000_000_00000_1110011; #[allow(clippy::unusual_byte_groupings)] -const PROGRAM_CODE: [u32; 4] = [ - // func7 rs2 rs1 f3 rd opcode - 0b_0000000_00100_00001_000_00100_0110011, // add x4, x4, x1 <=> addi x4, x4, 1 - ECALL_HALT, // ecall halt - ECALL_HALT, // ecall halt - ECALL_HALT, // ecall halt -]; +const PROGRAM_CODE: [u32; PROGRAM_SIZE] = { + let mut program: [u32; PROGRAM_SIZE] = [ECALL_HALT; PROGRAM_SIZE]; + + declare_program!( + program, + // func7 rs2 rs1 f3 rd opcode + 0b_0000000_00100_00001_000_00100_0110011, // add x4, x4, x1 <=> addi x4, x4, 1 + ECALL_HALT, // ecall halt + ECALL_HALT, // ecall halt + ECALL_HALT, // ecall halt + ); + program +}; + #[ignore = "this case is already tested in riscv_example as ecall_halt has only one instance"] #[test] fn test_single_add_instance_e2e() { @@ -202,9 +210,8 @@ fn test_single_add_instance_e2e() { let halt_config = zkvm_cs.register_opcode_circuit::>(); let u16_range_config = zkvm_cs.register_table_circuit::>(); - let prog_config = zkvm_cs.register_table_circuit::>(); + let prog_config = zkvm_cs.register_table_circuit::>(); - let program_code: Vec = PROGRAM_CODE.to_vec(); let mut zkvm_fixed_traces = ZKVMFixedTraces::default(); zkvm_fixed_traces.register_opcode_circuit::>(&zkvm_cs); zkvm_fixed_traces.register_opcode_circuit::>(&zkvm_cs); @@ -215,10 +222,10 @@ fn test_single_add_instance_e2e() { &(), ); - zkvm_fixed_traces.register_table_circuit::>( + zkvm_fixed_traces.register_table_circuit::>( &zkvm_cs, prog_config.clone(), - &program_code, + &PROGRAM_CODE, ); let pk = zkvm_cs @@ -272,13 +279,17 @@ fn test_single_add_instance_e2e() { .assign_table_circuit::>(&zkvm_cs, &u16_range_config, &()) .unwrap(); zkvm_witness - .assign_table_circuit::>(&zkvm_cs, &prog_config, &program_code.len()) + .assign_table_circuit::>( + &zkvm_cs, + &prog_config, + &PROGRAM_CODE.len(), + ) .unwrap(); - let pi = PublicValues::new(0, 0); + let pi = PublicValues::new(0, 0, 0, 0, 0); let transcript = Transcript::new(b"riscv"); let zkvm_proof = prover - .create_proof(zkvm_witness, pi, 1, transcript) + .create_proof(zkvm_witness, pi, transcript) .expect("create_proof failed"); let transcript = Transcript::new(b"riscv"); diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index f692f445b..e09e70574 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -6,7 +6,7 @@ use itertools::Itertools; use multilinear_extensions::{ commutative_op_mle_pair, mle::{DenseMultilinearExtension, FieldType, IntoMLE}, - op_mle, + op_mle_xa_b, op_mle3_range, util::ceil_log2, virtual_poly_v2::ArcMultilinearExtension, }; @@ -89,6 +89,23 @@ pub(crate) fn interleaving_mles_to_mles<'a, E: ExtensionField>( .collect::>>() } +macro_rules! tower_mle_4 { + ($p1:ident, $p2:ident, $q1:ident, $q2:ident, $acc_p:ident, $acc_q:ident, $start_index:ident, $cur_len:ident) => { + $q1[$start_index..][..$cur_len] + .par_iter() + .zip($q2[$start_index..][..$cur_len].par_iter()) + .zip($p1[$start_index..][..$cur_len].par_iter()) + .zip($p2[$start_index..][..$cur_len].par_iter()) + .zip($acc_p.par_iter_mut()) + .zip($acc_q.par_iter_mut()) + .with_min_len(MIN_PAR_SIZE) + .for_each(|(((((q1, q2), p1), p2), p_eval), q_eval)| { + *p_eval = *q1 * p2 + *q2 * p1; + *q_eval = *q1 * q2; + }) + }; +} + /// infer logup witness from last layer /// return is the ([p1,p2], [q1,q2]) for each layer pub(crate) fn infer_tower_logup_witness<'a, E: ExtensionField>( @@ -128,18 +145,13 @@ pub(crate) fn infer_tower_logup_witness<'a, E: ExtensionField>( FieldType::Ext(p2), FieldType::Ext(q1), FieldType::Ext(q2), - ) => q1[start_index..][..cur_len] - .par_iter() - .zip(q2[start_index..][..cur_len].par_iter()) - .zip(p1[start_index..][..cur_len].par_iter()) - .zip(p2[start_index..][..cur_len].par_iter()) - .zip(p_evals.par_iter_mut()) - .zip(q_evals.par_iter_mut()) - .with_min_len(MIN_PAR_SIZE) - .for_each(|(((((q1, q2), p1), p2), p_eval), q_eval)| { - *p_eval = *p2 * q1 + *p1 * q2; - *q_eval = *q1 * q2; - }), + ) => tower_mle_4!(p1, p2, q1, q2, p_evals, q_evals, start_index, cur_len), + ( + FieldType::Base(p1), + FieldType::Base(p2), + FieldType::Ext(q1), + FieldType::Ext(q2), + ) => tower_mle_4!(p1, p2, q1, q2, p_evals, q_evals, start_index, cur_len), _ => unreachable!(), }; } else { @@ -324,21 +336,10 @@ pub(crate) fn wit_infer_by_expr<'a, E: ExtensionField, const N: usize>( }) }, &|x, a, b| { - let a = op_mle!( - |a| { - assert_eq!(a.len(), 1); - a[0] - }, - |a| a.into() - ); - let b = op_mle!( - |b| { - assert_eq!(b.len(), 1); - b[0] - }, - |b| b.into() - ); - op_mle!(|x| { + op_mle_xa_b!(|x, a, b| { + assert_eq!(a.len(), 1); + assert_eq!(b.len(), 1); + let (a, b) = (a[0], b[0]); Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( ceil_log2(x.len()), x.par_iter() @@ -417,10 +418,16 @@ mod tests { virtual_poly_v2::ArcMultilinearExtension, }; - use crate::scheme::utils::{ - infer_tower_logup_witness, infer_tower_product_witness, interleaving_mles_to_mles, + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + expression::{Expression, ToExpr}, + scheme::utils::{ + infer_tower_logup_witness, infer_tower_product_witness, interleaving_mles_to_mles, + }, }; + use super::wit_infer_by_expr; + #[test] fn test_infer_tower_witness() { type E = GoldilocksExt2; @@ -659,4 +666,60 @@ mod tests { ]) ); } + + #[test] + fn test_wit_infer_by_expr_base_field() { + type E = goldilocks::GoldilocksExt2; + type B = goldilocks::Goldilocks; + let mut cs = ConstraintSystem::::new(|| "test"); + let mut cb = CircuitBuilder::new(&mut cs); + let a = cb.create_witin(|| "a").unwrap(); + let b = cb.create_witin(|| "b").unwrap(); + let c = cb.create_witin(|| "c").unwrap(); + + let expr: Expression = a.expr() + b.expr() + a.expr() * b.expr() + (c.expr() * 3 + 2); + + let res = wit_infer_by_expr( + &[], + &[ + vec![B::from(1)].into_mle().into(), + vec![B::from(2)].into_mle().into(), + vec![B::from(3)].into_mle().into(), + ], + &[], + &[], + &expr, + ); + res.get_base_field_vec(); + } + + #[test] + fn test_wit_infer_by_expr_ext_field() { + type E = goldilocks::GoldilocksExt2; + type B = goldilocks::Goldilocks; + let mut cs = ConstraintSystem::::new(|| "test"); + let mut cb = CircuitBuilder::new(&mut cs); + let a = cb.create_witin(|| "a").unwrap(); + let b = cb.create_witin(|| "b").unwrap(); + let c = cb.create_witin(|| "c").unwrap(); + + let expr: Expression = a.expr() + + b.expr() + + a.expr() * b.expr() + + (c.expr() * 3 + 2) + + Expression::Challenge(0, 1, E::ONE, E::ONE); + + let res = wit_infer_by_expr( + &[], + &[ + vec![B::from(1)].into_mle().into(), + vec![B::from(2)].into_mle().into(), + vec![B::from(3)].into_mle().into(), + ], + &[], + &[E::ONE], + &expr, + ); + res.get_ext_field_vec(); + } } diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index cde70b952..e1a406887 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -3,7 +3,7 @@ use std::marker::PhantomData; use ark_std::iterable::Iterable; use ff_ext::ExtensionField; -use itertools::{Itertools, izip}; +use itertools::{Itertools, interleave, izip}; use mpcs::PolynomialCommitmentScheme; use multilinear_extensions::{ mle::{IntoMLE, MultilinearExtension}, @@ -155,18 +155,19 @@ impl> ZKVMVerifier )?; tracing::info!("verified proof for table {}", name); - logup_sum -= table_proof.lk_p1_out_eval * table_proof.lk_q1_out_eval.invert().unwrap(); - logup_sum -= table_proof.lk_p2_out_eval * table_proof.lk_q2_out_eval.invert().unwrap(); + logup_sum = table_proof + .lk_out_evals + .iter() + .fold(logup_sum, |acc, [p1, p2, q1, q2]| { + acc - *p1 * q1.invert().unwrap() - *p2 * q2.invert().unwrap() + }); + + prod_w *= table_proof.w_out_evals.iter().flatten().product::(); + prod_r *= table_proof.r_out_evals.iter().flatten().product::(); } logup_sum -= E::from(dummy_table_item_multiplicity as u64) * dummy_table_item.invert().unwrap(); - // check rw_set equality across all proofs - // TODO: enable this when we have cpu init/finalize and mem init/finalize - // if prod_r != prod_w { - // return Err(ZKVMError::VerifyError("prod_r != prod_w".into())); - // } - // check logup relation across all proofs if logup_sum != E::ZERO { return Err(ZKVMError::VerifyError(format!( @@ -175,6 +176,27 @@ impl> ZKVMVerifier ))); } + let initial_global_state = eval_by_expr_with_instance( + &[], + &[], + pi, + &challenges, + &self.vk.initial_global_state_expr, + ); + prod_w *= initial_global_state; + let finalize_global_state = eval_by_expr_with_instance( + &[], + &[], + pi, + &challenges, + &self.vk.finalize_global_state_expr, + ); + prod_r *= finalize_global_state; + // check rw_set equality across all proofs + if prod_r != prod_w { + return Err(ZKVMError::VerifyError("prod_r != prod_w".into())); + } + Ok(true) } @@ -406,8 +428,7 @@ impl> ZKVMVerifier if cs.assert_zero_expressions.iter().any(|expr| { eval_by_expr_with_instance(&[], &proof.wits_in_evals, pi, challenges, expr) != E::ZERO }) { - // TODO add me back - // return Err(ZKVMError::VerifyError("zero expression != 0")); + return Err(ZKVMError::VerifyError("zero expression != 0".into())); } tracing::debug!( @@ -443,113 +464,196 @@ impl> ZKVMVerifier challenges: &[E; 2], ) -> Result, ZKVMError> { let cs = circuit_vk.get_cs(); - let lk_counts_per_instance = cs.lk_table_expressions.len(); - let log2_lk_count = ceil_log2(lk_counts_per_instance); - - let num_instances = proof.num_instances; - let log2_num_instances = ceil_log2(num_instances); + debug_assert!( + cs.r_table_expressions + .iter() + .zip_eq(cs.w_table_expressions.iter()) + .all(|(r, w)| r.table_len == w.table_len) + ); + let is_skip_same_point_sumcheck = cs + .r_table_expressions + .iter() + .chain(cs.w_table_expressions.iter()) + .map(|rw| rw.table_len) + .chain(cs.lk_table_expressions.iter().map(|lk| lk.table_len)) + .all_equal(); // verify and reduce product tower sumcheck let tower_proofs = &proof.tower_proof; - let expected_max_round = log2_num_instances + log2_lk_count; - let (_, _, logup_p_evals, logup_q_evals) = TowerVerify::verify( - vec![], - vec![vec![ - proof.lk_p1_out_eval, - proof.lk_p2_out_eval, - proof.lk_q1_out_eval, - proof.lk_q2_out_eval, - ]], - tower_proofs, - vec![expected_max_round], - num_logup_fanin, - transcript, - )?; - assert!(logup_q_evals.len() == 1, "[lk_q_record]"); - assert!(logup_p_evals.len() == 1, "[lk_p_record]"); - assert_eq!(logup_p_evals[0].point, logup_q_evals[0].point); - - // verify selector layer sumcheck - let rt_lk: Vec = logup_p_evals[0].point.to_vec(); - - // 2 for denominator and numerator - let alpha_pow = get_challenge_pows(2, transcript); - let mut alpha_pow_iter = alpha_pow.iter(); - let (alpha_lk_d, alpha_lk_n) = ( - alpha_pow_iter.next().unwrap(), - alpha_pow_iter.next().unwrap(), + // TODO probably move expected_max_rounds to verifier key + let expected_rounds = cs + // w_table_expression round match with r_table_expression so check any of them sufficient + .r_table_expressions + .iter() + .flat_map(|r| { + let num_vars = ceil_log2(r.table_len); + [num_vars, num_vars] + }) + .chain( + cs.lk_table_expressions + .iter() + .map(|l| ceil_log2(l.table_len)), + ) + .collect_vec(); + let expected_max_rounds = expected_rounds.iter().cloned().max().unwrap(); + let (rt_tower, prod_point_and_eval, logup_p_point_and_eval, logup_q_point_and_eval) = + TowerVerify::verify( + proof + .r_out_evals + .iter() + .zip(proof.w_out_evals.iter()) + .flat_map(|(r_evals, w_evals)| [r_evals.to_vec(), w_evals.to_vec()]) + .collect_vec(), + proof + .lk_out_evals + .iter() + .map(|eval| eval.to_vec()) + .collect_vec(), + tower_proofs, + expected_rounds, + num_logup_fanin, + transcript, + )?; + assert_eq!( + logup_q_point_and_eval.len(), + cs.lk_table_expressions.len(), + "[lk_q_record] mismatch length" ); - // alpha_lk * (out_lk_q - one) + alpha_lk_n * out_lk_p - let claim_sum = - *alpha_lk_d * (logup_q_evals[0].eval - E::ONE) + *alpha_lk_n * logup_p_evals[0].eval; - let sel_subclaim = IOPVerifierState::verify( - claim_sum, - &IOPProof { - point: vec![], // final claimed point will be derived from sumcheck protocol - proofs: proof.sel_sumcheck_proofs.clone(), - }, - &VPAuxInfo { - max_degree: SEL_DEGREE.max(cs.max_non_lc_degree), - num_variables: log2_num_instances, - phantom: PhantomData, - }, - transcript, + assert_eq!( + logup_p_point_and_eval.len(), + cs.lk_table_expressions.len(), + "[lk_p_record] mismatch length" ); - let (input_opening_point, expected_evaluation) = ( - sel_subclaim.point.iter().map(|c| c.elements).collect_vec(), - sel_subclaim.expected_evaluation, + assert_eq!( + prod_point_and_eval.len(), + cs.r_table_expressions.len() + cs.w_table_expressions.len(), + "[prod_record] mismatch length" ); - let eq_lk = build_eq_x_r_vec_sequential(&rt_lk[..log2_lk_count]); - let sel_lk = eq_eval_less_or_equal_than( - num_instances - 1, - &rt_lk[log2_lk_count..], - &input_opening_point, - ); + let (input_opening_point, in_evals) = if is_skip_same_point_sumcheck { + ( + rt_tower, + prod_point_and_eval + .into_iter() + .chain( + logup_p_point_and_eval + .into_iter() + .zip_eq(logup_q_point_and_eval) + .flat_map(|(p_point_and_eval, q_point_and_eval)| { + [p_point_and_eval, q_point_and_eval] + }), + ) + .map(|point_and_eval| point_and_eval.eval) + .collect_vec(), + ) + } else { + assert!(proof.same_r_sumcheck_proofs.is_some()); + + // verify opening same point layer sumcheck + let alpha_pow = get_challenge_pows( + cs.r_table_expressions.len() + + cs.w_table_expressions.len() + + cs.lk_table_expressions.len() * 2, // 2 for lk numerator and denominator + transcript, + ); - let computed_evals = [ - // lookup denominator - *alpha_lk_d - * sel_lk - * ((0..lk_counts_per_instance) - .map(|i| proof.lk_d_in_evals[i] * eq_lk[i]) - .sum::() - + (eq_lk[lk_counts_per_instance..].iter().sum::() - E::ONE)), - *alpha_lk_n - * sel_lk - * ((0..lk_counts_per_instance) - .map(|i| proof.lk_n_in_evals[i] * eq_lk[i]) - .sum::()), - ] - .iter() - .sum::(); - if computed_evals != expected_evaluation { - return Err(ZKVMError::VerifyError( - "sel evaluation verify failed".into(), - )); - } - // verify records (degree = 1) statement, thus no sumcheck - if cs - .lk_table_expressions + // \sum_i alpha_{i} * (out_r_eval{i}) + // + \sum_i alpha_{i} * (out_w_eval{i}) + // + \sum_i alpha_{i} * (out_lk_n{i}) + // + \sum_i alpha_{i} * (out_lk_d{i}) + let claim_sum = prod_point_and_eval + .iter() + .zip(alpha_pow.iter()) + .map(|(point_and_eval, alpha)| *alpha * point_and_eval.eval) + .sum::() + + interleave(&logup_p_point_and_eval, &logup_q_point_and_eval) + .zip_eq(alpha_pow.iter().skip(prod_point_and_eval.len())) + .map(|(point_n_eval, alpha)| *alpha * point_n_eval.eval) + .sum::(); + let sel_subclaim = IOPVerifierState::verify( + claim_sum, + &IOPProof { + point: vec![], // final claimed point will be derived from sumcheck protocol + proofs: proof.same_r_sumcheck_proofs.clone().unwrap(), + }, + &VPAuxInfo { + max_degree: SEL_DEGREE, + num_variables: expected_max_rounds, + phantom: PhantomData, + }, + transcript, + ); + let (input_opening_point, expected_evaluation) = ( + sel_subclaim.point.iter().map(|c| c.elements).collect_vec(), + sel_subclaim.expected_evaluation, + ); + + let computed_evals = [ + // r, w + prod_point_and_eval + .into_iter() + .zip_eq(proof.rw_in_evals.iter()) + .zip(alpha_pow.iter()) + .map(|((point_and_eval, in_eval), alpha)| { + let eq = eq_eval( + &point_and_eval.point, + &input_opening_point[0..point_and_eval.point.len()], + ); + // TODO times multiplication factor + *alpha * eq * in_eval + }) + .sum::(), + interleave(logup_p_point_and_eval, logup_q_point_and_eval) + .zip_eq(proof.lk_in_evals.iter()) + .zip_eq( + alpha_pow + .iter() + .skip(cs.r_table_expressions.len() + cs.w_table_expressions.len()), + ) + .map(|((point_and_eval, in_eval), alpha)| { + let eq = eq_eval( + &point_and_eval.point, + &input_opening_point[0..point_and_eval.point.len()], + ); + // TODO times multiplication factor + *alpha * eq * in_eval + }) + .sum::(), + ] .iter() - .map(|lk| &lk.values) - .chain(cs.lk_table_expressions.iter().map(|lk| &lk.multiplicity)) - .zip_eq( - proof.lk_d_in_evals[..lk_counts_per_instance] - .iter() - .chain(proof.lk_n_in_evals[..lk_counts_per_instance].iter()), + .sum::(); + if computed_evals != expected_evaluation { + return Err(ZKVMError::VerifyError( + "sel evaluation verify failed".into(), + )); + } + ( + input_opening_point, + [proof.rw_in_evals.to_vec(), proof.lk_in_evals.to_vec()].concat(), ) - .any(|(expr, expected_evals)| { - eval_by_expr_with_instance( - &proof.fixed_in_evals, - &proof.wits_in_evals, - pi, - challenges, - expr, - ) != *expected_evals - }) - { + }; + // verify records (degree = 1) statement, thus no sumcheck + if interleave( + &cs.r_table_expressions, // r + &cs.w_table_expressions, // w + ) + .map(|rw| &rw.values) + .chain( + cs.lk_table_expressions + .iter() + .flat_map(|lk| vec![&lk.multiplicity, &lk.values]), // p, q + ) + .zip_eq(in_evals) + .any(|(expr, expected_evals)| { + eval_by_expr_with_instance( + &proof.fixed_in_evals, + &proof.wits_in_evals, + pi, + challenges, + expr, + ) != expected_evals + }) { return Err(ZKVMError::VerifyError( "record evaluate != expected_evals".into(), )); @@ -612,7 +716,7 @@ impl TowerVerify { prod_out_evals: Vec>, logup_out_evals: Vec>, tower_proofs: &TowerProofs, - expected_rounds: Vec, + num_variables: Vec, num_fanin: usize, transcript: &mut Transcript, ) -> TowerVerifyResult { @@ -630,7 +734,7 @@ impl TowerVerify { assert!(logup_out_evals.iter().all(|evals| { evals.len() == 4 // [p1, p2, q1, q2] })); - assert_eq!(expected_rounds.len(), num_prod_spec + num_logup_spec); + assert_eq!(num_variables.len(), num_prod_spec + num_logup_spec); let alpha_pows = get_challenge_pows( num_prod_spec + num_logup_spec * 2, /* logup occupy 2 sumcheck: numerator and denominator */ @@ -660,9 +764,9 @@ impl TowerVerify { let mut logup_spec_p_input_layer_eval = vec![PointAndEval::default(); num_logup_spec]; let mut logup_spec_q_input_layer_eval = vec![PointAndEval::default(); num_logup_spec]; - let expected_max_round = expected_rounds.iter().max().unwrap(); + let max_num_variables = num_variables.iter().max().unwrap(); - let (next_rt, _) = (0..(expected_max_round-1)).try_fold( + let (next_rt, _) = (0..(max_num_variables-1)).try_fold( ( PointAndEval { point: initial_rt, @@ -685,13 +789,13 @@ impl TowerVerify { }, transcript, ); - tracing::debug!("verified tower proof at layer {}/{}", round + 1, expected_max_round-1); + tracing::debug!("verified tower proof at layer {}/{}", round + 1, max_num_variables-1); // check expected_evaluation let rt: Point = sumcheck_claim.point.iter().map(|c| c.elements).collect(); let expected_evaluation: E = (0..num_prod_spec) .zip(alpha_pows.iter()) - .zip(expected_rounds.iter()) + .zip(num_variables.iter()) .map(|((spec_index, alpha), max_round)| { eq_eval(out_rt, &rt) * alpha @@ -702,7 +806,7 @@ impl TowerVerify { .sum::() + (0..num_logup_spec) .zip_eq(alpha_pows[num_prod_spec..].chunks(2)) - .zip_eq(expected_rounds[num_prod_spec..].iter()) + .zip_eq(num_variables[num_prod_spec..].iter()) .map(|((spec_index, alpha), max_round)| { let (alpha_numerator, alpha_denominator) = (&alpha[0], &alpha[1]); eq_eval(out_rt, &rt) * if round < *max_round-1 { @@ -738,7 +842,7 @@ impl TowerVerify { let next_round = round + 1; let next_prod_spec_evals = (0..num_prod_spec) .zip(next_alpha_pows.iter()) - .zip(expected_rounds.iter()) + .zip(num_variables.iter()) .map(|((spec_index, alpha), max_round)| { if round < max_round -1 { // merged evaluation @@ -762,7 +866,7 @@ impl TowerVerify { .sum::(); let next_logup_spec_evals = (0..num_logup_spec) .zip_eq(next_alpha_pows[num_prod_spec..].chunks(2)) - .zip_eq(expected_rounds[num_prod_spec..].iter()) + .zip_eq(num_variables[num_prod_spec..].iter()) .map(|((spec_index, alpha), max_round)| { if round < max_round -1 { let (alpha_numerator, alpha_denominator) = (&alpha[0], &alpha[1]); diff --git a/ceno_zkvm/src/state.rs b/ceno_zkvm/src/state.rs new file mode 100644 index 000000000..875e8fbfb --- /dev/null +++ b/ceno_zkvm/src/state.rs @@ -0,0 +1,45 @@ +use ff_ext::ExtensionField; + +use crate::{ + circuit_builder::CircuitBuilder, + error::ZKVMError, + expression::{Expression, ToExpr}, + structs::RAMType, +}; + +pub trait StateCircuit { + fn initial_global_state( + circuit_builder: &mut CircuitBuilder, + ) -> Result, ZKVMError>; + fn finalize_global_state( + circuit_builder: &mut CircuitBuilder, + ) -> Result, ZKVMError>; +} + +pub struct GlobalState; + +impl StateCircuit for GlobalState { + fn initial_global_state( + circuit_builder: &mut crate::circuit_builder::CircuitBuilder, + ) -> Result, ZKVMError> { + let states: Vec> = vec![ + Expression::Constant(E::BaseField::from(RAMType::GlobalState as u64)), + circuit_builder.query_init_pc()?.expr(), + circuit_builder.query_init_cycle()?.expr(), + ]; + + Ok(circuit_builder.rlc_chip_record(states)) + } + + fn finalize_global_state( + circuit_builder: &mut crate::circuit_builder::CircuitBuilder, + ) -> Result, crate::error::ZKVMError> { + let states: Vec> = vec![ + Expression::Constant(E::BaseField::from(RAMType::GlobalState as u64)), + circuit_builder.query_end_pc()?.expr(), + circuit_builder.query_end_cycle()?.expr(), + ]; + + Ok(circuit_builder.rlc_chip_record(states)) + } +} diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index b89a8ecdb..a9b78f416 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -1,7 +1,9 @@ use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, error::ZKVMError, + expression::Expression, instructions::Instruction, + state::StateCircuit, tables::TableCircuit, witness::{LkMultiplicity, RowMajorMatrix}, }; @@ -12,13 +14,14 @@ use mpcs::PolynomialCommitmentScheme; use multilinear_extensions::{ mle::DenseMultilinearExtension, virtual_poly_v2::ArcMultilinearExtension, }; -use serde::Serialize; +use serde::{Deserialize, Serialize}; use std::collections::{BTreeMap, HashMap}; +use strum_macros::EnumIter; use sumcheck::structs::IOPProverMessage; pub struct TowerProver; -#[derive(Clone, Serialize)] +#[derive(Clone, Serialize, Deserialize)] pub struct TowerProofs { pub proofs: Vec>>, // specs -> layers -> evals @@ -40,10 +43,11 @@ pub struct TowerProverSpec<'a, E: ExtensionField> { pub type WitnessId = u16; pub type ChallengeId = u16; -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, EnumIter)] pub enum ROMType { U5 = 0, // 2^5 = 32 U8, // 2^8 = 256 + U14, // 2^14 = 16,384 U16, // 2^16 = 65,536 And, // a & b where a, b are bytes Or, // a | b where a, b are bytes @@ -121,9 +125,21 @@ impl> VerifyingKey } } -#[derive(Default, Clone)] +#[derive(Clone)] pub struct ZKVMConstraintSystem { pub(crate) circuit_css: BTreeMap>, + pub(crate) initial_global_state_expr: Expression, + pub(crate) finalize_global_state_expr: Expression, +} + +impl Default for ZKVMConstraintSystem { + fn default() -> Self { + ZKVMConstraintSystem { + circuit_css: BTreeMap::new(), + initial_global_state_expr: Expression::ZERO, + finalize_global_state_expr: Expression::ZERO, + } + } } impl ZKVMConstraintSystem { @@ -140,11 +156,20 @@ impl ZKVMConstraintSystem { let mut cs = ConstraintSystem::new(|| format!("riscv_table/{}", TC::name())); let mut circuit_builder = CircuitBuilder::::new(&mut cs); let config = TC::construct_circuit(&mut circuit_builder).unwrap(); - assert!(self.circuit_css.insert(TC::name(), cs.clone()).is_none()); + assert!(self.circuit_css.insert(TC::name(), cs).is_none()); config } + pub fn register_global_state>(&mut self) { + let mut cs = ConstraintSystem::new(|| "riscv_state"); + let mut circuit_builder = CircuitBuilder::::new(&mut cs); + self.initial_global_state_expr = + SC::initial_global_state(&mut circuit_builder).expect("global_state_in failed"); + self.finalize_global_state_expr = + SC::finalize_global_state(&mut circuit_builder).expect("global_state_out failed"); + } + pub fn get_cs(&self, name: &String) -> Option<&ConstraintSystem> { self.circuit_css.get(name) } @@ -260,6 +285,10 @@ pub struct ZKVMProvingKey> pub vp: PCS::VerifierParam, // pk for opcode and table circuits pub circuit_pks: BTreeMap>, + + // expression for global state in/out + pub initial_global_state_expr: Expression, + pub finalize_global_state_expr: Expression, } impl> ZKVMProvingKey { @@ -268,6 +297,8 @@ impl> ZKVMProvingKey> ZKVMProvingKey> { pub vp: PCS::VerifierParam, - // pk for opcode and table circuits + // vk for opcode and table circuits pub circuit_vks: BTreeMap>, + // expression for global state in/out + pub initial_global_state_expr: Expression, + pub finalize_global_state_expr: Expression, } diff --git a/ceno_zkvm/src/tables/mod.rs b/ceno_zkvm/src/tables/mod.rs index 138bcb546..2ef7e293a 100644 --- a/ceno_zkvm/src/tables/mod.rs +++ b/ceno_zkvm/src/tables/mod.rs @@ -1,7 +1,11 @@ -use crate::{circuit_builder::CircuitBuilder, error::ZKVMError, witness::RowMajorMatrix}; +use crate::{ + circuit_builder::CircuitBuilder, error::ZKVMError, scheme::constants::MIN_PAR_SIZE, + witness::RowMajorMatrix, +}; +use ff::Field; use ff_ext::ExtensionField; -use std::collections::HashMap; - +use rayon::iter::{IndexedParallelIterator, ParallelIterator}; +use std::{collections::HashMap, mem::MaybeUninit}; mod range; pub use range::*; @@ -11,6 +15,9 @@ pub use ops::*; mod program; pub use program::{InsnRecord, ProgramTableCircuit}; +mod ram; +pub use ram::*; + pub trait TableCircuit { type TableConfig: Send + Sync; type FixedInput: Send + Sync + ?Sized; @@ -34,4 +41,30 @@ pub trait TableCircuit { multiplicity: &[HashMap], input: &Self::WitnessInput, ) -> Result, ZKVMError>; + + fn padding_zero( + table: &mut RowMajorMatrix, + num_witin: usize, + ) -> Result<(), ZKVMError> { + // Fill the padding with zeros, if any. + let num_padding_instances = table.num_padding_instances(); + if num_padding_instances > 0 { + let nthreads = + std::env::var("RAYON_NUM_THREADS").map_or(8, |s| s.parse::().unwrap_or(8)); + let padding_instance = vec![MaybeUninit::new(E::BaseField::ZERO); num_witin]; + let num_padding_instance_per_batch = if num_padding_instances > 256 { + num_padding_instances.div_ceil(nthreads) + } else { + num_padding_instances + }; + table + .par_batch_iter_padding_mut(num_padding_instance_per_batch) + .with_min_len(MIN_PAR_SIZE) + .for_each(|row| { + row.chunks_mut(num_witin) + .for_each(|instance| instance.copy_from_slice(padding_instance.as_slice())); + }); + } + Ok(()) + } } diff --git a/ceno_zkvm/src/tables/ops/ops_circuit.rs b/ceno_zkvm/src/tables/ops/ops_circuit.rs index 0a0026661..cd48ebe19 100644 --- a/ceno_zkvm/src/tables/ops/ops_circuit.rs +++ b/ceno_zkvm/src/tables/ops/ops_circuit.rs @@ -42,7 +42,7 @@ impl TableCircuit for OpsTableCircuit fn construct_circuit(cb: &mut CircuitBuilder) -> Result { cb.namespace( || Self::name(), - |cb| OpTableConfig::construct_circuit(cb, OP::ROM_TYPE), + |cb| OpTableConfig::construct_circuit(cb, OP::ROM_TYPE, OP::len()), ) } @@ -51,7 +51,9 @@ impl TableCircuit for OpsTableCircuit num_fixed: usize, _input: &(), ) -> RowMajorMatrix { - config.generate_fixed_traces(num_fixed, OP::content()) + let mut table = config.generate_fixed_traces(num_fixed, OP::content()); + Self::padding_zero(&mut table, num_fixed).expect("padding error"); + table } fn assign_instances( @@ -61,6 +63,8 @@ impl TableCircuit for OpsTableCircuit _input: &(), ) -> Result, ZKVMError> { let multiplicity = &multiplicity[OP::ROM_TYPE as usize]; - config.assign_instances(num_witin, multiplicity, OP::len()) + let mut table = config.assign_instances(num_witin, multiplicity, OP::len())?; + Self::padding_zero(&mut table, num_witin)?; + Ok(table) } } diff --git a/ceno_zkvm/src/tables/ops/ops_impl.rs b/ceno_zkvm/src/tables/ops/ops_impl.rs index 937e3fa2e..f9bf4d0c8 100644 --- a/ceno_zkvm/src/tables/ops/ops_impl.rs +++ b/ceno_zkvm/src/tables/ops/ops_impl.rs @@ -25,6 +25,7 @@ impl OpTableConfig { pub fn construct_circuit( cb: &mut CircuitBuilder, rom_type: ROMType, + table_len: usize, ) -> Result { let abc = [ cb.create_fixed(|| "a")?, @@ -40,7 +41,7 @@ impl OpTableConfig { Expression::Fixed(abc[2]), ]); - cb.lk_table_record(|| "record", rlc_record, mlt.expr())?; + cb.lk_table_record(|| "record", table_len, rlc_record, mlt.expr())?; Ok(Self { abc, mlt }) } @@ -52,13 +53,6 @@ impl OpTableConfig { ) -> RowMajorMatrix { let mut fixed = RowMajorMatrix::::new(content.len(), num_fixed); - // Fill the padding with zeros, if any. - fixed.par_iter_mut().skip(content.len()).for_each(|row| { - for col in &self.abc { - set_fixed_val!(row, *col, F::ZERO); - } - }); - fixed .par_iter_mut() .with_min_len(MIN_PAR_SIZE) @@ -85,11 +79,6 @@ impl OpTableConfig { mlts[*idx as usize] = *mlt; } - // Fill the padding with zeros, if any. - witness.par_iter_mut().skip(length).for_each(|row| { - set_val!(row, self.mlt, F::ZERO); - }); - witness .par_iter_mut() .with_min_len(MIN_PAR_SIZE) diff --git a/ceno_zkvm/src/tables/program.rs b/ceno_zkvm/src/tables/program.rs index 033c469ee..3514365c8 100644 --- a/ceno_zkvm/src/tables/program.rs +++ b/ceno_zkvm/src/tables/program.rs @@ -10,12 +10,26 @@ use crate::{ tables::TableCircuit, witness::RowMajorMatrix, }; -use ceno_emul::{CENO_PLATFORM, DecodedInstruction, PC_STEP_SIZE, WORD_SIZE, Word}; +use ceno_emul::{CENO_PLATFORM, DecodedInstruction, PC_STEP_SIZE, WORD_SIZE}; use ff_ext::ExtensionField; use goldilocks::SmallField; use itertools::Itertools; use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}; +#[macro_export] +macro_rules! declare_program { + ($program:ident, $($instr:expr),* $(,)?) => { + + { + let mut _i = 0; + $( + $program[_i] = $instr; + _i += 1; + )* + } + }; +} + #[derive(Clone, Debug)] pub struct InsnRecord([T; 7]); @@ -97,11 +111,13 @@ pub struct ProgramTableConfig { mlt: WitIn, } -pub struct ProgramTableCircuit(PhantomData); +pub struct ProgramTableCircuit(PhantomData); -impl TableCircuit for ProgramTableCircuit { +impl TableCircuit + for ProgramTableCircuit +{ type TableConfig = ProgramTableConfig; - type FixedInput = [u32]; + type FixedInput = [u32; PROGRAM_SIZE]; type WitnessInput = usize; fn name() -> String { @@ -127,7 +143,7 @@ impl TableCircuit for ProgramTableCircuit { cb.rlc_chip_record(fields) }; - cb.lk_table_record(|| "prog table", record_exprs, mlt.expr())?; + cb.lk_table_record(|| "prog table", PROGRAM_SIZE, record_exprs, mlt.expr())?; Ok(ProgramTableConfig { record, mlt }) } @@ -135,7 +151,7 @@ impl TableCircuit for ProgramTableCircuit { fn generate_fixed_traces( config: &ProgramTableConfig, num_fixed: usize, - program: &[Word], + program: &Self::FixedInput, ) -> RowMajorMatrix { // TODO: get bytecode of the program. let num_instructions = program.len(); diff --git a/ceno_zkvm/src/tables/ram.rs b/ceno_zkvm/src/tables/ram.rs new file mode 100644 index 000000000..decd52608 --- /dev/null +++ b/ceno_zkvm/src/tables/ram.rs @@ -0,0 +1,57 @@ +use ceno_emul::{Addr, CENO_PLATFORM, WORD_SIZE, Word}; +use ram_circuit::RamTableCircuit; + +use crate::{instructions::riscv::constants::UINT_LIMBS, structs::RAMType}; + +mod ram_circuit; +mod ram_impl; +use ram_circuit::RamTable; +pub use ram_circuit::{MemFinalRecord, MemInitRecord}; + +#[derive(Clone)] +pub struct MemTable; + +impl RamTable for MemTable { + const RAM_TYPE: RAMType = RAMType::Memory; + const V_LIMBS: usize = 1; // See `MemoryExpr`. + + fn len() -> usize { + // TODO figure out better way to define memory entry count + 1 << 10 + } + + fn addr(entry_index: usize) -> Addr { + CENO_PLATFORM.ram_start() + (entry_index * WORD_SIZE) as Addr + } +} +pub type MemTableCircuit = RamTableCircuit; + +#[derive(Clone)] +pub struct RegTable; + +impl RamTable for RegTable { + const RAM_TYPE: RAMType = RAMType::Register; + const V_LIMBS: usize = UINT_LIMBS; // See `RegisterExpr`. + + fn len() -> usize { + 32 // register size 32 + } + + fn addr(entry_index: usize) -> Addr { + entry_index as Addr + } +} + +pub type RegTableCircuit = RamTableCircuit; + +pub fn initial_registers() -> Vec { + RegTable::init_state() +} + +pub fn initial_memory(ram_content: &[Word]) -> Vec { + let mut mem_init = MemTable::init_state(); + for (i, value) in ram_content.iter().enumerate() { + mem_init[i].value = *value; + } + mem_init +} diff --git a/ceno_zkvm/src/tables/ram/ram_circuit.rs b/ceno_zkvm/src/tables/ram/ram_circuit.rs new file mode 100644 index 000000000..d384d751f --- /dev/null +++ b/ceno_zkvm/src/tables/ram/ram_circuit.rs @@ -0,0 +1,83 @@ +use std::{collections::HashMap, marker::PhantomData}; + +use ceno_emul::{Addr, Cycle, Word}; +use ff_ext::ExtensionField; + +use crate::{ + circuit_builder::CircuitBuilder, error::ZKVMError, structs::RAMType, tables::TableCircuit, + witness::RowMajorMatrix, +}; + +use super::ram_impl::RamTableConfig; + +#[derive(Clone, Debug)] +pub struct MemInitRecord { + pub addr: Addr, + pub value: Word, +} + +pub struct MemFinalRecord { + pub cycle: Cycle, + pub value: Word, +} + +/// Impl trait as parameter to RamTableCircuit. +pub trait RamTable { + const RAM_TYPE: RAMType; + const V_LIMBS: usize; + + fn len() -> usize; + + fn addr(entry_index: usize) -> Addr; + + fn init_state() -> Vec { + (0..Self::len()) + .map(|i| MemInitRecord { + addr: Self::addr(i), + value: 0, + }) + .collect() + } +} + +pub struct RamTableCircuit(PhantomData<(E, R)>); + +impl TableCircuit + for RamTableCircuit +{ + type TableConfig = RamTableConfig; + type FixedInput = [MemInitRecord]; + type WitnessInput = [MemFinalRecord]; + + fn name() -> String { + format!("RAM_{:?}", RAM::RAM_TYPE) + } + + fn construct_circuit(cb: &mut CircuitBuilder) -> Result { + cb.namespace( + || Self::name(), + |cb| Self::TableConfig::construct_circuit(cb), + ) + } + + fn generate_fixed_traces( + config: &Self::TableConfig, + num_fixed: usize, + init_v: &Self::FixedInput, + ) -> RowMajorMatrix { + let mut table = config.gen_init_state(num_fixed, init_v); + Self::padding_zero(&mut table, num_fixed).expect("padding error"); + table + } + + fn assign_instances( + config: &Self::TableConfig, + num_witin: usize, + _multiplicity: &[HashMap], + final_v: &Self::WitnessInput, + ) -> Result, ZKVMError> { + let mut table = config.assign_instances(num_witin, final_v)?; + Self::padding_zero(&mut table, num_witin)?; + Ok(table) + } +} diff --git a/ceno_zkvm/src/tables/ram/ram_impl.rs b/ceno_zkvm/src/tables/ram/ram_impl.rs new file mode 100644 index 000000000..a4d263123 --- /dev/null +++ b/ceno_zkvm/src/tables/ram/ram_impl.rs @@ -0,0 +1,140 @@ +use std::{marker::PhantomData, mem::MaybeUninit}; + +use ff_ext::ExtensionField; +use goldilocks::SmallField; +use itertools::Itertools; +use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}; + +use crate::{ + circuit_builder::CircuitBuilder, + error::ZKVMError, + expression::{Expression, Fixed, ToExpr, WitIn}, + instructions::riscv::constants::{LIMB_BITS, LIMB_MASK}, + scheme::constants::MIN_PAR_SIZE, + set_fixed_val, set_val, + witness::RowMajorMatrix, +}; + +use super::{ + MemInitRecord, + ram_circuit::{MemFinalRecord, RamTable}, +}; + +#[derive(Clone, Debug)] +pub struct RamTableConfig { + init_v: Vec, + addr: Fixed, + + final_v: Vec, + final_cycle: WitIn, + + phantom: PhantomData, +} + +impl RamTableConfig { + pub fn construct_circuit( + cb: &mut CircuitBuilder, + ) -> Result { + let init_v = (0..RAM::V_LIMBS) + .map(|i| cb.create_fixed(|| format!("init_v_limb_{i}"))) + .collect::, ZKVMError>>()?; + let addr = cb.create_fixed(|| "addr")?; + + let final_v = (0..RAM::V_LIMBS) + .map(|i| cb.create_witin(|| format!("final_v_limb_{i}"))) + .collect::, ZKVMError>>()?; + let final_cycle = cb.create_witin(|| "final_cycle")?; + + let init_table = cb.rlc_chip_record( + [ + vec![(RAM::RAM_TYPE as usize).into()], + vec![Expression::Fixed(addr)], + init_v.iter().map(|v| v.expr()).collect_vec(), + vec![Expression::ZERO], // Initial cycle. + ] + .concat(), + ); + + let final_table = cb.rlc_chip_record( + [ + // a v t + vec![(RAM::RAM_TYPE as usize).into()], + vec![Expression::Fixed(addr)], + final_v.iter().map(|v| v.expr()).collect_vec(), + vec![final_cycle.expr()], + ] + .concat(), + ); + + cb.w_table_record(|| "init_table", RAM::len(), init_table)?; + cb.r_table_record(|| "final_table", RAM::len(), final_table)?; + + Ok(Self { + init_v, + addr, + final_v, + final_cycle, + phantom: PhantomData, + }) + } + + pub fn gen_init_state( + &self, + num_fixed: usize, + init_v: &[MemInitRecord], + ) -> RowMajorMatrix { + assert_eq!(init_v.len(), RAM::len()); + // for ram in memory offline check + let mut init_table = RowMajorMatrix::::new(RAM::len(), num_fixed); + + init_table + .par_iter_mut() + .with_min_len(MIN_PAR_SIZE) + .zip(init_v.into_par_iter()) + .for_each(|(row, rec)| { + if self.init_v.len() == 1 { + // Assign value directly. + set_fixed_val!(row, self.init_v[0], (rec.value as u64).into()); + } else { + // Assign value limbs. + self.init_v.iter().enumerate().for_each(|(l, limb)| { + let val = (rec.value >> (l * LIMB_BITS)) & LIMB_MASK; + set_fixed_val!(row, limb, (val as u64).into()); + }); + } + set_fixed_val!(row, self.addr, (rec.addr as u64).into()); + }); + + init_table + } + + /// TODO consider taking RowMajorMatrix as argument to save allocations. + pub fn assign_instances( + &self, + num_witness: usize, + final_v: &[MemFinalRecord], + ) -> Result, ZKVMError> { + assert_eq!(final_v.len(), RAM::len()); + let mut final_table = RowMajorMatrix::::new(RAM::len(), num_witness); + + final_table + .par_iter_mut() + .with_min_len(MIN_PAR_SIZE) + .zip(final_v.into_par_iter()) + .for_each(|(row, rec)| { + if self.final_v.len() == 1 { + // Assign value directly. + set_fixed_val!(row, self.init_v[0], (rec.value as u64).into()); + } else { + // Assign value limbs. + self.final_v.iter().enumerate().for_each(|(l, limb)| { + let val = (rec.value >> (l * LIMB_BITS)) & LIMB_MASK; + set_val!(row, limb, val as u64); + }); + } + set_val!(row, self.final_cycle, rec.cycle); + }); + + Ok(final_table) + } +} diff --git a/ceno_zkvm/src/tables/range.rs b/ceno_zkvm/src/tables/range.rs index cdbf5aca1..319f3c850 100644 --- a/ceno_zkvm/src/tables/range.rs +++ b/ceno_zkvm/src/tables/range.rs @@ -25,6 +25,15 @@ impl RangeTable for U8Table { } pub type U8TableCircuit = RangeTableCircuit; +pub struct U14Table; +impl RangeTable for U14Table { + const ROM_TYPE: ROMType = ROMType::U14; + fn len() -> usize { + 1 << 14 + } +} +pub type U14TableCircuit = RangeTableCircuit; + pub struct U16Table; impl RangeTable for U16Table { const ROM_TYPE: ROMType = ROMType::U16; diff --git a/ceno_zkvm/src/tables/range/range_circuit.rs b/ceno_zkvm/src/tables/range/range_circuit.rs index e1132f2be..bb7c83448 100644 --- a/ceno_zkvm/src/tables/range/range_circuit.rs +++ b/ceno_zkvm/src/tables/range/range_circuit.rs @@ -35,7 +35,7 @@ impl TableCircuit for RangeTableCircuit fn construct_circuit(cb: &mut CircuitBuilder) -> Result { cb.namespace( || Self::name(), - |cb| RangeTableConfig::construct_circuit(cb, RANGE::ROM_TYPE), + |cb| RangeTableConfig::construct_circuit(cb, RANGE::ROM_TYPE, RANGE::len()), ) } @@ -44,7 +44,9 @@ impl TableCircuit for RangeTableCircuit num_fixed: usize, _input: &(), ) -> RowMajorMatrix { - config.generate_fixed_traces(num_fixed, RANGE::content()) + let mut table = config.generate_fixed_traces(num_fixed, RANGE::content()); + Self::padding_zero(&mut table, num_fixed).expect("padding error"); + table } fn assign_instances( @@ -54,6 +56,8 @@ impl TableCircuit for RangeTableCircuit _input: &(), ) -> Result, ZKVMError> { let multiplicity = &multiplicity[RANGE::ROM_TYPE as usize]; - config.assign_instances(num_witin, multiplicity, RANGE::len()) + let mut table = config.assign_instances(num_witin, multiplicity, RANGE::len())?; + Self::padding_zero(&mut table, num_witin).expect("padding error"); + Ok(table) } } diff --git a/ceno_zkvm/src/tables/range/range_impl.rs b/ceno_zkvm/src/tables/range/range_impl.rs index f40f5d279..8a14fe236 100644 --- a/ceno_zkvm/src/tables/range/range_impl.rs +++ b/ceno_zkvm/src/tables/range/range_impl.rs @@ -25,6 +25,7 @@ impl RangeTableConfig { pub fn construct_circuit( cb: &mut CircuitBuilder, rom_type: ROMType, + table_len: usize, ) -> Result { let fixed = cb.create_fixed(|| "fixed")?; let mlt = cb.create_witin(|| "mlt")?; @@ -32,7 +33,7 @@ impl RangeTableConfig { let rlc_record = cb.rlc_chip_record(vec![(rom_type as usize).into(), Expression::Fixed(fixed)]); - cb.lk_table_record(|| "record", rlc_record, mlt.expr())?; + cb.lk_table_record(|| "record", table_len, rlc_record, mlt.expr())?; Ok(Self { fixed, mlt }) } @@ -44,11 +45,6 @@ impl RangeTableConfig { ) -> RowMajorMatrix { let mut fixed = RowMajorMatrix::::new(content.len(), num_fixed); - // Fill the padding with zeros, if any. - fixed.par_iter_mut().skip(content.len()).for_each(|row| { - set_fixed_val!(row, self.fixed, F::ZERO); - }); - fixed .par_iter_mut() .with_min_len(MIN_PAR_SIZE) @@ -73,11 +69,6 @@ impl RangeTableConfig { mlts[*idx as usize] = *mlt; } - // Fill the padding with zeros, if any. - witness.par_iter_mut().skip(length).for_each(|row| { - set_val!(row, self.mlt, F::ZERO); - }); - witness .par_iter_mut() .with_min_len(MIN_PAR_SIZE) diff --git a/ceno_zkvm/src/uint.rs b/ceno_zkvm/src/uint.rs index 69abfc7a1..6a9732e0e 100644 --- a/ceno_zkvm/src/uint.rs +++ b/ceno_zkvm/src/uint.rs @@ -4,7 +4,7 @@ mod logic; pub mod util; use crate::{ - chip_handler::{MemoryExpr, RegisterExpr}, + chip_handler::{AddressExpr, MemoryExpr, RegisterExpr}, circuit_builder::CircuitBuilder, error::{UtilError, ZKVMError}, expression::{Expression, ToExpr, WitIn}, @@ -88,7 +88,7 @@ impl UIntLimbs { circuit_builder.namespace(name_fn, |cb| { Ok(UIntLimbs { limbs: UintLimb::WitIn( - (0..Self::NUM_CELLS) + (0..Self::NUM_LIMBS) .map(|i| { let w = cb.create_witin(|| format!("limb_{i}"))?; if is_check { @@ -112,7 +112,7 @@ impl UIntLimbs { carries: Option>, carries_auxiliary_lt_config: Option>, ) -> Self { - assert!(limbs.len() == Self::NUM_CELLS); + assert!(limbs.len() == Self::NUM_LIMBS); if let Some(carries) = &carries { let diff = limbs.len() - carries.len(); assert!( @@ -131,12 +131,12 @@ impl UIntLimbs { /// take vector of primative type and instantiate witnesses pub fn from_const_unchecked>(limbs: Vec) -> Self { - assert!(limbs.len() == Self::NUM_CELLS); + assert!(limbs.len() == Self::NUM_LIMBS); UIntLimbs { limbs: UintLimb::Expression( limbs .into_iter() - .take(Self::NUM_CELLS) + .take(Self::NUM_LIMBS) .map(|limb| Expression::Constant(E::BaseField::from(limb.into()))) .collect::>>(), ), @@ -159,8 +159,8 @@ impl UIntLimbs { circuit_builder: &mut CircuitBuilder, expr_limbs: Vec>, ) -> Self { - assert_eq!(expr_limbs.len(), Self::NUM_CELLS); - let limbs = (0..Self::NUM_CELLS) + assert_eq!(expr_limbs.len(), Self::NUM_LIMBS); + let limbs = (0..Self::NUM_LIMBS) .map(|i| { let w = circuit_builder.create_witin(|| "wit for limb").unwrap(); circuit_builder @@ -208,10 +208,10 @@ impl UIntLimbs { pub fn assign_limbs(&self, instance: &mut [MaybeUninit], limbs_values: &[u16]) { assert!( - limbs_values.len() <= Self::NUM_CELLS, + limbs_values.len() <= Self::NUM_LIMBS, "assign input length mismatch. input_len={}, NUM_CELLS={}", limbs_values.len(), - Self::NUM_CELLS + Self::NUM_LIMBS ); if let UintLimb::WitIn(wires) = &self.limbs { for (wire, limb) in wires.iter().zip( @@ -285,9 +285,7 @@ impl UIntLimbs { let shift_pows = { let mut shift_pows = Vec::with_capacity(k); shift_pows.push(Expression::Constant(E::BaseField::ONE)); - (0..k - 1).for_each(|_| { - shift_pows.push(shift_pows.last().unwrap().clone() * (1 << 8).into()) - }); + (0..k - 1).for_each(|_| shift_pows.push(shift_pows.last().unwrap().clone() * (1 << 8))); shift_pows }; let combined_limbs = x @@ -304,7 +302,7 @@ impl UIntLimbs { .unwrap() }) .collect_vec(); - UIntLimbs::::from_exprs_unchecked(combined_limbs) + Ok(UIntLimbs::::from_exprs_unchecked(combined_limbs)) } pub fn to_u8_limbs( @@ -317,9 +315,7 @@ impl UIntLimbs { let shift_pows = { let mut shift_pows = Vec::with_capacity(k); shift_pows.push(Expression::Constant(E::BaseField::ONE)); - (0..k - 1).for_each(|_| { - shift_pows.push(shift_pows.last().unwrap().clone() * (1 << 8).into()) - }); + (0..k - 1).for_each(|_| shift_pows.push(shift_pows.last().unwrap().clone() * (1 << 8))); shift_pows }; let split_limbs = x @@ -349,19 +345,18 @@ impl UIntLimbs { UIntLimbs::::create_witin_from_exprs(circuit_builder, split_limbs) } - pub fn from_exprs_unchecked(expr_limbs: Vec>) -> Result { - let n = Self { + pub fn from_exprs_unchecked(expr_limbs: Vec>) -> Self { + Self { limbs: UintLimb::Expression( expr_limbs .into_iter() .chain(std::iter::repeat(Expression::ZERO)) - .take(Self::NUM_CELLS) + .take(Self::NUM_LIMBS) .collect_vec(), ), carries: None, carries_auxiliary_lt_config: None, - }; - Ok(n) + } } /// If current limbs are Expression, this function will create witIn and replace the limbs @@ -373,7 +368,7 @@ impl UIntLimbs { if let UintLimb::Expression(_) = self.limbs { circuit_builder.namespace(name_fn, |cb| { self.limbs = UintLimb::WitIn( - (0..Self::NUM_CELLS) + (0..Self::NUM_LIMBS) .map(|i| { let w = cb.create_witin(|| format!("limb_{i}"))?; cb.assert_ux::<_, _, C>(|| format!("limb_{i}_in_{C}"), w.expr())?; @@ -527,10 +522,16 @@ impl UIntLimbs { let mut self_lo = self.expr(); let self_hi = self_lo.split_off(self_lo.len() / 2); Ok(( - UIntLimbs::from_exprs_unchecked(self_lo)?, - UIntLimbs::from_exprs_unchecked(self_hi)?, + UIntLimbs::from_exprs_unchecked(self_lo), + UIntLimbs::from_exprs_unchecked(self_hi), )) } + + pub fn to_field_expr(&self, is_neg: Expression) -> Expression { + // Convert two's complement representation into field arithmetic. + // Example: 0xFFFF_FFFF = 2^32 - 1 --> shift --> -1 + self.value() - is_neg * (1_u64 << 32) + } } /// Construct `UIntLimbs` from `Vec` @@ -538,13 +539,13 @@ impl TryFrom> for type Error = UtilError; fn try_from(limbs: Vec) -> Result { - if limbs.len() != Self::NUM_CELLS { + if limbs.len() != Self::NUM_LIMBS { return Err(UtilError::UIntError(format!( "cannot construct UIntLimbs<{}, {}> from {} cells, requires {} cells", M, C, limbs.len(), - Self::NUM_CELLS + Self::NUM_LIMBS ))); } @@ -585,10 +586,14 @@ impl UIntLimbs<32, 16, E> { u16_limbs.try_into().expect("two limbs with M=32 and C=16") } + /// Interpret this UInt as a memory address. + pub fn address_expr(&self) -> AddressExpr { + self.value() + } + /// Return a value suitable for memory read/write. From [u16; 2] limbs pub fn memory_expr(&self) -> MemoryExpr { - let u16_limbs = self.expr(); - u16_limbs.try_into().expect("two limbs with M=32 and C=16") + self.value() } } @@ -600,7 +605,7 @@ impl UIntLimbs<32, 8, E> { .chunks(2) .map(|chunk| { let (a, b) = (chunk[0].clone(), chunk[1].clone()); - a + b * 256.into() + a + b * 256 }) .collect_vec(); u16_limbs.try_into().expect("four limbs with M=32 and C=8") @@ -630,6 +635,7 @@ impl ValueMul { } } +#[derive(Clone)] pub struct Value<'a, T: Into + From + Copy + Default> { #[allow(dead_code)] val: T, diff --git a/ceno_zkvm/src/uint/arithmetic.rs b/ceno_zkvm/src/uint/arithmetic.rs index bef8c71cf..c93bb73ae 100644 --- a/ceno_zkvm/src/uint/arithmetic.rs +++ b/ceno_zkvm/src/uint/arithmetic.rs @@ -5,7 +5,6 @@ use itertools::{Itertools, izip}; use super::{UIntLimbs, UintLimb}; use crate::{ circuit_builder::CircuitBuilder, - create_witin_from_expr, error::ZKVMError, expression::{Expression, ToExpr, WitIn}, gadgets::AssertLTConfig, @@ -29,7 +28,7 @@ impl UIntLimbs { || "add_carry", circuit_builder, with_overflow, - Self::NUM_CELLS, + Self::NUM_LIMBS, )?; let Some(carries) = &c.carries else { return Err(ZKVMError::CircuitError); @@ -55,8 +54,7 @@ impl UIntLimbs { limb_expr = limb_expr.clone() + carry.unwrap().expr(); } if next_carry.is_some() { - limb_expr = - limb_expr.clone() - next_carry.unwrap().expr() * Self::POW_OF_C.into(); + limb_expr = limb_expr.clone() - next_carry.unwrap().expr() * Self::POW_OF_C; } circuit_builder @@ -83,7 +81,7 @@ impl UIntLimbs { let b = c.to_canonical_u64(); // convert Expression::Constant to limbs - let b_limbs = (0..Self::NUM_CELLS) + let b_limbs = (0..Self::NUM_LIMBS) .map(|i| { Expression::Constant(E::BaseField::from((b >> (C * i)) & Self::LIMB_BIT_MASK)) }) @@ -115,9 +113,9 @@ impl UIntLimbs { debug_assert!(M2 == M || M2 == 2 * M, "illegal M2 {M2} and M {M}"); let is_hi_limb = M2 == 2 * M; let num_limbs = if is_hi_limb { - 2 * Self::NUM_CELLS + 2 * Self::NUM_LIMBS } else { - Self::NUM_CELLS + Self::NUM_LIMBS }; // with high limb, overall cell will be double let c_limbs: Vec = (0..num_limbs).try_fold(vec![], |mut c_limbs, i| { @@ -199,8 +197,7 @@ impl UIntLimbs { result_c[i] = result_c[i].clone() + carry.unwrap().expr(); } if next_carry.is_some() { - result_c[i] = - result_c[i].clone() - next_carry.unwrap().expr() * Self::POW_OF_C.into(); + result_c[i] = result_c[i].clone() - next_carry.unwrap().expr() * Self::POW_OF_C; } circuit_builder.require_zero(|| format!("mul_zero_{i}"), result_c[i].clone())?; Ok::<(), ZKVMError>(()) @@ -244,7 +241,7 @@ impl UIntLimbs { mul_hi } else { // lo limb - UIntLimbs::from_exprs_unchecked(mul.expr())? + UIntLimbs::from_exprs_unchecked(mul.expr()) }; let add = cb.namespace( || "add", @@ -269,7 +266,7 @@ impl UIntLimbs { circuit_builder: &mut CircuitBuilder, rhs: &UIntLimbs, ) -> Result { - let n_limbs = Self::NUM_CELLS; + let n_limbs = Self::NUM_LIMBS; let (is_equal_per_limb, diff_inv_per_limb): (Vec, Vec) = self .limbs .iter() @@ -283,7 +280,7 @@ impl UIntLimbs { .iter() .fold(Expression::ZERO, |acc, flag| acc.clone() + flag.expr()); - let sum_flag = create_witin_from_expr!(|| "sum_flag", circuit_builder, false, sum_expr)?; + let sum_flag = WitIn::from_expr(|| "sum_flag", circuit_builder, sum_expr, false)?; let (is_equal, diff_inv) = circuit_builder.is_equal(sum_flag.expr(), Expression::from(n_limbs))?; Ok(IsEqualConfig { @@ -306,7 +303,7 @@ impl UIntLimbs { E: ExtensionField, { let high_limb_no_msb = circuit_builder.create_witin(|| "high_limb_mask")?; - let high_limb = self.limbs[Self::NUM_CELLS - 1].expr(); + let high_limb = self.limbs[Self::NUM_LIMBS - 1].expr(); circuit_builder.lookup_and_byte( high_limb.clone(), @@ -316,7 +313,7 @@ impl UIntLimbs { let inv_128 = F::from(128).invert().unwrap(); let msb = (high_limb - high_limb_no_msb.expr()) * Expression::Constant(inv_128); - let msb = create_witin_from_expr!(|| "msb", circuit_builder, false, msb)?; + let msb = WitIn::from_expr(|| "msb", circuit_builder, msb, false)?; Ok(MsbConfig { msb, high_limb_no_msb, @@ -329,7 +326,7 @@ impl UIntLimbs { circuit_builder: &mut CircuitBuilder, rhs: &UIntLimbs, ) -> Result { - let n_bytes = Self::NUM_CELLS; + let n_bytes = Self::NUM_LIMBS; let indexes: Vec = (0..n_bytes) .map(|_| circuit_builder.create_witin(|| "index")) .collect::>()?; @@ -361,7 +358,7 @@ impl UIntLimbs { .rev() .enumerate() .map(|(i, expr)| { - create_witin_from_expr!(|| format!("si_expr_{i}"), circuit_builder, false, expr) + WitIn::from_expr(|| format!("si_expr_{i}"), circuit_builder, expr, false) }) .collect::, ZKVMError>>()?; @@ -396,10 +393,8 @@ impl UIntLimbs { // check the first byte difference has a inverse // unwrap is safe because vector len > 0 - let lhs_ne_byte = - create_witin_from_expr!(|| "lhs_ne_byte", circuit_builder, false, sa.clone())?; - let rhs_ne_byte = - create_witin_from_expr!(|| "rhs_ne_byte", circuit_builder, false, sb.clone())?; + let lhs_ne_byte = WitIn::from_expr(|| "lhs_ne_byte", circuit_builder, sa.clone(), false)?; + let rhs_ne_byte = WitIn::from_expr(|| "rhs_ne_byte", circuit_builder, sb.clone(), false)?; let index_ne = si.first().unwrap(); circuit_builder.require_zero( || "byte inverse check", @@ -433,10 +428,10 @@ impl UIntLimbs { let rhs_msb = rhs.msb_decompose(circuit_builder)?; let mut lhs_limbs = self.limbs.iter().copied().collect_vec(); - lhs_limbs[Self::NUM_CELLS - 1] = lhs_msb.high_limb_no_msb; + lhs_limbs[Self::NUM_LIMBS - 1] = lhs_msb.high_limb_no_msb; let lhs_no_msb = Self::from_witins_unchecked(lhs_limbs, None, None); let mut rhs_limbs = rhs.limbs.iter().copied().collect_vec(); - rhs_limbs[Self::NUM_CELLS - 1] = rhs_msb.high_limb_no_msb; + rhs_limbs[Self::NUM_LIMBS - 1] = rhs_msb.high_limb_no_msb; let rhs_no_msb = Self::from_witins_unchecked(rhs_limbs, None, None); // (1) compute ltu(a_{::MAX_CELL_BIT_WIDTH as u32) as u64; - let single_wit_size = UIntLimbs::::NUM_CELLS; + let pow_of_c: u64 = 2_usize.pow(UIntLimbs::::MAX_LIMB_BIT_WIDTH as u32) as u64; + let single_wit_size = UIntLimbs::::NUM_LIMBS; let a = &witness_values[0..single_wit_size]; let mut const_b_pre_allocated = vec![0u64; single_wit_size]; @@ -781,8 +776,8 @@ mod tests { witness_values: Vec, overflow: bool, ) { - let pow_of_c: u64 = 2_usize.pow(UIntLimbs::::MAX_CELL_BIT_WIDTH as u32) as u64; - let single_wit_size = UIntLimbs::::NUM_CELLS; + let pow_of_c: u64 = 2_usize.pow(UIntLimbs::::MAX_LIMB_BIT_WIDTH as u32) as u64; + let single_wit_size = UIntLimbs::::NUM_LIMBS; if overflow { assert_eq!( witness_values.len() % single_wit_size, @@ -937,7 +932,7 @@ mod tests { .require_equal(|| "assert_g", &mut cb, &uint_e) .unwrap(); - MockProver::assert_satisfied(&cb, &witness_values, None); + MockProver::assert_satisfied(&cb, &witness_values, &[], None, None); } #[test] @@ -987,7 +982,7 @@ mod tests { .require_equal(|| "assert_g", &mut cb, &uint_g) .unwrap(); - MockProver::assert_satisfied(&cb, &witness_values, None); + MockProver::assert_satisfied(&cb, &witness_values, &[], None, None); } #[test] @@ -1026,7 +1021,7 @@ mod tests { .require_equal(|| "assert_e", &mut cb, &uint_e) .unwrap(); - MockProver::assert_satisfied(&cb, &witness_values, None); + MockProver::assert_satisfied(&cb, &witness_values, &[], None, None); } #[test] @@ -1065,7 +1060,7 @@ mod tests { .require_equal(|| "assert_e", &mut cb, &uint_e) .unwrap(); - MockProver::assert_satisfied(&cb, &witness_values, None); + MockProver::assert_satisfied(&cb, &witness_values, &[], None, None); } #[test] @@ -1102,7 +1097,7 @@ mod tests { .require_equal(|| "assert_g", &mut cb, &uint_c) .unwrap(); - MockProver::assert_satisfied(&cb, &witness_values, None); + MockProver::assert_satisfied(&cb, &witness_values, &[], None, None); } } } diff --git a/ceno_zkvm/src/uint/constants.rs b/ceno_zkvm/src/uint/constants.rs index dc917bc92..c418213f0 100644 --- a/ceno_zkvm/src/uint/constants.rs +++ b/ceno_zkvm/src/uint/constants.rs @@ -6,24 +6,26 @@ pub const BYTE_BIT_WIDTH: usize = 8; use ff_ext::ExtensionField; -impl UIntLimbs { - pub const M: usize = M; - pub const C: usize = C; - - /// Determines the maximum number of bits that should be represented in each cell - /// independent of the cell capacity `C`. - /// If M < C i.e. total bit < cell capacity, the maximum_usable_cell_capacity - /// is actually M. - /// but if M >= C then maximum_usable_cell_capacity = C - pub const MAX_CELL_BIT_WIDTH: usize = const_min(M, C); - - /// `NUM_CELLS` represent the minimum number of cells each of size `C` needed - /// to hold `M` total bits - pub const NUM_CELLS: usize = M.div_ceil(C); +impl + UIntLimbs +{ + pub const TOTAL_BITS: usize = TOTAL_BITS; + pub const LIMB_BITS: usize = CAPACITY; + + /// Determines the maximum number of bits that should be represented in each limb + /// independent of the limb capacity. + /// If total bits < limb capacity, the maximum_usable_limb_capacity + /// is actually 'total bits'. + /// but if total bits >= limb capacity then maximum_usable_limb_capacity = 'limb capacity'. + pub const MAX_LIMB_BIT_WIDTH: usize = const_min(TOTAL_BITS, CAPACITY); + + /// `NUM_LIMBS` represent the minimum number of limbs needed + /// to hold total bits + pub const NUM_LIMBS: usize = TOTAL_BITS.div_ceil(CAPACITY); /// Max carry value during degree 2 limb multiplication pub const MAX_DEGREE_2_MUL_CARRY_VALUE: u64 = - max_carry_word_for_multiplication(2, Self::M, Self::C); + max_carry_word_for_multiplication(2, Self::TOTAL_BITS, Self::LIMB_BITS); /// Min bits to cover MAX_DEGREE_2_MUL_CARRY_VALUE pub const MAX_DEGREE_2_MUL_CARRY_BITS: usize = { @@ -33,5 +35,5 @@ impl UIntLimbs { /// Min number of u16 limb to cover max carry value pub const MAX_DEGREE_2_MUL_CARRY_U16_LIMB: usize = - (Self::MAX_DEGREE_2_MUL_CARRY_BITS + 15) / 16; + Self::MAX_DEGREE_2_MUL_CARRY_BITS.div_ceil(16); } diff --git a/ceno_zkvm/src/utils.rs b/ceno_zkvm/src/utils.rs index b6c01e933..47af02967 100644 --- a/ceno_zkvm/src/utils.rs +++ b/ceno_zkvm/src/utils.rs @@ -2,6 +2,7 @@ use ff::Field; use ff_ext::ExtensionField; use goldilocks::SmallField; use itertools::Itertools; +use multilinear_extensions::util::max_usable_threads; use transcript::Transcript; /// convert ext field element to u64, assume it is inside the range @@ -20,7 +21,7 @@ pub fn i64_to_base(x: i64) -> F { } /// This is helper function to convert witness of u8 limb into u16 limb -/// TODO: need a better way to keep consistency of VALUE_BIT_WIDTH +/// TODO: need a better way to keep consistency of LIMB_BITS #[allow(dead_code)] pub fn limb_u8_to_u16(input: &[u8]) -> Vec { input @@ -113,7 +114,8 @@ pub fn u64vec(x: u64) -> [u64; W] { /// we expect each thread at least take 4 num of sumcheck variables /// return optimal num threads to run sumcheck -pub fn proper_num_threads(num_vars: usize, expected_max_threads: usize) -> usize { +pub fn optimal_sumcheck_threads(num_vars: usize) -> usize { + let expected_max_threads = max_usable_threads(); let min_numvar_per_thread = 4; if num_vars <= min_numvar_per_thread { 1 diff --git a/ceno_zkvm/src/virtual_polys.rs b/ceno_zkvm/src/virtual_polys.rs index 5d81ea042..4019f2d22 100644 --- a/ceno_zkvm/src/virtual_polys.rs +++ b/ceno_zkvm/src/virtual_polys.rs @@ -22,11 +22,11 @@ pub struct VirtualPolynomials<'a, E: ExtensionField> { } impl<'a, E: ExtensionField> VirtualPolynomials<'a, E> { - pub fn new(num_threads: usize, num_variables: usize) -> Self { + pub fn new(num_threads: usize, max_num_variables: usize) -> Self { VirtualPolynomials { num_threads, polys: (0..num_threads) - .map(|_| VirtualPolynomialV2::new(num_variables - ceil_log2(num_threads))) + .map(|_| VirtualPolynomialV2::new(max_num_variables - ceil_log2(num_threads))) .collect_vec(), thread_based_mles_storage: HashMap::new(), } diff --git a/ceno_zkvm/src/witness.rs b/ceno_zkvm/src/witness.rs index c3dee7e45..bdffa2ec6 100644 --- a/ceno_zkvm/src/witness.rs +++ b/ceno_zkvm/src/witness.rs @@ -21,7 +21,7 @@ use thread_local::ThreadLocal; use crate::{ structs::ROMType, - tables::{AndTable, LtuTable, OpsTable, OrTable, XorTable}, + tables::{AndTable, LtuTable, OpsTable, OrTable, PowTable, XorTable}, utils::next_pow2_instance_padding, }; @@ -127,7 +127,7 @@ impl Index for RowMajorMatrix { /// A lock-free thread safe struct to count logup multiplicity for each ROM type /// Lock-free by thread-local such that each thread will only have its local copy /// struct is cloneable, for internallly it use Arc so the clone will be low cost -#[derive(Clone, Default)] +#[derive(Clone, Default, Debug)] #[allow(clippy::type_complexity)] pub struct LkMultiplicity { multiplicity: Arc; mem::variant_count::()]>>>, @@ -140,6 +140,7 @@ impl LkMultiplicity { pub fn assert_ux(&mut self, v: u64) { match C { 16 => self.increment(ROMType::U16, v), + 14 => self.increment(ROMType::U14, v), 8 => self.increment(ROMType::U8, v), 5 => self.increment(ROMType::U5, v), _ => panic!("Unsupported bit range"), @@ -171,6 +172,10 @@ impl LkMultiplicity { self.logic_u8::(a, b) } + pub fn lookup_pow2(&mut self, v: u64) { + self.logic_u8::(2, v) + } + /// Fetch instruction at pc pub fn fetch(&mut self, pc: u32) { self.increment(ROMType::Instruction, pc as u64); diff --git a/examples-builder/Cargo.toml b/examples-builder/Cargo.toml index a7f691478..6862e806b 100644 --- a/examples-builder/Cargo.toml +++ b/examples-builder/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "ceno-examples" -version.workspace = true edition.workspace = true license.workspace = true +name = "ceno-examples" +version.workspace = true diff --git a/examples/Cargo.lock b/examples/Cargo.lock index cab8ac8ef..359355930 100644 --- a/examples/Cargo.lock +++ b/examples/Cargo.lock @@ -7,14 +7,13 @@ name = "ceno_rt" version = "0.1.0" dependencies = [ "riscv", - "riscv-rt", ] [[package]] name = "critical-section" -version = "1.1.3" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f64009896348fc5af4222e9cf7d7d82a95a256c634ebcf61c53e4ea461422242" +checksum = "790eea4361631c5e7d22598ecd5723ff611904e3344ce8720784c93e3d83d40b" [[package]] name = "embedded-hal" @@ -29,11 +28,17 @@ dependencies = [ "ceno_rt", ] +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + [[package]] name = "proc-macro2" -version = "1.0.87" +version = "1.0.88" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3e4daa0dcf6feba26f985457cdf104d4b4256fc5a09547140f3631bb076b19a" +checksum = "7c3a7fc5db1e57d5a779a352c8cdb57b29aa4c40cc69c3a68a7fedc815fbf2f9" dependencies = [ "unicode-ident", ] @@ -49,40 +54,39 @@ dependencies = [ [[package]] name = "riscv" -version = "0.11.1" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f5c1b8bf41ea746266cdee443d1d1e9125c86ce1447e1a2615abd34330d33a9" +checksum = "5ea8ff73d3720bdd0a97925f0bf79ad2744b6da8ff36be3840c48ac81191d7a7" dependencies = [ "critical-section", "embedded-hal", + "paste", + "riscv-macros", + "riscv-pac", ] [[package]] -name = "riscv-rt" -version = "0.12.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0d35e32cf1383183e8885d8a9aa4402a087fd094dc34c2cb6df6687d0229dfe" -dependencies = [ - "riscv", - "riscv-rt-macros", -] - -[[package]] -name = "riscv-rt-macros" -version = "0.2.1" +name = "riscv-macros" +version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8d100d466dbb76681ef6a9386f3da9abc570d57394e86da0ba5af8c4408486d" +checksum = "f265be5d634272320a7de94cea15c22a3bfdd4eb42eb43edc528415f066a1f25" dependencies = [ "proc-macro2", "quote", "syn", ] +[[package]] +name = "riscv-pac" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8188909339ccc0c68cfb5a04648313f09621e8b87dc03095454f1a11f6c5d436" + [[package]] name = "syn" -version = "1.0.109" +version = "2.0.82" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +checksum = "83540f837a8afc019423a8edb95b52a8effe46957ee402287f4292fae35be021" dependencies = [ "proc-macro2", "quote", diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 1d23821ab..08082d0bd 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -1,8 +1,8 @@ [package] +edition = "2021" name = "examples" resolver = "2" version = "0.1.0" -edition = "2021" [dependencies] ceno_rt = { path = "../ceno_rt" } diff --git a/examples/examples/ceno_rt_alloc.rs b/examples/examples/ceno_rt_alloc.rs index e44a7f4ca..3e169d9b3 100644 --- a/examples/examples/ceno_rt_alloc.rs +++ b/examples/examples/ceno_rt_alloc.rs @@ -9,7 +9,7 @@ use alloc::{vec, vec::Vec}; static mut OUTPUT: u32 = 0; -#[no_mangle] +ceno_rt::entry!(main); fn main() { // Test writing to a global variable. unsafe { diff --git a/examples/examples/ceno_rt_io.rs b/examples/examples/ceno_rt_io.rs index 3ecaf6a7e..8d2afa6aa 100644 --- a/examples/examples/ceno_rt_io.rs +++ b/examples/examples/ceno_rt_io.rs @@ -5,7 +5,7 @@ extern crate ceno_rt; use ceno_rt::println; use core::fmt::Write; -#[no_mangle] +ceno_rt::entry!(main); fn main() { println!("📜📜📜 Hello, World!"); println!("🌏🌍🌎"); diff --git a/examples/examples/ceno_rt_mem.rs b/examples/examples/ceno_rt_mem.rs index ab31af075..2cc484f80 100644 --- a/examples/examples/ceno_rt_mem.rs +++ b/examples/examples/ceno_rt_mem.rs @@ -7,7 +7,7 @@ use core::ptr::{read_volatile, write_volatile}; extern crate ceno_rt; const OUTPUT_ADDRESS: u32 = 0x8000_0000; -#[no_mangle] +ceno_rt::entry!(main); #[inline(never)] fn main() { test_data_section(); diff --git a/examples/examples/ceno_rt_mini.rs b/examples/examples/ceno_rt_mini.rs index 97c7280b4..70ec64153 100644 --- a/examples/examples/ceno_rt_mini.rs +++ b/examples/examples/ceno_rt_mini.rs @@ -3,5 +3,5 @@ extern crate ceno_rt; -#[no_mangle] +ceno_rt::entry!(main); fn main() {} diff --git a/examples/examples/ceno_rt_panic.rs b/examples/examples/ceno_rt_panic.rs index 0e5d3b699..dc8c2fa17 100644 --- a/examples/examples/ceno_rt_panic.rs +++ b/examples/examples/ceno_rt_panic.rs @@ -3,7 +3,7 @@ extern crate ceno_rt; -#[no_mangle] +ceno_rt::entry!(main); fn main() { panic!("This is a panic message!"); } diff --git a/ff_ext/Cargo.toml b/ff_ext/Cargo.toml index f3f15ea3d..5c6a0d937 100644 --- a/ff_ext/Cargo.toml +++ b/ff_ext/Cargo.toml @@ -1,13 +1,11 @@ [package] -name = "ff_ext" -version.workspace = true edition.workspace = true license.workspace = true - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +name = "ff_ext" +version.workspace = true [dependencies] -serde.workspace = true -goldilocks.workspace = true ff.workspace = true +goldilocks.workspace = true poseidon.workspace = true +serde.workspace = true diff --git a/ff_ext/src/lib.rs b/ff_ext/src/lib.rs index beb487a69..4d87c38c0 100644 --- a/ff_ext/src/lib.rs +++ b/ff_ext/src/lib.rs @@ -21,10 +21,11 @@ pub trait ExtensionField: + for<'a> AddAssign<&'a Self::BaseField> + for<'a> SubAssign<&'a Self::BaseField> + for<'a> MulAssign<&'a Self::BaseField> + + Ord { const DEGREE: usize; - type BaseField: SmallField + FromUniformBytes<64> + Poseidon; + type BaseField: SmallField + FromUniformBytes<64> + Poseidon + Ord; fn from_bases(bases: &[Self::BaseField]) -> Self; diff --git a/gkr-graph/Cargo.toml b/gkr-graph/Cargo.toml index b3346c7a6..a8dfbf8d7 100644 --- a/gkr-graph/Cargo.toml +++ b/gkr-graph/Cargo.toml @@ -1,19 +1,17 @@ [package] -name = "gkr-graph" -version.workspace = true edition.workspace = true license.workspace = true - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +name = "gkr-graph" +version.workspace = true [dependencies] -ff_ext = { path = "../ff_ext" } ark-std.workspace = true ff.workspace = true +ff_ext = { path = "../ff_ext" } gkr = { path = "../gkr", features = ["parallel"] } goldilocks.workspace = true -itertools = "0.12.1" -multilinear_extensions = { version = "0.1.0", path = "../multilinear_extensions" } -simple-frontend = { version = "0.1.0", path = "../simple-frontend" } -sumcheck = { version = "0.1.0", path = "../sumcheck" } -transcript = { version = "0.1.0", path = "../transcript" } +itertools.workspace = true +multilinear_extensions = { version = "0", path = "../multilinear_extensions" } +simple-frontend = { version = "0", path = "../simple-frontend" } +sumcheck = { version = "0", path = "../sumcheck" } +transcript = { version = "0", path = "../transcript" } diff --git a/gkr/Cargo.toml b/gkr/Cargo.toml index 629b6891a..fab3ec6b8 100644 --- a/gkr/Cargo.toml +++ b/gkr/Cargo.toml @@ -1,44 +1,41 @@ [package] -name = "gkr" -version.workspace = true edition.workspace = true license.workspace = true - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +name = "gkr" +version.workspace = true [dependencies] -ff.workspace = true ark-std.workspace = true +ff.workspace = true goldilocks.workspace = true -simple-frontend = { path = "../simple-frontend" } -transcript = { path = "../transcript" } -sumcheck = { path = "../sumcheck" } -multilinear_extensions = { path = "../multilinear_extensions" } +crossbeam-channel.workspace = true ff_ext = { path = "../ff_ext" } -serde.workspace = true -serde_json = "1.0.128" +itertools.workspace = true +multilinear_extensions = { path = "../multilinear_extensions" } rayon.workspace = true -tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } -tracing-flame = "0.2.0" -tracing = "0.1.40" -crossbeam-channel = "0.5.13" -itertools = "0.12.1" -const_env = "0.1.2" +serde.workspace = true +serde_json.workspace = true +simple-frontend = { path = "../simple-frontend" } +sumcheck = { path = "../sumcheck" } +tracing.workspace = true +tracing-flame.workspace = true +tracing-subscriber.workspace = true +transcript = { path = "../transcript" } [dev-dependencies] -tiny-keccak = { version = "2.0.2", features = ["keccak"] } -pprof = { version = "0.13", features = ["flamegraph"]} -criterion = { version = "0.5", features = ["html_reports"] } -cfg-if = "1.0.0" +cfg-if.workspace = true +criterion.workspace = true +pprof.workspace = true +tiny-keccak = { version = "2.0", features = ["keccak"] } [features] -default = [ ] +default = [] flamegraph = ["pprof/flamegraph", "pprof/criterion"] -parallel = [ ] non_pow2_rayon_thread = ["sumcheck/non_pow2_rayon_thread"] +parallel = [] unsafe = [] [[bench]] -name = "keccak256" harness = false +name = "keccak256" diff --git a/gkr/Makefile.toml b/gkr/Makefile.toml index f2084a373..f75ac8205 100644 --- a/gkr/Makefile.toml +++ b/gkr/Makefile.toml @@ -4,14 +4,14 @@ CORE = { script = ["nproc"] } RAYON_NUM_THREADS = "${CORE}" [tasks.gkr_bench] -command = "cargo" args = ["bench", "--bench", "keccak256", "--features", "flamegraph", "--features", "parallel", "--package", "gkr"] +command = "cargo" [tasks.gkr_example] -command = "cargo" args = ["run", "--package", "gkr", "--release", "--features", "parallel", "--example", "keccak256"] +command = "cargo" [tasks.gkr_example_flamegraph] -env = { "RUST_LOG" = "debug", "RAYON_NUM_THREADS" = "1"} -command = "cargo" args = ["run", "--package", "gkr", "--release", "--features", "parallel", "--example", "keccak256"] +command = "cargo" +env = { "RUST_LOG" = "debug", "RAYON_NUM_THREADS" = "1" } diff --git a/gkr/benches/keccak256.rs b/gkr/benches/keccak256.rs index fce8ffb31..d48920dd7 100644 --- a/gkr/benches/keccak256.rs +++ b/gkr/benches/keccak256.rs @@ -3,11 +3,11 @@ use std::time::Duration; -use const_env::from_env; use criterion::*; use gkr::gadgets::keccak256::{keccak256_circuit, prove_keccak256, verify_keccak256}; use goldilocks::GoldilocksExt2; +use multilinear_extensions::util::max_usable_threads; cfg_if::cfg_if! { if #[cfg(feature = "flamegraph")] { @@ -28,8 +28,6 @@ cfg_if::cfg_if! { criterion_main!(keccak256); const NUM_SAMPLES: usize = 10; -#[from_env] -const RAYON_NUM_THREADS: usize = 8; fn bench_keccak256(c: &mut Criterion) { println!( @@ -37,26 +35,7 @@ fn bench_keccak256(c: &mut Criterion) { keccak256_circuit::().layers.len() ); - let max_thread_id = { - if !RAYON_NUM_THREADS.is_power_of_two() { - #[cfg(not(feature = "non_pow2_rayon_thread"))] - { - panic!( - "add --features non_pow2_rayon_thread to enable unsafe feature which support non pow of 2 rayon thread pool" - ); - } - - #[cfg(feature = "non_pow2_rayon_thread")] - { - use sumcheck::{local_thread_pool::create_local_pool_once, util::ceil_log2}; - let max_thread_id = 1 << ceil_log2(RAYON_NUM_THREADS); - create_local_pool_once(1 << ceil_log2(RAYON_NUM_THREADS), true); - max_thread_id - } - } else { - RAYON_NUM_THREADS - } - }; + let max_thread_id = max_usable_threads(); let circuit = keccak256_circuit::(); diff --git a/mpcs/Cargo.toml b/mpcs/Cargo.toml index eaa4d3742..26ee123ba 100644 --- a/mpcs/Cargo.toml +++ b/mpcs/Cargo.toml @@ -1,74 +1,65 @@ [package] -name = "mpcs" -version.workspace = true edition.workspace = true license.workspace = true - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +name = "mpcs" +version.workspace = true [dependencies] -ctr = { git = "https://github.com/RustCrypto/block-modes", rev = "78fefba" } +aes = "0.8" +ark-std.workspace = true +bitvec = "1.0" +ctr = "0.9" +ff.workspace = true ff_ext = { path = "../ff_ext" } -serde.workspace = true +# TODO: move to version 1, once our dependencies are updated +generic-array = { version = "0.14", features = ["serde"] } goldilocks.workspace = true -ff.workspace = true -ark-std = { version = "0.4" } -aes = "0.8" -itertools = "0.12.1" -num-bigint = "0.4.6" -num-integer = "0.1.46" -rand = "0.8" -bitvec = "1.0.1" -halo2_curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", tag = "0.3.3", package = "halo2curves", features = [ - "derive_serde", -] } -blake2b_simd = "1.0.2" -sha3 = "0.10.8" -rayon = { version = "1.10.0", optional = true } -blake2 = { version = "0.10.6" } +itertools.workspace = true +multilinear_extensions = { path = "../multilinear_extensions" } +num-bigint = "0.4" +num-integer = "0.1" +plonky2.workspace = true poseidon.workspace = true +rand.workspace = true rand_chacha.workspace = true -generic-array = { version = "0.14.7", features = ["serde"] } -multilinear_extensions = { path = "../multilinear_extensions" } +rayon = { workspace = true, optional = true } +serde.workspace = true transcript = { path = "../transcript" } -plonky2 = "0.2.2" [dev-dependencies] -simple-frontend = { path = "../simple-frontend" } -gkr = { path = "../gkr" } -criterion = { version = "0.5", features = ["html_reports"] } +criterion.workspace = true [features] -default = ["parallel"] # Add "sanity-check" to debug -parallel = ["dep:rayon"] benchmark = ["parallel"] +default = ["parallel"] # Add "sanity-check" to debug +parallel = ["dep:rayon"] +print-trace = ["ark-std/print-trace"] sanity-check = [] -print-trace = [ "ark-std/print-trace" ] [[bench]] -name = "commit_open_verify_rs" harness = false +name = "commit_open_verify_rs" [[bench]] -name = "commit_open_verify_basecode" harness = false +name = "commit_open_verify_basecode" [[bench]] -name = "basecode" harness = false +name = "basecode" [[bench]] -name = "rscode" harness = false +name = "rscode" [[bench]] -name = "interpolate" harness = false +name = "interpolate" [[bench]] -name = "fft" harness = false +name = "fft" [[bench]] -name = "utils" harness = false +name = "utils" diff --git a/mpcs/src/basefold/encoding/basecode.rs b/mpcs/src/basefold/encoding/basecode.rs index 6d10b5bea..44b91ba76 100644 --- a/mpcs/src/basefold/encoding/basecode.rs +++ b/mpcs/src/basefold/encoding/basecode.rs @@ -10,7 +10,7 @@ use crate::{ }; use aes::cipher::{KeyIvInit, StreamCipher, StreamCipherSeek}; use ark_std::{end_timer, start_timer}; -use ff::{BatchInverter, Field, PrimeField}; +use ff::{BatchInvert, Field, PrimeField}; use ff_ext::ExtensionField; use generic_array::GenericArray; use multilinear_extensions::mle::FieldType; @@ -354,8 +354,7 @@ pub fn get_table_aes( .collect(); // Then invert all the elements. Now weights = { -1/2x } - let mut scratch_space = vec![E::BaseField::ZERO; weights.len()]; - BatchInverter::invert_with_external_scratch(&mut weights, &mut scratch_space); + BatchInvert::batch_invert(&mut weights); // Zip x and -1/2x together. The result is the list { (x, -1/2x) } // What is this -1/2x? It is used in linear interpolation over the domain (x, -x), which diff --git a/mpcs/src/util/plonky2_util/mod.rs b/mpcs/src/util/plonky2_util/mod.rs index 909d83333..8ebe0e7d0 100644 --- a/mpcs/src/util/plonky2_util/mod.rs +++ b/mpcs/src/util/plonky2_util/mod.rs @@ -1,10 +1,10 @@ -#![allow(clippy::new_without_default)] -#![allow(clippy::too_many_arguments)] -#![allow(clippy::type_complexity)] +//! This module is adapted from https://github.com/0xPolygonZero/plonky2 + #![allow(clippy::needless_range_loop)] extern crate alloc; +use alloc::vec::Vec; use core::{ hint::unreachable_unchecked, mem::size_of, @@ -13,16 +13,13 @@ use core::{ use ff_ext::ExtensionField; use multilinear_extensions::mle::FieldType; -mod transpose_util; use transpose_util::transpose_in_place_square; -pub fn bits_u64(n: u64) -> usize { - (64 - n.leading_zeros()) as usize -} +mod transpose_util; /// Computes `ceil(log_2(n))`. #[must_use] -pub fn log2_ceil(n: usize) -> usize { +pub const fn log2_ceil(n: usize) -> usize { (usize::BITS - n.saturating_sub(1).leading_zeros()) as usize } @@ -71,7 +68,8 @@ pub fn reverse_index_bits(arr: &[T]) -> Vec { // } // where reverse_bits(i, n_power) computes the n_power-bit reverse. The complications are there // to guide the compiler to generate optimal assembly. -pub fn reverse_bits(n: usize, num_bits: usize) -> usize { + +pub(crate) const fn reverse_bits(n: usize, num_bits: usize) -> usize { // NB: The only reason we need overflowing_shr() here as opposed // to plain '>>' is to accommodate the case n == num_bits == 0, // which would become `0 >> 64`. Rust thinks that any shift of 64 @@ -80,6 +78,7 @@ pub fn reverse_bits(n: usize, num_bits: usize) -> usize { .overflowing_shr(usize::BITS - num_bits as u32) .0 } + fn reverse_index_bits_small(arr: &[T], n_power: usize) -> Vec { let n = arr.len(); let mut result = Vec::with_capacity(n); @@ -112,14 +111,17 @@ fn reverse_index_bits_large(arr: &[T], n_power: usize) -> Vec { } /// Bit-reverse the order of elements in `arr`. -/// SAFETY: ensure that `arr.len() == 1 << lg_n`. +/// SAFETY: ensure that `arr.len() == 1 << lb_n`. #[cfg(not(target_arch = "aarch64"))] -unsafe fn reverse_index_bits_in_place_small(arr: &mut [T], lg_n: usize) { - if lg_n <= 6 { - // BIT_REVERSE_6BIT holds 6-bit reverses. This shift makes them lg_n-bit reverses. - let dst_shr_amt = 6 - lg_n; +unsafe fn reverse_index_bits_in_place_small(arr: &mut [T], lb_n: usize) { + if lb_n <= 6 { + // BIT_REVERSE_6BIT holds 6-bit reverses. This shift makes them lb_n-bit reverses. + let dst_shr_amt = 6 - lb_n as u32; for src in 0..arr.len() { - let dst = (BIT_REVERSE_6BIT[src] as usize) >> dst_shr_amt; + // `wrapping_shr` handles the case when `arr.len() == 1`. In that case `src == 0`, so + // `src.reverse_bits() == 0`. `usize::wrapping_shr` by 64 is a no-op, but it gives the + // correct result. + let dst = (BIT_REVERSE_6BIT[src] as usize).wrapping_shr(dst_shr_amt); if src < dst { swap(arr.get_unchecked_mut(src), arr.get_unchecked_mut(dst)); } @@ -128,11 +130,14 @@ unsafe fn reverse_index_bits_in_place_small(arr: &mut [T], lg_n: usize) { // LLVM does not know that it does not need to reverse src at each iteration (which is // expensive on x86). We take advantage of the fact that the low bits of dst change rarely and the high // bits of dst are dependent only on the low bits of src. - let dst_lo_shr_amt = 64 - (lg_n - 6); - let dst_hi_shl_amt = lg_n - 6; + let dst_lo_shr_amt = usize::BITS - (lb_n - 6) as u32; + let dst_hi_shl_amt = lb_n - 6; for src_chunk in 0..(arr.len() >> 6) { let src_hi = src_chunk << 6; - let dst_lo = src_chunk.reverse_bits() >> dst_lo_shr_amt; + // `wrapping_shr` handles the case when `arr.len() == 1`. In that case `src == 0`, so + // `src.reverse_bits() == 0`. `usize::wrapping_shr` by 64 is a no-op, but it gives the + // correct result. + let dst_lo = src_chunk.reverse_bits().wrapping_shr(dst_lo_shr_amt); for src_lo in 0..(1 << 6) { let dst_hi = (BIT_REVERSE_6BIT[src_lo] as usize) << dst_hi_shl_amt; let src = src_hi + src_lo; @@ -146,15 +151,15 @@ unsafe fn reverse_index_bits_in_place_small(arr: &mut [T], lg_n: usize) { } /// Bit-reverse the order of elements in `arr`. -/// SAFETY: ensure that `arr.len() == 1 << lg_n`. +/// SAFETY: ensure that `arr.len() == 1 << lb_n`. #[cfg(target_arch = "aarch64")] -unsafe fn reverse_index_bits_in_place_small(arr: &mut [T], lg_n: usize) { +unsafe fn reverse_index_bits_in_place_small(arr: &mut [T], lb_n: usize) { // Aarch64 can reverse bits in one instruction, so the trivial version works best. for src in 0..arr.len() { // `wrapping_shr` handles the case when `arr.len() == 1`. In that case `src == 0`, so // `src.reverse_bits() == 0`. `usize::wrapping_shr` by 64 is a no-op, but it gives the // correct result. - let dst = src.reverse_bits().wrapping_shr(usize::BITS - lg_n as u32); + let dst = src.reverse_bits().wrapping_shr(usize::BITS - lb_n as u32); if src < dst { swap(arr.get_unchecked_mut(src), arr.get_unchecked_mut(dst)); } @@ -197,38 +202,38 @@ const BIG_T_SIZE: usize = 1 << 14; const SMALL_ARR_SIZE: usize = 1 << 16; pub fn reverse_index_bits_in_place(arr: &mut [T]) { let n = arr.len(); - let lg_n = log2_strict(n); + let lb_n = log2_strict(n); // If the whole array fits in fast cache, then the trivial algorithm is cache friendly. Also, if // `T` is really big, then the trivial algorithm is cache-friendly, no matter the size of the // array. - if size_of::() << lg_n <= SMALL_ARR_SIZE || size_of::() >= BIG_T_SIZE { + if size_of::() << lb_n <= SMALL_ARR_SIZE || size_of::() >= BIG_T_SIZE { unsafe { - reverse_index_bits_in_place_small(arr, lg_n); + reverse_index_bits_in_place_small(arr, lb_n); } } else { debug_assert!(n >= 4); // By our choice of `BIG_T_SIZE` and `SMALL_ARR_SIZE`. // Algorithm: // - // Treat `arr` as a `sqrt(n)` by `sqrt(n)` row-major matrix. (Assume for now that `lg_n` is + // Treat `arr` as a `sqrt(n)` by `sqrt(n)` row-major matrix. (Assume for now that `lb_n` is // even, i.e., `n` is a square number.) To perform bit-order reversal we: // 1. Bit-reverse the order of the rows. (They are contiguous in memory, so this is // basically a series of large `memcpy`s.) // 2. Transpose the matrix. // 3. Bit-reverse the order of the rows. // This is equivalent to, for every index `0 <= i < n`: - // 1. bit-reversing `i[lg_n / 2..lg_n]`, - // 2. swapping `i[0..lg_n / 2]` and `i[lg_n / 2..lg_n]`, - // 3. bit-reversing `i[lg_n / 2..lg_n]`. + // 1. bit-reversing `i[lb_n / 2..lb_n]`, + // 2. swapping `i[0..lb_n / 2]` and `i[lb_n / 2..lb_n]`, + // 3. bit-reversing `i[lb_n / 2..lb_n]`. // - // If `lg_n` is odd, i.e., `n` is not a square number, then the above procedure requires - // slight modification. At steps 1 and 3 we bit-reverse bits `ceil(lg_n / 2)..lg_n`, of the - // index (shuffling `floor(lg_n / 2)` chunks of length `ceil(lg_n / 2)`). At step 2, we + // If `lb_n` is odd, i.e., `n` is not a square number, then the above procedure requires + // slight modification. At steps 1 and 3 we bit-reverse bits `ceil(lb_n / 2)..lb_n`, of the + // index (shuffling `floor(lb_n / 2)` chunks of length `ceil(lb_n / 2)`). At step 2, we // perform _two_ transposes. We treat `arr` as two matrices, one where the middle bit of the // index is `0` and another, where the middle bit is `1`; we transpose each individually. - let lb_num_chunks = lg_n >> 1; - let lb_chunk_size = lg_n - lb_num_chunks; + let lb_num_chunks = lb_n >> 1; + let lb_chunk_size = lb_n - lb_num_chunks; unsafe { reverse_index_bits_in_place_chunks(arr, lb_num_chunks, lb_chunk_size); transpose_in_place_square(arr, lb_chunk_size, lb_num_chunks, 0); diff --git a/mpcs/src/util/plonky2_util/transpose_util.rs b/mpcs/src/util/plonky2_util/transpose_util.rs index 6e19a3635..4cc4cc9e5 100644 --- a/mpcs/src/util/plonky2_util/transpose_util.rs +++ b/mpcs/src/util/plonky2_util/transpose_util.rs @@ -1,3 +1,5 @@ +//! This module is adapted from https://github.com/0xPolygonZero/plonky2 + use core::ptr::swap; const LB_BLOCK_SIZE: usize = 3; @@ -27,7 +29,7 @@ unsafe fn transpose_in_place_square_small( } /// Transpose square matrices and swap -/// The matrices are of of size `1 << lb_size` by `1 << lb_size`. They occupy +/// The matrices are of size `1 << lb_size` by `1 << lb_size`. They occupy /// `M0[i, j] == arr[(i + x << lb_stride) + j + y]`, `M1[i, j] == arr[i + x + (j + y << lb_stride)]` /// for `0 <= i, j < 1 << lb_size. The transposition swaps `M0[i, j]` and `M1[j, i]`. /// @@ -53,7 +55,7 @@ unsafe fn transpose_swap_square_small( } /// Transpose square matrices and swap -/// The matrices are of of size `1 << lb_size` by `1 << lb_size`. They occupy +/// The matrices are of size `1 << lb_size` by `1 << lb_size`. They occupy /// `M0[i, j] == arr[(i + x << lb_stride) + j + y]`, `M1[i, j] == arr[i + x + (j + y << lb_stride)]` /// for `0 <= i, j < 1 << lb_size. The transposition swaps `M0[i, j]` and `M1[j, i]`. /// diff --git a/multilinear_extensions/Cargo.toml b/multilinear_extensions/Cargo.toml index 63db80a27..ad364def1 100644 --- a/multilinear_extensions/Cargo.toml +++ b/multilinear_extensions/Cargo.toml @@ -1,27 +1,22 @@ [package] -name = "multilinear_extensions" -version.workspace = true edition.workspace = true license.workspace = true - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +name = "multilinear_extensions" +version.workspace = true [dependencies] -tracing = "0.1.40" -tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } -tracing-flame = "0.2.0" -ff_ext = { path = "../ff_ext" } -itertools = "0.12.1" ark-std.workspace = true ff.workspace = true +ff_ext = { path = "../ff_ext" } goldilocks.workspace = true +itertools.workspace = true rayon.workspace = true serde.workspace = true +tracing.workspace = true [dev-dependencies] -criterion = { version = "0.5", features = ["html_reports"] } +env_logger = "0.11" log = "0.4" -env_logger = "0.7" [features] -parallel = [ ] +parallel = [] diff --git a/multilinear_extensions/src/mle.rs b/multilinear_extensions/src/mle.rs index 8ce82ae02..8a182e645 100644 --- a/multilinear_extensions/src/mle.rs +++ b/multilinear_extensions/src/mle.rs @@ -44,32 +44,22 @@ pub trait MultilinearExtension: Send + Sync { fn name(&self) -> &'static str; fn get_ext_field_vec(&self) -> &[E] { - self.get_ext_field_vec_optn() - .unwrap_or_else(|| unreachable!()) - } - - fn get_base_field_vec(&self) -> &[E::BaseField] { - self.get_base_field_vec_optn() - .unwrap_or_else(|| unreachable!()) - } - - fn get_ext_field_vec_optn(&self) -> Option<&[E]> { match &self.evaluations() { FieldType::Ext(evaluations) => { let (start, offset) = self.evaluations_range().unwrap_or((0, evaluations.len())); - Some(&evaluations[start..][..offset]) + &evaluations[start..][..offset] } - _ => None, + _ => panic!("evaluation not in extension field"), } } - fn get_base_field_vec_optn(&self) -> Option<&[E::BaseField]> { + fn get_base_field_vec(&self) -> &[E::BaseField] { match &self.evaluations() { FieldType::Base(evaluations) => { let (start, offset) = self.evaluations_range().unwrap_or((0, evaluations.len())); - Some(&evaluations[start..][..offset]) + &evaluations[start..][..offset] } - _ => None, + _ => panic!("evaluation not in base field"), } } } @@ -158,6 +148,14 @@ impl FieldType { FieldType::Unreachable => true, } } + + pub fn variant_name(&self) -> &'static str { + match self { + FieldType::Base(_) => "Base", + FieldType::Ext(_) => "Ext", + FieldType::Unreachable => "Unreachable", + } + } } /// Stores a multilinear polynomial in dense evaluation form. @@ -1032,101 +1030,131 @@ macro_rules! op_mle { } #[macro_export] -macro_rules! op_mle_3 { +macro_rules! op_mle3_range { + ($x:ident, $a:ident, $b:ident, $x_vec:ident, $a_vec:ident, $b_vec:ident, $op:expr, |$bb_out:ident| $op_bb_out:expr) => {{ + let $x = if let Some((start, offset)) = $x.evaluations_range() { + &$x_vec[start..][..offset] + } else { + &$x_vec[..] + }; + let $a = if let Some((start, offset)) = $a.evaluations_range() { + &$a_vec[start..][..offset] + } else { + &$a_vec[..] + }; + let $b = if let Some((start, offset)) = $b.evaluations_range() { + &$b_vec[start..][..offset] + } else { + &$b_vec[..] + }; + let $bb_out = $op; + $op_bb_out + }}; +} + +/// deal with x * a + b +#[macro_export] +macro_rules! op_mle_xa_b { + (|$x:ident, $a:ident, $b:ident| $op:expr, |$bb_out:ident| $op_bb_out:expr) => { + match (&$x.evaluations(), &$a.evaluations(), &$b.evaluations()) { + ( + $crate::mle::FieldType::Base(x_vec), + $crate::mle::FieldType::Base(a_vec), + $crate::mle::FieldType::Base(b_vec), + ) => { + op_mle3_range!($x, $a, $b, x_vec, a_vec, b_vec, $op, |$bb_out| $op_bb_out) + } + ( + $crate::mle::FieldType::Base(x_vec), + $crate::mle::FieldType::Ext(a_vec), + $crate::mle::FieldType::Base(b_vec), + ) => { + op_mle3_range!($x, $a, $b, x_vec, a_vec, b_vec, $op, |$bb_out| $op_bb_out) + } + ( + $crate::mle::FieldType::Base(x_vec), + $crate::mle::FieldType::Ext(a_vec), + $crate::mle::FieldType::Ext(b_vec), + ) => { + op_mle3_range!($x, $a, $b, x_vec, a_vec, b_vec, $op, |$bb_out| $op_bb_out) + } + (x, a, b) => unreachable!( + "unmatched pattern {:?} {:?} {:?}", + x.variant_name(), + a.variant_name(), + b.variant_name() + ), + } + }; + (|$x:ident, $a:ident, $b:ident| $op:expr) => { + op_mle_xa_b!(|$x, $a, $b| $op, |out| out) + }; +} + +/// deal with f1 * f2 * f3 +/// applying cumulative rule for f1, f2, f3 to canonical form: Ext field comes first following by Base Field +#[macro_export] +macro_rules! op_mle_product_3 { (|$f1:ident, $f2:ident, $f3:ident| $op:expr, |$bb_out:ident| $op_bb_out:expr) => { match (&$f1.evaluations(), &$f2.evaluations(), &$f3.evaluations()) { + // capture non-canonical form ( - $crate::mle::FieldType::Base(f1), - $crate::mle::FieldType::Base(f2), - $crate::mle::FieldType::Base(f3), + $crate::mle::FieldType::Ext(_), + $crate::mle::FieldType::Base(_), + $crate::mle::FieldType::Ext(_), ) => { - let $f1 = if let Some((start, offset)) = $f1.evaluations_range() { - &f1[start..][..offset] - } else { - &f1[..] - }; - let $f2 = if let Some((start, offset)) = $f2.evaluations_range() { - &f2[start..][..offset] - } else { - &f2[..] - }; - let $f3 = if let Some((start, offset)) = $f3.evaluations_range() { - &f3[start..][..offset] - } else { - &f3[..] - }; - let $bb_out = $op; - $op_bb_out + op_mle_product_3!(@internal |$f1, $f3, $f2| { + let ($f2, $f3) = ($f3, $f2); + $op + }, |$bb_out| $op_bb_out) } + // ...add more non-canonical form + // default will go canonical form + _ => op_mle_product_3!(@internal |$f1, $f2, $f3| $op, |$bb_out| $op_bb_out), + } + }; + (@internal |$f1:ident, $f2:ident, $f3:ident| $op:expr, |$bb_out:ident| $op_bb_out:expr) => { + match (&$f1.evaluations(), &$f2.evaluations(), &$f3.evaluations()) { ( - $crate::mle::FieldType::Ext(f1), - $crate::mle::FieldType::Base(f2), - $crate::mle::FieldType::Base(f3), + $crate::mle::FieldType::Base(f1_vec), + $crate::mle::FieldType::Base(f2_vec), + $crate::mle::FieldType::Base(f3_vec), ) => { - let $f1 = if let Some((start, offset)) = $f1.evaluations_range() { - &f1[start..][..offset] - } else { - &f1[..] - }; - let $f2 = if let Some((start, offset)) = $f2.evaluations_range() { - &f2[start..][..offset] - } else { - &f2[..] - }; - let $f3 = if let Some((start, offset)) = $f3.evaluations_range() { - &f3[start..][..offset] - } else { - &f3[..] - }; - $op + op_mle3_range!($f1, $f2, $f3, f1_vec, f2_vec, f3_vec, $op, |$bb_out| $op_bb_out) } ( - $crate::mle::FieldType::Ext(f1), - $crate::mle::FieldType::Ext(f2), - $crate::mle::FieldType::Ext(f3), + $crate::mle::FieldType::Ext(f1_vec), + $crate::mle::FieldType::Base(f2_vec), + $crate::mle::FieldType::Base(f3_vec), ) => { - let $f1 = if let Some((start, offset)) = $f1.evaluations_range() { - &f1[start..][..offset] - } else { - &f1[..] - }; - let $f2 = if let Some((start, offset)) = $f2.evaluations_range() { - &f2[start..][..offset] - } else { - &f2[..] - }; - let $f3 = if let Some((start, offset)) = $f3.evaluations_range() { - &f3[start..][..offset] - } else { - &f3[..] - }; - $op + op_mle3_range!($f1, $f2, $f3, f1_vec, f2_vec, f3_vec, $op, |$bb_out| $op_bb_out) } ( - $crate::mle::FieldType::Ext(f1), - $crate::mle::FieldType::Ext(f2), - $crate::mle::FieldType::Base(f3), + $crate::mle::FieldType::Ext(f1_vec), + $crate::mle::FieldType::Ext(f2_vec), + $crate::mle::FieldType::Ext(f3_vec), ) => { - let $f1 = if let Some((start, offset)) = $f1.evaluations_range() { - &f1[start..][..offset] - } else { - &f1[..] - }; - let $f2 = if let Some((start, offset)) = $f2.evaluations_range() { - &f2[start..][..offset] - } else { - &f2[..] - }; - let $f3 = if let Some((start, offset)) = $f3.evaluations_range() { - &f3[start..][..offset] - } else { - &f3[..] - }; - $op + op_mle3_range!($f1, $f2, $f3, f1_vec, f2_vec, f3_vec, $op, |$bb_out| $op_bb_out) } - _ => unreachable!(), + ( + $crate::mle::FieldType::Ext(f1_vec), + $crate::mle::FieldType::Ext(f2_vec), + $crate::mle::FieldType::Base(f3_vec), + ) => { + op_mle3_range!($f1, $f2, $f3, f1_vec, f2_vec, f3_vec, $op, |$bb_out| $op_bb_out) + } + // ... add more canonial case if missing + (a, b, c) => unreachable!( + "unmatched pattern {:?} {:?} {:?}", + a.variant_name(), + b.variant_name(), + c.variant_name() + ), } }; + (|$f1:ident, $f2:ident, $f3:ident| $op:expr) => { + op_mle_product_3!(|$f1, $f2, $f3| $op, |out| out) + }; } /// macro support op(a, b) and tackles type matching internally. diff --git a/multilinear_extensions/src/util.rs b/multilinear_extensions/src/util.rs index 28e4f8284..a0a8e56a2 100644 --- a/multilinear_extensions/src/util.rs +++ b/multilinear_extensions/src/util.rs @@ -30,3 +30,21 @@ pub fn create_uninit_vec(len: usize) -> Vec> { pub fn largest_even_below(n: usize) -> usize { if n % 2 == 0 { n } else { n.saturating_sub(1) } } + +fn prev_power_of_two(n: usize) -> usize { + (n + 1).next_power_of_two() / 2 +} + +/// Largest power of two that fits the available rayon threads +pub fn max_usable_threads() -> usize { + if cfg!(test) { + 1 + } else { + let n = rayon::current_num_threads(); + let threads = prev_power_of_two(n); + if n != threads { + tracing::warn!("thread size {n} is not power of 2, using {threads} threads instead."); + } + threads + } +} diff --git a/multilinear_extensions/src/virtual_poly.rs b/multilinear_extensions/src/virtual_poly.rs index a5468ad90..45e4e9fb7 100644 --- a/multilinear_extensions/src/virtual_poly.rs +++ b/multilinear_extensions/src/virtual_poly.rs @@ -2,7 +2,7 @@ use std::{cmp::max, collections::HashMap, marker::PhantomData, mem::MaybeUninit, use crate::{ mle::{ArcDenseMultilinearExtension, DenseMultilinearExtension, MultilinearExtension}, - util::{bit_decompose, create_uninit_vec}, + util::{bit_decompose, create_uninit_vec, max_usable_threads}, }; use ark_std::{end_timer, iterable::Iterable, rand::Rng, start_timer}; use ff::{Field, PrimeField}; @@ -452,8 +452,7 @@ pub fn build_eq_x_r_vec(r: &[E]) -> Vec { // .... // 1 1 1 1 -> r0 * r1 * r2 * r3 // we will need 2^num_var evaluations - let nthreads = - std::env::var("RAYON_NUM_THREADS").map_or(8, |s| s.parse::().unwrap_or(8)); + let nthreads = max_usable_threads(); let nbits = nthreads.trailing_zeros() as usize; assert_eq!(1 << nbits, nthreads); diff --git a/poseidon/Cargo.toml b/poseidon/Cargo.toml index 9560dffba..489f4efc1 100644 --- a/poseidon/Cargo.toml +++ b/poseidon/Cargo.toml @@ -1,21 +1,21 @@ [package] -name = "poseidon" -version.workspace = true edition.workspace = true license.workspace = true +name = "poseidon" +version.workspace = true [dependencies] criterion.workspace = true ff.workspace = true goldilocks.workspace = true serde.workspace = true -unroll = "0.1.5" +unroll = "0.1" [dev-dependencies] -plonky2 = "0.2.2" -rand = "0.8.5" ark-std.workspace = true +plonky2.workspace = true +rand.workspace = true [[bench]] -name = "hashing" harness = false +name = "hashing" diff --git a/rustfmt.toml b/rustfmt.toml index fd6aeb116..1ff1f984d 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -1,8 +1,8 @@ -edition = "2024" -style_edition = "2024" -wrap_comments = false comment_width = 300 +edition = "2024" imports_granularity = "Crate" max_width = 100 newline_style = "Unix" normalize_comments = true +style_edition = "2024" +wrap_comments = false diff --git a/simple-frontend/Cargo.toml b/simple-frontend/Cargo.toml index 60af0127c..1e642b3e6 100644 --- a/simple-frontend/Cargo.toml +++ b/simple-frontend/Cargo.toml @@ -1,17 +1,13 @@ [package] -name = "simple-frontend" -version.workspace = true edition.workspace = true license.workspace = true - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +name = "simple-frontend" +version.workspace = true [dependencies] -ff_ext = { path = "../ff_ext" } -ark-std.workspace = true ff.workspace = true +ff_ext = { path = "../ff_ext" } goldilocks.workspace = true -itertools = "0.12.1" -rayon.workspace = true +itertools.workspace = true serde.workspace = true diff --git a/singer-utils/Cargo.toml b/singer-utils/Cargo.toml index 11e391b02..e679b06ef 100644 --- a/singer-utils/Cargo.toml +++ b/singer-utils/Cargo.toml @@ -1,23 +1,21 @@ [package] -name = "singer-utils" -version.workspace = true edition.workspace = true license.workspace = true - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +name = "singer-utils" +version.workspace = true [dependencies] ark-std.workspace = true -ff_ext = { path = "../ff_ext" } ff.workspace = true +ff_ext = { path = "../ff_ext" } goldilocks.workspace = true gkr = { path = "../gkr", features = ["parallel"] } -itertools = "0.12.1" -simple-frontend = { version = "0.1.0", path = "../simple-frontend" } -gkr-graph = { version = "0.1.0", path = "../gkr-graph" } -sumcheck = { version = "0.1.0", path = "../sumcheck" } -strum = "0.26.3" -strum_macros = "0.26.4" -transcript = { version = "0.1.0", path = "../transcript" } -multilinear_extensions = { path = "../multilinear_extensions", features = [ "parallel"] } +gkr-graph = { version = "0", path = "../gkr-graph" } +itertools.workspace = true +multilinear_extensions = { path = "../multilinear_extensions", features = ["parallel"] } +simple-frontend = { version = "0", path = "../simple-frontend" } +strum.workspace = true +strum_macros.workspace = true +sumcheck = { version = "0", path = "../sumcheck" } +transcript = { version = "0", path = "../transcript" } diff --git a/singer/Cargo.toml b/singer/Cargo.toml index 15cb11636..ac71ab4b9 100644 --- a/singer/Cargo.toml +++ b/singer/Cargo.toml @@ -1,10 +1,8 @@ [package] -name = "singer" -version.workspace = true edition.workspace = true license.workspace = true - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +name = "singer" +version.workspace = true [dependencies] ark-std.workspace = true @@ -13,36 +11,34 @@ goldilocks.workspace = true rayon.workspace = true serde.workspace = true +ff_ext = { path = "../ff_ext" } gkr = { path = "../gkr", features = ["parallel"] } +gkr-graph = { version = "0", path = "../gkr-graph" } +multilinear_extensions = { version = "0", path = "../multilinear_extensions" } +simple-frontend = { version = "0", path = "../simple-frontend" } +singer-utils = { version = "0", path = "../singer-utils" } transcript = { path = "../transcript" } -mpcs = { path = "../mpcs" } -gkr-graph = { version = "0.1.0", path = "../gkr-graph" } -simple-frontend = { version = "0.1.0", path = "../simple-frontend" } -singer-utils = { version = "0.1.0", path = "../singer-utils" } -multilinear_extensions = { version = "0.1.0", path = "../multilinear_extensions" } -ff_ext = { path = "../ff_ext" } -itertools = "0.12.1" -strum = "0.25.0" -strum_macros = "0.25.3" -paste = "1.0.15" -tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } -tracing-flame = "0.2.0" -tracing = "0.1.40" +itertools.workspace = true +paste.workspace = true +strum.workspace = true +strum_macros.workspace = true +tracing-flame.workspace = true +tracing-subscriber.workspace = true [dev-dependencies] -pprof = { version = "0.13", features = ["flamegraph"]} -criterion = { version = "0.5", features = ["html_reports"] } -cfg-if = "1.0.0" -const_env = "0.1.2" +cfg-if.workspace = true +criterion.workspace = true +pprof.workspace = true +tracing.workspace = true [features] -witness-count = [] -test-dbg = [] dbg-add-opcode = [] -non_pow2_rayon_thread = [] flamegraph = ["pprof/flamegraph", "pprof/criterion"] +non_pow2_rayon_thread = [] +test-dbg = [] +witness-count = [] [[bench]] -name = "add" harness = false +name = "add" diff --git a/singer/benches/add.rs b/singer/benches/add.rs index 70fee6f28..5984a19b2 100644 --- a/singer/benches/add.rs +++ b/singer/benches/add.rs @@ -4,7 +4,6 @@ use std::time::{Duration, Instant}; use ark_std::test_rng; -use const_env::from_env; use criterion::*; use ff_ext::{ExtensionField, ff::Field}; @@ -30,9 +29,8 @@ cfg_if::cfg_if! { criterion_main!(op_add); const NUM_SAMPLES: usize = 10; -#[from_env] -const RAYON_NUM_THREADS: usize = 8; +use multilinear_extensions::util::max_usable_threads; use singer::{ CircuitWiresIn, SingerGraphBuilder, SingerParams, instructions::{Instruction, InstructionGraph, SingerCircuitBuilder, add::AddInstruction}, @@ -42,26 +40,7 @@ use singer_utils::structs::ChipChallenges; use transcript::Transcript; fn bench_add(c: &mut Criterion) { - let max_thread_id = { - if !RAYON_NUM_THREADS.is_power_of_two() { - #[cfg(not(feature = "non_pow2_rayon_thread"))] - { - panic!( - "add --features non_pow2_rayon_thread to enable unsafe feature which support non pow of 2 rayon thread pool" - ); - } - - #[cfg(feature = "non_pow2_rayon_thread")] - { - use sumcheck::{local_thread_pool::create_local_pool_once, util::ceil_log2}; - let max_thread_id = 1 << ceil_log2(RAYON_NUM_THREADS); - create_local_pool_once(1 << ceil_log2(RAYON_NUM_THREADS), true); - max_thread_id - } - } else { - RAYON_NUM_THREADS - } - }; + let max_thread_id = max_usable_threads(); let chip_challenges = ChipChallenges::default(); let circuit_builder = SingerCircuitBuilder::::new(chip_challenges).expect("circuit builder failed"); diff --git a/sumcheck/Cargo.toml b/sumcheck/Cargo.toml index f46b2d27a..540495c0a 100644 --- a/sumcheck/Cargo.toml +++ b/sumcheck/Cargo.toml @@ -1,34 +1,29 @@ [package] -name = "sumcheck" -version.workspace = true edition.workspace = true license.workspace = true - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +name = "sumcheck" +version.workspace = true [dependencies] -ff_ext = { path = "../ff_ext" } -ark-ff = "0.4.2" -tracing = "0.1.40" -const_env = "0.1.2" -itertools = "0.12.1" ark-std.workspace = true ff.workspace = true +ff_ext = { path = "../ff_ext" } goldilocks.workspace = true +itertools.workspace = true rayon.workspace = true serde.workspace = true +tracing.workspace = true -multilinear_extensions = { path = "../multilinear_extensions", features = [ "parallel"] } +crossbeam-channel.workspace = true +multilinear_extensions = { path = "../multilinear_extensions", features = ["parallel"] } transcript = { path = "../transcript" } -crossbeam-channel = "0.5.13" [dev-dependencies] -pprof = { version = "0.13", features = ["flamegraph"]} -criterion = { version = "0.5", features = ["html_reports"] } +criterion.workspace = true [[bench]] -name = "devirgo_sumcheck" harness = false +name = "devirgo_sumcheck" [features] -non_pow2_rayon_thread = [ ] +non_pow2_rayon_thread = [] diff --git a/sumcheck/benches/devirgo_sumcheck.rs b/sumcheck/benches/devirgo_sumcheck.rs index a051e7615..fd33b9d09 100644 --- a/sumcheck/benches/devirgo_sumcheck.rs +++ b/sumcheck/benches/devirgo_sumcheck.rs @@ -1,20 +1,20 @@ #![allow(clippy::manual_memcpy)] #![allow(clippy::needless_range_loop)] -use std::sync::Arc; +use std::array; use ark_std::test_rng; -use const_env::from_env; use criterion::*; -use ff_ext::{ExtensionField, ff::Field}; +use ff_ext::ExtensionField; use itertools::Itertools; -use sumcheck::{structs::IOPProverState, util::ceil_log2}; +use sumcheck::{structs::IOPProverStateV2 as IOPProverState, util::ceil_log2}; use goldilocks::GoldilocksExt2; use multilinear_extensions::{ - commutative_op_mle_pair, - mle::{ArcDenseMultilinearExtension, DenseMultilinearExtension, MultilinearExtension}, - virtual_poly::VirtualPolynomial, + mle::DenseMultilinearExtension, + op_mle, + util::max_usable_threads, + virtual_poly_v2::{ArcMultilinearExtension, VirtualPolynomialV2 as VirtualPolynomial}, }; use transcript::Transcript; @@ -22,64 +22,88 @@ criterion_group!(benches, sumcheck_fn, devirgo_sumcheck_fn,); criterion_main!(benches); const NUM_SAMPLES: usize = 10; +const NUM_DEGREE: usize = 3; +const NV: [usize; 2] = [25, 26]; + +/// transpose 2d vector without clone +pub fn transpose(v: Vec>) -> Vec> { + assert!(!v.is_empty()); + let len = v[0].len(); + let mut iters: Vec<_> = v.into_iter().map(|n| n.into_iter()).collect(); + (0..len) + .map(|_| { + iters + .iter_mut() + .map(|n| n.next().unwrap()) + .collect::>() + }) + .collect() +} -fn prepare_input( - max_thread_id: usize, +fn prepare_input<'a, E: ExtensionField>( nv: usize, -) -> (E, VirtualPolynomial, Vec>) { +) -> (E, VirtualPolynomial<'a, E>, Vec>) { let mut rng = test_rng(); + let max_thread_id = max_usable_threads(); let size_log2 = ceil_log2(max_thread_id); - let f1: Arc> = - DenseMultilinearExtension::::random(nv, &mut rng).into(); - let g1: Arc> = - DenseMultilinearExtension::::random(nv, &mut rng).into(); - - let mut virtual_poly_1 = VirtualPolynomial::new_from_mle(f1.clone(), E::BaseField::ONE); - virtual_poly_1.mul_by_mle(g1.clone(), ::BaseField::ONE); - - let mut virtual_poly_f1: Vec> = match &f1.evaluations { - multilinear_extensions::mle::FieldType::Base(evaluations) => evaluations - .chunks((1 << nv) >> size_log2) - .map(|chunk| { - DenseMultilinearExtension::::from_evaluations_vec(nv - size_log2, chunk.to_vec()) - .into() + let fs: [ArcMultilinearExtension<'a, E>; NUM_DEGREE] = array::from_fn(|_| { + let mle: ArcMultilinearExtension<'a, E> = + DenseMultilinearExtension::::random(nv, &mut rng).into(); + mle + }); + + let mut virtual_poly_v1 = VirtualPolynomial::new(nv); + virtual_poly_v1.add_mle_list(fs.to_vec(), E::ONE); + + // devirgo version + let virtual_poly_v2: Vec>> = transpose( + fs.iter() + .map(|f| match &f.evaluations() { + multilinear_extensions::mle::FieldType::Base(evaluations) => evaluations + .chunks((1 << nv) >> size_log2) + .map(|chunk| { + let mle: ArcMultilinearExtension<'a, E> = + DenseMultilinearExtension::::from_evaluations_vec( + nv - size_log2, + chunk.to_vec(), + ) + .into(); + mle + }) + .collect_vec(), + _ => unreachable!(), }) - .map(|mle| VirtualPolynomial::new_from_mle(mle, E::BaseField::ONE)) - .collect_vec(), - _ => unreachable!(), - }; - - let poly_g1: Vec> = match &g1.evaluations { - multilinear_extensions::mle::FieldType::Base(evaluations) => evaluations - .chunks((1 << nv) >> size_log2) - .map(|chunk| { - DenseMultilinearExtension::::from_evaluations_vec(nv - size_log2, chunk.to_vec()) - .into() + .collect(), + ); + let virtual_poly_v2: Vec> = virtual_poly_v2 + .into_iter() + .map(|fs| { + let mut virtual_polynomial = VirtualPolynomial::new(fs[0].num_vars()); + virtual_polynomial.add_mle_list(fs, E::ONE); + virtual_polynomial + }) + .collect(); + + let asserted_sum = fs + .iter() + .fold(vec![E::ONE; 1 << nv], |mut acc, f| { + op_mle!(f, |f| { + (0..f.len()).zip(acc.iter_mut()).for_each(|(i, acc)| { + *acc *= f[i]; + }); + acc }) - .collect_vec(), - _ => unreachable!(), - }; - - let asserted_sum = commutative_op_mle_pair!(|f1, g1| { - (0..f1.len()) - .map(|i| f1[i] * g1[i]) - .fold(E::ZERO, |acc, item| acc + item) - }); + }) + .iter() + .sum::(); - virtual_poly_f1 - .iter_mut() - .zip(poly_g1.iter()) - .for_each(|(f1, g1)| f1.mul_by_mle(g1.clone(), E::BaseField::ONE)); - (asserted_sum, virtual_poly_1, virtual_poly_f1) + (asserted_sum, virtual_poly_v1, virtual_poly_v2) } -#[from_env] -const RAYON_NUM_THREADS: usize = 8; - fn sumcheck_fn(c: &mut Criterion) { type E = GoldilocksExt2; - for nv in [13, 14, 15, 16].into_iter() { + for nv in NV.into_iter() { // expand more input size once runtime is acceptable let mut group = c.benchmark_group(format!("sumcheck_nv_{}", nv)); group.sample_size(NUM_SAMPLES); @@ -92,7 +116,7 @@ fn sumcheck_fn(c: &mut Criterion) { || { let prover_transcript = Transcript::::new(b"test"); let (asserted_sum, virtual_poly, virtual_poly_splitted) = - { prepare_input(RAYON_NUM_THREADS, nv) }; + { prepare_input(nv) }; ( prover_transcript, asserted_sum, @@ -123,7 +147,8 @@ fn sumcheck_fn(c: &mut Criterion) { fn devirgo_sumcheck_fn(c: &mut Criterion) { type E = GoldilocksExt2; - for nv in [13, 14, 15, 16].into_iter() { + let threads = max_usable_threads(); + for nv in NV.into_iter() { // expand more input size once runtime is acceptable let mut group = c.benchmark_group(format!("devirgo_nv_{}", nv)); group.sample_size(NUM_SAMPLES); @@ -136,7 +161,7 @@ fn devirgo_sumcheck_fn(c: &mut Criterion) { || { let prover_transcript = Transcript::::new(b"test"); let (asserted_sum, virtual_poly, virtual_poly_splitted) = - { prepare_input(RAYON_NUM_THREADS, nv) }; + { prepare_input(nv) }; ( prover_transcript, asserted_sum, @@ -151,7 +176,7 @@ fn devirgo_sumcheck_fn(c: &mut Criterion) { virtual_poly_splitted, )| { let (_sumcheck_proof_v2, _) = IOPProverState::::prove_batch_polys( - RAYON_NUM_THREADS, + threads, virtual_poly_splitted, &mut prover_transcript, ); diff --git a/sumcheck/examples/devirgo_sumcheck.rs b/sumcheck/examples/devirgo_sumcheck.rs deleted file mode 100644 index 29ff81368..000000000 --- a/sumcheck/examples/devirgo_sumcheck.rs +++ /dev/null @@ -1,112 +0,0 @@ -use std::sync::Arc; - -use ark_std::test_rng; -use const_env::from_env; -use ff_ext::{ExtensionField, ff::Field}; -use goldilocks::GoldilocksExt2; -use itertools::Itertools; -use multilinear_extensions::{ - commutative_op_mle_pair, - mle::{ArcDenseMultilinearExtension, DenseMultilinearExtension, MultilinearExtension}, - virtual_poly::VirtualPolynomial, -}; -use sumcheck::{ - structs::{IOPProverState, IOPVerifierState}, - util::ceil_log2, -}; -use transcript::Transcript; - -type E = GoldilocksExt2; - -fn prepare_input( - max_thread_id: usize, -) -> (E, VirtualPolynomial, Vec>) { - let nv = 10; - let mut rng = test_rng(); - let size_log2 = ceil_log2(max_thread_id); - let f1: Arc> = - DenseMultilinearExtension::::random(nv, &mut rng).into(); - let g1: Arc> = - DenseMultilinearExtension::::random(nv, &mut rng).into(); - - let mut virtual_poly_1 = VirtualPolynomial::new_from_mle(f1.clone(), E::BaseField::ONE); - virtual_poly_1.mul_by_mle(g1.clone(), ::BaseField::ONE); - - let mut virtual_poly_f1: Vec> = match &f1.evaluations { - multilinear_extensions::mle::FieldType::Base(evaluations) => evaluations - .chunks((1 << nv) >> size_log2) - .map(|chunk| { - DenseMultilinearExtension::::from_evaluations_vec(nv - size_log2, chunk.to_vec()) - .into() - }) - .map(|mle| VirtualPolynomial::new_from_mle(mle, E::BaseField::ONE)) - .collect_vec(), - _ => unreachable!(), - }; - - let poly_g1: Vec> = match &g1.evaluations { - multilinear_extensions::mle::FieldType::Base(evaluations) => evaluations - .chunks((1 << nv) >> size_log2) - .map(|chunk| { - DenseMultilinearExtension::::from_evaluations_vec(nv - size_log2, chunk.to_vec()) - .into() - }) - .collect_vec(), - _ => unreachable!(), - }; - - let asserted_sum = commutative_op_mle_pair!(|f1, g1| { - (0..f1.len()) - .map(|i| f1[i] * g1[i]) - .fold(E::ZERO, |acc, item| acc + item) - }); - - virtual_poly_f1 - .iter_mut() - .zip(poly_g1.iter()) - .for_each(|(f1, g1)| f1.mul_by_mle(g1.clone(), E::BaseField::ONE)); - (asserted_sum, virtual_poly_1, virtual_poly_f1) -} - -#[from_env] -const RAYON_NUM_THREADS: usize = 8; - -fn main() { - let mut prover_transcript_v1 = Transcript::::new(b"test"); - let mut prover_transcript_v2 = Transcript::::new(b"test"); - - let (asserted_sum, virtual_poly, virtual_poly_splitted) = prepare_input(RAYON_NUM_THREADS); - let (sumcheck_proof_v2, _) = IOPProverState::::prove_batch_polys( - RAYON_NUM_THREADS, - virtual_poly_splitted.clone(), - &mut prover_transcript_v2, - ); - println!("v2 finish"); - - let mut transcript = Transcript::new(b"test"); - let poly_info = virtual_poly.aux_info.clone(); - let subclaim = IOPVerifierState::::verify( - asserted_sum, - &sumcheck_proof_v2, - &poly_info, - &mut transcript, - ); - assert!( - virtual_poly.evaluate( - subclaim - .point - .iter() - .map(|c| c.elements) - .collect::>() - .as_ref() - ) == subclaim.expected_evaluation, - "wrong subclaim" - ); - - #[allow(deprecated)] - let (sumcheck_proof_v1, _) = - IOPProverState::::prove_parallel(virtual_poly.clone(), &mut prover_transcript_v1); - - println!("v1 finish"); - assert!(sumcheck_proof_v2 == sumcheck_proof_v1); -} diff --git a/sumcheck/src/prover_v2.rs b/sumcheck/src/prover_v2.rs index 8d1c4c952..f2ede8680 100644 --- a/sumcheck/src/prover_v2.rs +++ b/sumcheck/src/prover_v2.rs @@ -7,7 +7,7 @@ use itertools::Itertools; use multilinear_extensions::{ commutative_op_mle_pair, mle::{DenseMultilinearExtension, MultilinearExtension}, - op_mle, op_mle_3, + op_mle, op_mle_product_3, op_mle3_range, util::largest_even_below, virtual_poly_v2::VirtualPolynomialV2, }; @@ -43,6 +43,7 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { ) -> (IOPProof, IOPProverStateV2<'a, E>) { assert!(!polys.is_empty()); assert_eq!(polys.len(), max_thread_id); + assert!(max_thread_id.is_power_of_two()); let log2_max_thread_id = ceil_log2(max_thread_id); // do not support SIZE not power of 2 assert!( @@ -507,7 +508,7 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { &self.poly.flattened_ml_extensions[products[1]], &self.poly.flattened_ml_extensions[products[2]], ); - op_mle_3!( + op_mle_product_3!( |f1, f2, f3| { let res = (0..largest_even_below(f1.len())) .step_by(2) @@ -863,7 +864,7 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { &self.poly.flattened_ml_extensions[products[1]], &self.poly.flattened_ml_extensions[products[2]], ); - op_mle_3!( + op_mle_product_3!( |f1, f2, f3| { let res = (0..largest_even_below(f1.len())) .step_by(2) diff --git a/taplo.toml b/taplo.toml new file mode 100644 index 000000000..d857ad4e0 --- /dev/null +++ b/taplo.toml @@ -0,0 +1,6 @@ +# Configuration doc: https://taplo.tamasfe.dev/configuration/formatter-options.html +[formatting] +align_comments = false +array_auto_collapse = false +array_auto_expand = false +reorder_keys = true diff --git a/transcript/Cargo.toml b/transcript/Cargo.toml index 4f1e0dbf0..4769afe00 100644 --- a/transcript/Cargo.toml +++ b/transcript/Cargo.toml @@ -1,18 +1,13 @@ [package] -name = "transcript" -version.workspace = true edition.workspace = true license.workspace = true - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +name = "transcript" +version.workspace = true [dependencies] -ff_ext = { path = "../ff_ext" } -crossbeam-channel = "0.5.13" -ark-std.workspace = true +crossbeam-channel.workspace = true ff.workspace = true +ff_ext = { path = "../ff_ext" } goldilocks.workspace = true -halo2curves.workspace = true -rayon.workspace = true +poseidon.workspace = true serde.workspace = true -poseidon.workspace = true \ No newline at end of file