From ef3fe9a3b5bc966c1d12b1833b4e69cebef925f7 Mon Sep 17 00:00:00 2001 From: Aisuko Date: Mon, 16 Oct 2023 08:07:26 +0000 Subject: [PATCH 01/18] Init commit for rust backend Signed-off-by: GitHub --- .gitignore | 1 + backend/rust/Cargo.lock | 1234 ++++++++++++++++++++++++++++++++++++++ backend/rust/Cargo.toml | 20 + backend/rust/Makefile | 19 + backend/rust/README.md | 39 ++ backend/rust/build.rs | 9 + backend/rust/src/main.rs | 88 +++ 7 files changed, 1410 insertions(+) create mode 100644 backend/rust/Cargo.lock create mode 100644 backend/rust/Cargo.toml create mode 100644 backend/rust/Makefile create mode 100644 backend/rust/README.md create mode 100644 backend/rust/build.rs create mode 100644 backend/rust/src/main.rs diff --git a/.gitignore b/.gitignore index 8ffe29584056..a031473cf2a9 100644 --- a/.gitignore +++ b/.gitignore @@ -41,3 +41,4 @@ release/ backend-assets/ prepare /ggml-metal.metal +/backend/rust/target \ No newline at end of file diff --git a/backend/rust/Cargo.lock b/backend/rust/Cargo.lock new file mode 100644 index 000000000000..c3dd6ee99577 --- /dev/null +++ b/backend/rust/Cargo.lock @@ -0,0 +1,1234 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "addr2line" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb" +dependencies = [ + "gimli", +] + +[[package]] +name = "adler" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" + +[[package]] +name = "aho-corasick" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2969dcb958b36655471fc61f7e416fa76033bdd4bfed0678d8fee1e2d07a1f0" +dependencies = [ + "memchr", +] + +[[package]] +name = "anyhow" +version = "1.0.75" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4668cab20f66d8d020e1fbc0ebe47217433c1b6c8f2040faf858554e394ace6" + +[[package]] +name = "async-stream" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22068c0c19514942eefcfd4daf8976ef1aad84e61539f95cd200c35202f80af5" +dependencies = [ + "async-stream-impl 0.2.1", + "futures-core", +] + +[[package]] +name = "async-stream" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd56dd203fef61ac097dd65721a419ddccb106b2d2b70ba60a6b529f03961a51" +dependencies = [ + "async-stream-impl 0.3.5", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25f9db3b38af870bf7e5cc649167533b493928e50744e2c30ae350230b414670" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.38", +] + +[[package]] +name = "async-trait" +version = "0.1.74" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a66537f1bb974b254c98ed142ff995236e81b9d0fe4db0575f46612cb15eb0f9" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.38", +] + +[[package]] +name = "autocfg" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" + +[[package]] +name = "axum" +version = "0.6.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b829e4e32b91e643de6eafe82b1d90675f5874230191a4ffbc1b336dec4d6bf" +dependencies = [ + "async-trait", + "axum-core", + "bitflags 1.3.2", + "bytes", + "futures-util", + "http", + "http-body", + "hyper", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "sync_wrapper", + "tower", + "tower-layer", + "tower-service", +] + +[[package]] +name = "axum-core" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "759fa577a247914fd3f7f76d62972792636412fbfd634cd452f6a385a74d2d2c" +dependencies = [ + "async-trait", + "bytes", + "futures-util", + "http", + "http-body", + "mime", + "rustversion", + "tower-layer", + "tower-service", +] + +[[package]] +name = "backtrace" +version = "0.3.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2089b7e3f35b9dd2d0ed921ead4f6d318c27680d4a5bd167b3ee120edb105837" +dependencies = [ + "addr2line", + "cc", + "cfg-if", + "libc", + "miniz_oxide", + "object", + "rustc-demangle", +] + +[[package]] +name = "base64" +version = "0.21.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ba43ea6f343b788c8764558649e08df62f86c6ef251fdaeb1ffd010a9ae50a2" + +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + +[[package]] +name = "bitflags" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07" + +[[package]] +name = "bytes" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" + +[[package]] +name = "cc" +version = "1.0.83" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" +dependencies = [ + "libc", +] + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "either" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" + +[[package]] +name = "equivalent" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" + +[[package]] +name = "errno" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3e13f66a2f95e32a39eaa81f6b95d42878ca0e1db0c7543723dfe12557e860" +dependencies = [ + "libc", + "windows-sys", +] + +[[package]] +name = "fastrand" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25cbce373ec4653f1a01a31e8a5e5ec0c622dc27ff9c4e6606eefef5cbbed4a5" + +[[package]] +name = "fixedbitset" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "futures-channel" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "955518d47e09b25bbebc7a18df10b81f0c766eaf4c4f1cccef2fca5f2a4fb5f2" +dependencies = [ + "futures-core", +] + +[[package]] +name = "futures-core" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c" + +[[package]] +name = "futures-sink" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f43be4fe21a13b9781a69afa4985b0f6ee0e1afab2c6f454a8cf30e2b2237b6e" + +[[package]] +name = "futures-task" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65" + +[[package]] +name = "futures-util" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533" +dependencies = [ + "futures-core", + "futures-task", + "pin-project-lite", + "pin-utils", +] + +[[package]] +name = "getrandom" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce" +dependencies = [ + "cfg-if", + "libc", + "wasi 0.9.0+wasi-snapshot-preview1", +] + +[[package]] +name = "getrandom" +version = "0.2.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427" +dependencies = [ + "cfg-if", + "libc", + "wasi 0.11.0+wasi-snapshot-preview1", +] + +[[package]] +name = "gimli" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fb8d784f27acf97159b40fc4db5ecd8aa23b9ad5ef69cdd136d3bc80665f0c0" + +[[package]] +name = "h2" +version = "0.3.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91fc23aa11be92976ef4729127f1a74adf36d8436f7816b185d18df956790833" +dependencies = [ + "bytes", + "fnv", + "futures-core", + "futures-sink", + "futures-util", + "http", + "indexmap 1.9.3", + "slab", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "hashbrown" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" + +[[package]] +name = "hashbrown" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7dfda62a12f55daeae5015f81b0baea145391cb4520f86c248fc615d72640d12" + +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + +[[package]] +name = "hermit-abi" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7" + +[[package]] +name = "home" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5444c27eef6923071f7ebcc33e3444508466a76f7a2b93da00ed6e19f30c1ddb" +dependencies = [ + "windows-sys", +] + +[[package]] +name = "http" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd6effc99afb63425aff9b05836f029929e345a6148a14b7ecd5ab67af944482" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + +[[package]] +name = "http-body" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5f38f16d184e36f2408a55281cd658ecbd3ca05cce6d6510a176eca393e26d1" +dependencies = [ + "bytes", + "http", + "pin-project-lite", +] + +[[package]] +name = "httparse" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904" + +[[package]] +name = "httpdate" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" + +[[package]] +name = "hyper" +version = "0.14.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffb1cfd654a8219eaef89881fdb3bb3b1cdc5fa75ded05d6933b2b382e395468" +dependencies = [ + "bytes", + "futures-channel", + "futures-core", + "futures-util", + "h2", + "http", + "http-body", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "socket2 0.4.9", + "tokio", + "tower-service", + "tracing", + "want", +] + +[[package]] +name = "hyper-timeout" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbb958482e8c7be4bc3cf272a766a2b0bf1a6755e7a6ae777f017a31d11b13b1" +dependencies = [ + "hyper", + "pin-project-lite", + "tokio", + "tokio-io-timeout", +] + +[[package]] +name = "indexmap" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" +dependencies = [ + "autocfg", + "hashbrown 0.12.3", +] + +[[package]] +name = "indexmap" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8adf3ddd720272c6ea8bf59463c04e0f93d0bbf7c5439b691bca2987e0270897" +dependencies = [ + "equivalent", + "hashbrown 0.14.1", +] + +[[package]] +name = "itertools" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57" +dependencies = [ + "either", +] + +[[package]] +name = "itoa" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" + +[[package]] +name = "libc" +version = "0.2.149" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a08173bc88b7955d1b3145aa561539096c421ac8debde8cbc3612ec635fee29b" + +[[package]] +name = "linux-raw-sys" +version = "0.4.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da2479e8c062e40bf0066ffa0bc823de0a9368974af99c9f6df941d2c231e03f" + +[[package]] +name = "log" +version = "0.4.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" + +[[package]] +name = "matchit" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" + +[[package]] +name = "memchr" +version = "2.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f665ee40bc4a3c5590afb1e9677db74a508659dfd71e126420da8274909a0167" + +[[package]] +name = "mime" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" + +[[package]] +name = "miniz_oxide" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7810e0be55b428ada41041c41f32c9f1a42817901b4ccf45fa3d4b6561e74c7" +dependencies = [ + "adler", +] + +[[package]] +name = "mio" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "927a765cd3fc26206e66b296465fa9d3e5ab003e651c1b3c060e7956d96b19d2" +dependencies = [ + "libc", + "wasi 0.11.0+wasi-snapshot-preview1", + "windows-sys", +] + +[[package]] +name = "multimap" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a" + +[[package]] +name = "num_cpus" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" +dependencies = [ + "hermit-abi", + "libc", +] + +[[package]] +name = "object" +version = "0.32.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9cf5f9dd3933bd50a9e1f149ec995f39ae2c496d31fd772c1fd45ebc27e902b0" +dependencies = [ + "memchr", +] + +[[package]] +name = "once_cell" +version = "1.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" + +[[package]] +name = "percent-encoding" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94" + +[[package]] +name = "petgraph" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1d3afd2628e69da2be385eb6f2fd57c8ac7977ceeff6dc166ff1657b0e386a9" +dependencies = [ + "fixedbitset", + "indexmap 2.0.2", +] + +[[package]] +name = "pin-project" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fda4ed1c6c173e3fc7a83629421152e01d7b1f9b7f65fb301e490e8cfc656422" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4359fd9c9171ec6e8c62926d6faaf553a8dc3f64e1507e76da7911b4f6a04405" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.38", +] + +[[package]] +name = "pin-project-lite" +version = "0.2.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + +[[package]] +name = "prettyplease" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae005bd773ab59b4725093fd7df83fd7892f7d8eafb48dbd7de6e024e4215f9d" +dependencies = [ + "proc-macro2", + "syn 2.0.38", +] + +[[package]] +name = "proc-macro2" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "134c189feb4956b20f6f547d2cf727d4c0fe06722b20a0eec87ed445a97f92da" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "prost" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4fdd22f3b9c31b53c060df4a0613a1c7f062d4115a2b984dd15b1858f7e340d" +dependencies = [ + "bytes", + "prost-derive", +] + +[[package]] +name = "prost-build" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8bdf592881d821b83d471f8af290226c8d51402259e9bb5be7f9f8bdebbb11ac" +dependencies = [ + "bytes", + "heck", + "itertools", + "log", + "multimap", + "once_cell", + "petgraph", + "prettyplease", + "prost", + "prost-types", + "regex", + "syn 2.0.38", + "tempfile", + "which", +] + +[[package]] +name = "prost-derive" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "265baba7fabd416cf5078179f7d2cbeca4ce7a9041111900675ea7c4cb8a4c32" +dependencies = [ + "anyhow", + "itertools", + "proc-macro2", + "quote", + "syn 2.0.38", +] + +[[package]] +name = "prost-types" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e081b29f63d83a4bc75cfc9f3fe424f9156cf92d8a4f0c9407cce9a1b67327cf" +dependencies = [ + "prost", +] + +[[package]] +name = "quote" +version = "1.0.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rand" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03" +dependencies = [ + "getrandom 0.1.16", + "libc", + "rand_chacha 0.2.2", + "rand_core 0.5.1", + "rand_hc", +] + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha 0.3.1", + "rand_core 0.6.4", +] + +[[package]] +name = "rand_chacha" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4c8ed856279c9737206bf725bf36935d8666ead7aa69b52be55af369d193402" +dependencies = [ + "ppv-lite86", + "rand_core 0.5.1", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core 0.6.4", +] + +[[package]] +name = "rand_core" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19" +dependencies = [ + "getrandom 0.1.16", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom 0.2.10", +] + +[[package]] +name = "rand_hc" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c" +dependencies = [ + "rand_core 0.5.1", +] + +[[package]] +name = "redox_syscall" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" +dependencies = [ + "bitflags 1.3.2", +] + +[[package]] +name = "regex" +version = "1.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "380b951a9c5e80ddfd6136919eef32310721aa4aacd4889a8d39124b026ab343" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f804c7828047e88b2d32e2d7fe5a105da8ee3264f01902f796c8e067dc2483f" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" + +[[package]] +name = "rust" +version = "0.1.0" +dependencies = [ + "async-stream 0.2.1", + "prost", + "rand 0.7.3", + "serde", + "serde_json", + "tokio", + "tokio-stream", + "tonic", + "tonic-build", +] + +[[package]] +name = "rustc-demangle" +version = "0.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" + +[[package]] +name = "rustix" +version = "0.38.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "745ecfa778e66b2b63c88a61cb36e0eea109e803b0b86bf9879fbc77c70e86ed" +dependencies = [ + "bitflags 2.4.1", + "errno", + "libc", + "linux-raw-sys", + "windows-sys", +] + +[[package]] +name = "rustversion" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ffc183a10b4478d04cbbbfc96d0873219d962dd5accaff2ffbd4ceb7df837f4" + +[[package]] +name = "ryu" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741" + +[[package]] +name = "serde" +version = "1.0.189" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e422a44e74ad4001bdc8eede9a4570ab52f71190e9c076d14369f38b9200537" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.189" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e48d1f918009ce3145511378cf68d613e3b3d9137d67272562080d68a2b32d5" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.38", +] + +[[package]] +name = "serde_json" +version = "1.0.107" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b420ce6e3d8bd882e9b243c6eed35dbc9a6110c9769e74b584e0d68d1f20c65" +dependencies = [ + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "slab" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" +dependencies = [ + "autocfg", +] + +[[package]] +name = "socket2" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64a4a911eed85daf18834cfaa86a79b7d266ff93ff5ba14005426219480ed662" +dependencies = [ + "libc", + "winapi", +] + +[[package]] +name = "socket2" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4031e820eb552adee9295814c0ced9e5cf38ddf1e8b7d566d6de8e2538ea989e" +dependencies = [ + "libc", + "windows-sys", +] + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "syn" +version = "2.0.38" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e96b79aaa137db8f61e26363a0c9b47d8b4ec75da28b7d1d614c2303e232408b" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "sync_wrapper" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" + +[[package]] +name = "tempfile" +version = "3.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb94d2f3cc536af71caac6b6fcebf65860b347e7ce0cc9ebe8f70d3e521054ef" +dependencies = [ + "cfg-if", + "fastrand", + "redox_syscall", + "rustix", + "windows-sys", +] + +[[package]] +name = "tokio" +version = "1.33.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f38200e3ef7995e5ef13baec2f432a6da0aa9ac495b2c0e8f3b7eec2c92d653" +dependencies = [ + "backtrace", + "bytes", + "libc", + "mio", + "num_cpus", + "pin-project-lite", + "socket2 0.5.4", + "tokio-macros", + "windows-sys", +] + +[[package]] +name = "tokio-io-timeout" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30b74022ada614a1b4834de765f9bb43877f910cc8ce4be40e89042c9223a8bf" +dependencies = [ + "pin-project-lite", + "tokio", +] + +[[package]] +name = "tokio-macros" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.38", +] + +[[package]] +name = "tokio-stream" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "397c988d37662c7dda6d2208364a706264bf3d6138b11d436cbac0ad38832842" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "tokio-util" +version = "0.7.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d68074620f57a0b21594d9735eb2e98ab38b17f80d3fcb189fca266771ca60d" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", + "tracing", +] + +[[package]] +name = "tonic" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d560933a0de61cf715926b9cac824d4c883c2c43142f787595e48280c40a1d0e" +dependencies = [ + "async-stream 0.3.5", + "async-trait", + "axum", + "base64", + "bytes", + "h2", + "http", + "http-body", + "hyper", + "hyper-timeout", + "percent-encoding", + "pin-project", + "prost", + "tokio", + "tokio-stream", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tonic-build" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d021fc044c18582b9a2408cd0dd05b1596e3ecdb5c4df822bb0183545683889" +dependencies = [ + "prettyplease", + "proc-macro2", + "prost-build", + "quote", + "syn 2.0.38", +] + +[[package]] +name = "tower" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" +dependencies = [ + "futures-core", + "futures-util", + "indexmap 1.9.3", + "pin-project", + "pin-project-lite", + "rand 0.8.5", + "slab", + "tokio", + "tokio-util", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower-layer" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0" + +[[package]] +name = "tower-service" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" + +[[package]] +name = "tracing" +version = "0.1.39" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee2ef2af84856a50c1d430afce2fdded0a4ec7eda868db86409b4543df0797f9" +dependencies = [ + "pin-project-lite", + "tracing-attributes", + "tracing-core", +] + +[[package]] +name = "tracing-attributes" +version = "0.1.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.38", +] + +[[package]] +name = "tracing-core" +version = "0.1.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" +dependencies = [ + "once_cell", +] + +[[package]] +name = "try-lock" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed" + +[[package]] +name = "unicode-ident" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" + +[[package]] +name = "want" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e" +dependencies = [ + "try-lock", +] + +[[package]] +name = "wasi" +version = "0.9.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519" + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "which" +version = "4.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7" +dependencies = [ + "either", + "home", + "once_cell", + "rustix", +] + +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + +[[package]] +name = "windows-sys" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-targets" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" + +[[package]] +name = "windows_i686_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" + +[[package]] +name = "windows_i686_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" diff --git a/backend/rust/Cargo.toml b/backend/rust/Cargo.toml new file mode 100644 index 000000000000..43e0d112e252 --- /dev/null +++ b/backend/rust/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "rust" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +tonic = "0.10" +prost = "0.12" +tokio = { version = "1.0", features = ["rt-multi-thread", "macros", "sync", "time"] } +tokio-stream = "0.1" + +async-stream = "0.2" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +rand = "0.7" + +[build-dependencies] +tonic-build = "0.10" \ No newline at end of file diff --git a/backend/rust/Makefile b/backend/rust/Makefile new file mode 100644 index 000000000000..9a6a79eacbd0 --- /dev/null +++ b/backend/rust/Makefile @@ -0,0 +1,19 @@ +.PHONY: check +check: + @echo "Checking..." + @cargo check + +.POHNY: build +build: + @echo "Building..." + @cargo build + +.PHONY: run +run: + @echo "Running..." + @cargo run --release + +.PHONY: doc +doc: + @echo "Documenting..." + @cargo doc --no-deps --document-private-items --open diff --git a/backend/rust/README.md b/backend/rust/README.md new file mode 100644 index 000000000000..8489d0d6fe22 --- /dev/null +++ b/backend/rust/README.md @@ -0,0 +1,39 @@ +## Here is a backend written in Rust for the LocalAI project + +Here are some rules for the Rust backend: +* Same proto file with the LocalAI's other backends, we should keep the same interface of the backend. +* `async` should be as the default way to write code. +* Streaming response should be supported. +* Only server side gRPC services are supported for current backend. +* The backend should also have metrics for monitoring. + + +### The information of the environment + +* cargo 1.73.0 (9c4383fb5 2023-08-26) +* rustup 1.26.0 (5af9b9484 2023-04-05) +* rustc 1.73.0 (cc66ad468 2023-10-03) + +## Build the development environment + +#### Protocol Buffers compiler + +Ubuntu or Debian + +``` +sudo apt update && sudo apt upgrade -y +sudo apt install -y protobuf-compiler libprotobuf-dev +``` + +macOS +``` +brew install protobuf +``` + +### Generating the server side code + +> Rust backend uses the same proto file with the other backends, so we should keep the same interface of the backend. So, the output file of backend.rs is in the /target folder and do not need to be managed by git. + +``` +make build +``` \ No newline at end of file diff --git a/backend/rust/build.rs b/backend/rust/build.rs new file mode 100644 index 000000000000..249a2f4b303c --- /dev/null +++ b/backend/rust/build.rs @@ -0,0 +1,9 @@ +fn main() { + tonic_build::configure() + .build_server(true) + .build_client(false) + .compile( + &["../../pkg/grpc/proto/backend.proto"], + &["../../pkg/grpc/proto"]) + .expect("Failed to compile proto file"); +} \ No newline at end of file diff --git a/backend/rust/src/main.rs b/backend/rust/src/main.rs new file mode 100644 index 000000000000..14c3c78804c1 --- /dev/null +++ b/backend/rust/src/main.rs @@ -0,0 +1,88 @@ +use backend::backend_server::{Backend, BackendServer}; +use backend::{HealthMessage, PredictOptions,Reply,ModelOptions, EmbeddingResult, GenerateImageRequest,TranscriptRequest,TranscriptResult,TtsRequest,TokenizationResponse}; +use tonic::{Request, Response, Status}; +use tokio_stream::{wrappers::ReceiverStream}; + +use tonic::transport::Server; + +pub mod backend{ + tonic::include_proto!("backend"); +} + + +#[derive(Debug)] +struct BackendService; + +#[tonic::async_trait] +impl Backend for BackendService{ + + // Result in proto/backend.rs is conflict with std::result::Result + // So we need to use use the fully qualified name of the Result type in the protobuf file + + async fn health(&self, request: Request) -> Result,Status> { + + //TODO: Maybe we can move this to a logger + println!("Got a request: {:?}", request); + + let reply = backend::Reply { + message: format!("OK").into(), + }; + + Ok(Response::new(reply)) + + } + + // implmenet the predict function + async fn predict(&self, request: Request) -> Result,Status> { + unimplemented!("Not implemented yet") + } + + // implement the model function + async fn load_model(&self, request: Request) -> Result,Status> { + unimplemented!("Not implemented yet") + } + + type PredictStreamStream = ReceiverStream>; + + async fn predict_stream(&self, request: Request) -> Result,Status> { + unimplemented!("Not implemented yet") + } + + async fn embedding(&self, request: Request) -> Result,Status> { + unimplemented!("Not implemented yet") + } + + async fn generate_image(&self, request: Request) -> Result,Status> { + unimplemented!("Not implemented yet") + } + + async fn audio_transcription(&self, request: Request) -> Result,Status> { + unimplemented!("Not implemented yet") + } + + async fn tts(&self, request: Request) -> Result,Status> { + unimplemented!("Not implemented yet") + } + + async fn tokenize_string(&self, request: Request) -> Result,Status> { + unimplemented!("Not implemented yet") + } + + async fn status(&self, request: Request) -> Result,Status> { + unimplemented!("Not implemented yet") + } + +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let addr = "[::1]:50052".parse().unwrap(); + + let backend = BackendService {}; + + let svc = BackendServer::new(backend); + + Server::builder().add_service(svc).serve(addr).await?; + + Ok(()) +} From 029a71fe03e4fec2b10c955ce8b6fcfde4eec525 Mon Sep 17 00:00:00 2001 From: Aisuko Date: Wed, 18 Oct 2023 10:47:18 +1100 Subject: [PATCH 02/18] Update backend/rust/Makefile Co-authored-by: Luca Barbato Signed-off-by: Aisuko --- backend/rust/Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/rust/Makefile b/backend/rust/Makefile index 9a6a79eacbd0..c5871f45ade3 100644 --- a/backend/rust/Makefile +++ b/backend/rust/Makefile @@ -3,7 +3,7 @@ check: @echo "Checking..." @cargo check -.POHNY: build +.PHONY: build build: @echo "Building..." @cargo build From 5c67aa67aa567e8d11ee125484a9465042dce210 Mon Sep 17 00:00:00 2001 From: Aisuko Date: Wed, 18 Oct 2023 00:50:03 +0000 Subject: [PATCH 03/18] Add tracing Signed-off-by: GitHub --- backend/rust/Cargo.lock | 31 +- backend/rust/Cargo.toml | 25 +- backend/rust/bunker/Cargo.toml | 21 + backend/rust/{ => bunker}/build.rs | 5 +- backend/rust/bunker/src/backend.rs | 966 +++++++++++++++++++++++++++++ backend/rust/bunker/src/lib.rs | 5 + backend/rust/bunker/src/service.rs | 21 + backend/rust/burn/Cargo.toml | 22 + backend/rust/{ => burn}/Makefile | 0 backend/rust/burn/src/main.rs | 40 ++ backend/rust/src/main.rs | 88 --- 11 files changed, 1099 insertions(+), 125 deletions(-) create mode 100644 backend/rust/bunker/Cargo.toml rename backend/rust/{ => bunker}/build.rs (57%) create mode 100644 backend/rust/bunker/src/backend.rs create mode 100644 backend/rust/bunker/src/lib.rs create mode 100644 backend/rust/bunker/src/service.rs create mode 100644 backend/rust/burn/Cargo.toml rename backend/rust/{ => burn}/Makefile (100%) create mode 100644 backend/rust/burn/src/main.rs delete mode 100644 backend/rust/src/main.rs diff --git a/backend/rust/Cargo.lock b/backend/rust/Cargo.lock index c3dd6ee99577..f2e9eec81300 100644 --- a/backend/rust/Cargo.lock +++ b/backend/rust/Cargo.lock @@ -170,6 +170,22 @@ version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07" +[[package]] +name = "bunker" +version = "0.1.0" +dependencies = [ + "async-stream 0.2.1", + "async-trait", + "prost", + "rand 0.7.3", + "serde", + "serde_json", + "tokio", + "tokio-stream", + "tonic", + "tonic-build", +] + [[package]] name = "bytes" version = "1.5.0" @@ -787,21 +803,6 @@ version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" -[[package]] -name = "rust" -version = "0.1.0" -dependencies = [ - "async-stream 0.2.1", - "prost", - "rand 0.7.3", - "serde", - "serde_json", - "tokio", - "tokio-stream", - "tonic", - "tonic-build", -] - [[package]] name = "rustc-demangle" version = "0.1.23" diff --git a/backend/rust/Cargo.toml b/backend/rust/Cargo.toml index 43e0d112e252..4fa1c2286f9b 100644 --- a/backend/rust/Cargo.toml +++ b/backend/rust/Cargo.toml @@ -1,20 +1,5 @@ -[package] -name = "rust" -version = "0.1.0" -edition = "2021" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -tonic = "0.10" -prost = "0.12" -tokio = { version = "1.0", features = ["rt-multi-thread", "macros", "sync", "time"] } -tokio-stream = "0.1" - -async-stream = "0.2" -serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" -rand = "0.7" - -[build-dependencies] -tonic-build = "0.10" \ No newline at end of file +[workspace] +resolver = "2" +members = [ + "bunker", +] \ No newline at end of file diff --git a/backend/rust/bunker/Cargo.toml b/backend/rust/bunker/Cargo.toml new file mode 100644 index 000000000000..086317e916f4 --- /dev/null +++ b/backend/rust/bunker/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "bunker" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +tonic = "0.10" +prost = "0.12" +tokio = { version = "1.0", features = ["rt-multi-thread", "macros", "sync", "time"] } +tokio-stream = "0.1" + +async-stream = "0.2" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +rand = "0.7" +async-trait = "0.1.74" + +[build-dependencies] +tonic-build = "0.10" diff --git a/backend/rust/build.rs b/backend/rust/bunker/build.rs similarity index 57% rename from backend/rust/build.rs rename to backend/rust/bunker/build.rs index 249a2f4b303c..d9d332a04e4d 100644 --- a/backend/rust/build.rs +++ b/backend/rust/bunker/build.rs @@ -1,9 +1,10 @@ fn main() { tonic_build::configure() + .out_dir("src") .build_server(true) .build_client(false) .compile( - &["../../pkg/grpc/proto/backend.proto"], - &["../../pkg/grpc/proto"]) + &["../../../pkg/grpc/proto/backend.proto"], + &["../../../pkg/grpc/proto"]) .expect("Failed to compile proto file"); } \ No newline at end of file diff --git a/backend/rust/bunker/src/backend.rs b/backend/rust/bunker/src/backend.rs new file mode 100644 index 000000000000..8e06344905d5 --- /dev/null +++ b/backend/rust/bunker/src/backend.rs @@ -0,0 +1,966 @@ +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct HealthMessage {} +/// The request message containing the user's name. +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct PredictOptions { + #[prost(string, tag = "1")] + pub prompt: ::prost::alloc::string::String, + #[prost(int32, tag = "2")] + pub seed: i32, + #[prost(int32, tag = "3")] + pub threads: i32, + #[prost(int32, tag = "4")] + pub tokens: i32, + #[prost(int32, tag = "5")] + pub top_k: i32, + #[prost(int32, tag = "6")] + pub repeat: i32, + #[prost(int32, tag = "7")] + pub batch: i32, + #[prost(int32, tag = "8")] + pub n_keep: i32, + #[prost(float, tag = "9")] + pub temperature: f32, + #[prost(float, tag = "10")] + pub penalty: f32, + #[prost(bool, tag = "11")] + pub f16kv: bool, + #[prost(bool, tag = "12")] + pub debug_mode: bool, + #[prost(string, repeated, tag = "13")] + pub stop_prompts: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, + #[prost(bool, tag = "14")] + pub ignore_eos: bool, + #[prost(float, tag = "15")] + pub tail_free_sampling_z: f32, + #[prost(float, tag = "16")] + pub typical_p: f32, + #[prost(float, tag = "17")] + pub frequency_penalty: f32, + #[prost(float, tag = "18")] + pub presence_penalty: f32, + #[prost(int32, tag = "19")] + pub mirostat: i32, + #[prost(float, tag = "20")] + pub mirostat_eta: f32, + #[prost(float, tag = "21")] + pub mirostat_tau: f32, + #[prost(bool, tag = "22")] + pub penalize_nl: bool, + #[prost(string, tag = "23")] + pub logit_bias: ::prost::alloc::string::String, + #[prost(bool, tag = "25")] + pub m_lock: bool, + #[prost(bool, tag = "26")] + pub m_map: bool, + #[prost(bool, tag = "27")] + pub prompt_cache_all: bool, + #[prost(bool, tag = "28")] + pub prompt_cache_ro: bool, + #[prost(string, tag = "29")] + pub grammar: ::prost::alloc::string::String, + #[prost(string, tag = "30")] + pub main_gpu: ::prost::alloc::string::String, + #[prost(string, tag = "31")] + pub tensor_split: ::prost::alloc::string::String, + #[prost(float, tag = "32")] + pub top_p: f32, + #[prost(string, tag = "33")] + pub prompt_cache_path: ::prost::alloc::string::String, + #[prost(bool, tag = "34")] + pub debug: bool, + #[prost(int32, repeated, tag = "35")] + pub embedding_tokens: ::prost::alloc::vec::Vec, + #[prost(string, tag = "36")] + pub embeddings: ::prost::alloc::string::String, + #[prost(float, tag = "37")] + pub rope_freq_base: f32, + #[prost(float, tag = "38")] + pub rope_freq_scale: f32, + #[prost(float, tag = "39")] + pub negative_prompt_scale: f32, + #[prost(string, tag = "40")] + pub negative_prompt: ::prost::alloc::string::String, + #[prost(int32, tag = "41")] + pub n_draft: i32, +} +/// The response message containing the result +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Reply { + #[prost(bytes = "vec", tag = "1")] + pub message: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ModelOptions { + #[prost(string, tag = "1")] + pub model: ::prost::alloc::string::String, + #[prost(int32, tag = "2")] + pub context_size: i32, + #[prost(int32, tag = "3")] + pub seed: i32, + #[prost(int32, tag = "4")] + pub n_batch: i32, + #[prost(bool, tag = "5")] + pub f16_memory: bool, + #[prost(bool, tag = "6")] + pub m_lock: bool, + #[prost(bool, tag = "7")] + pub m_map: bool, + #[prost(bool, tag = "8")] + pub vocab_only: bool, + #[prost(bool, tag = "9")] + pub low_vram: bool, + #[prost(bool, tag = "10")] + pub embeddings: bool, + #[prost(bool, tag = "11")] + pub numa: bool, + #[prost(int32, tag = "12")] + pub ngpu_layers: i32, + #[prost(string, tag = "13")] + pub main_gpu: ::prost::alloc::string::String, + #[prost(string, tag = "14")] + pub tensor_split: ::prost::alloc::string::String, + #[prost(int32, tag = "15")] + pub threads: i32, + #[prost(string, tag = "16")] + pub library_search_path: ::prost::alloc::string::String, + #[prost(float, tag = "17")] + pub rope_freq_base: f32, + #[prost(float, tag = "18")] + pub rope_freq_scale: f32, + #[prost(float, tag = "19")] + pub rms_norm_eps: f32, + #[prost(int32, tag = "20")] + pub ngqa: i32, + #[prost(string, tag = "21")] + pub model_file: ::prost::alloc::string::String, + /// AutoGPTQ + #[prost(string, tag = "22")] + pub device: ::prost::alloc::string::String, + #[prost(bool, tag = "23")] + pub use_triton: bool, + #[prost(string, tag = "24")] + pub model_base_name: ::prost::alloc::string::String, + #[prost(bool, tag = "25")] + pub use_fast_tokenizer: bool, + /// Diffusers + #[prost(string, tag = "26")] + pub pipeline_type: ::prost::alloc::string::String, + #[prost(string, tag = "27")] + pub scheduler_type: ::prost::alloc::string::String, + #[prost(bool, tag = "28")] + pub cuda: bool, + #[prost(float, tag = "29")] + pub cfg_scale: f32, + #[prost(bool, tag = "30")] + pub img2img: bool, + #[prost(string, tag = "31")] + pub clip_model: ::prost::alloc::string::String, + #[prost(string, tag = "32")] + pub clip_subfolder: ::prost::alloc::string::String, + #[prost(int32, tag = "33")] + pub clip_skip: i32, + /// RWKV + #[prost(string, tag = "34")] + pub tokenizer: ::prost::alloc::string::String, + /// LLM (llama.cpp) + #[prost(string, tag = "35")] + pub lora_base: ::prost::alloc::string::String, + #[prost(string, tag = "36")] + pub lora_adapter: ::prost::alloc::string::String, + #[prost(bool, tag = "37")] + pub no_mul_mat_q: bool, + #[prost(string, tag = "39")] + pub draft_model: ::prost::alloc::string::String, + #[prost(string, tag = "38")] + pub audio_path: ::prost::alloc::string::String, + /// vllm + #[prost(string, tag = "40")] + pub quantization: ::prost::alloc::string::String, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Result { + #[prost(string, tag = "1")] + pub message: ::prost::alloc::string::String, + #[prost(bool, tag = "2")] + pub success: bool, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct EmbeddingResult { + #[prost(float, repeated, tag = "1")] + pub embeddings: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct TranscriptRequest { + #[prost(string, tag = "2")] + pub dst: ::prost::alloc::string::String, + #[prost(string, tag = "3")] + pub language: ::prost::alloc::string::String, + #[prost(uint32, tag = "4")] + pub threads: u32, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct TranscriptResult { + #[prost(message, repeated, tag = "1")] + pub segments: ::prost::alloc::vec::Vec, + #[prost(string, tag = "2")] + pub text: ::prost::alloc::string::String, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct TranscriptSegment { + #[prost(int32, tag = "1")] + pub id: i32, + #[prost(int64, tag = "2")] + pub start: i64, + #[prost(int64, tag = "3")] + pub end: i64, + #[prost(string, tag = "4")] + pub text: ::prost::alloc::string::String, + #[prost(int32, repeated, tag = "5")] + pub tokens: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct GenerateImageRequest { + #[prost(int32, tag = "1")] + pub height: i32, + #[prost(int32, tag = "2")] + pub width: i32, + #[prost(int32, tag = "3")] + pub mode: i32, + #[prost(int32, tag = "4")] + pub step: i32, + #[prost(int32, tag = "5")] + pub seed: i32, + #[prost(string, tag = "6")] + pub positive_prompt: ::prost::alloc::string::String, + #[prost(string, tag = "7")] + pub negative_prompt: ::prost::alloc::string::String, + #[prost(string, tag = "8")] + pub dst: ::prost::alloc::string::String, + #[prost(string, tag = "9")] + pub src: ::prost::alloc::string::String, + /// Diffusers + #[prost(string, tag = "10")] + pub enable_parameters: ::prost::alloc::string::String, + #[prost(int32, tag = "11")] + pub clip_skip: i32, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct TtsRequest { + #[prost(string, tag = "1")] + pub text: ::prost::alloc::string::String, + #[prost(string, tag = "2")] + pub model: ::prost::alloc::string::String, + #[prost(string, tag = "3")] + pub dst: ::prost::alloc::string::String, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct TokenizationResponse { + #[prost(int32, tag = "1")] + pub length: i32, + #[prost(int32, repeated, tag = "2")] + pub tokens: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct MemoryUsageData { + #[prost(uint64, tag = "1")] + pub total: u64, + #[prost(map = "string, uint64", tag = "2")] + pub breakdown: ::std::collections::HashMap<::prost::alloc::string::String, u64>, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct StatusResponse { + #[prost(enumeration = "status_response::State", tag = "1")] + pub state: i32, + #[prost(message, optional, tag = "2")] + pub memory: ::core::option::Option, +} +/// Nested message and enum types in `StatusResponse`. +pub mod status_response { + #[derive( + Clone, + Copy, + Debug, + PartialEq, + Eq, + Hash, + PartialOrd, + Ord, + ::prost::Enumeration + )] + #[repr(i32)] + pub enum State { + Uninitialized = 0, + Busy = 1, + Ready = 2, + Error = -1, + } + impl State { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + State::Uninitialized => "UNINITIALIZED", + State::Busy => "BUSY", + State::Ready => "READY", + State::Error => "ERROR", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "UNINITIALIZED" => Some(Self::Uninitialized), + "BUSY" => Some(Self::Busy), + "READY" => Some(Self::Ready), + "ERROR" => Some(Self::Error), + _ => None, + } + } + } +} +/// Generated server implementations. +pub mod backend_server { + #![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)] + use tonic::codegen::*; + /// Generated trait containing gRPC methods that should be implemented for use with BackendServer. + #[async_trait] + pub trait Backend: Send + Sync + 'static { + async fn health( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status>; + async fn predict( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status>; + async fn load_model( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status>; + /// Server streaming response type for the PredictStream method. + type PredictStreamStream: tonic::codegen::tokio_stream::Stream< + Item = std::result::Result, + > + + Send + + 'static; + async fn predict_stream( + &self, + request: tonic::Request, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + >; + async fn embedding( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status>; + async fn generate_image( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status>; + async fn audio_transcription( + &self, + request: tonic::Request, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + >; + async fn tts( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status>; + async fn tokenize_string( + &self, + request: tonic::Request, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + >; + async fn status( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status>; + } + #[derive(Debug)] + pub struct BackendServer { + inner: _Inner, + accept_compression_encodings: EnabledCompressionEncodings, + send_compression_encodings: EnabledCompressionEncodings, + max_decoding_message_size: Option, + max_encoding_message_size: Option, + } + struct _Inner(Arc); + impl BackendServer { + pub fn new(inner: T) -> Self { + Self::from_arc(Arc::new(inner)) + } + pub fn from_arc(inner: Arc) -> Self { + let inner = _Inner(inner); + Self { + inner, + accept_compression_encodings: Default::default(), + send_compression_encodings: Default::default(), + max_decoding_message_size: None, + max_encoding_message_size: None, + } + } + pub fn with_interceptor( + inner: T, + interceptor: F, + ) -> InterceptedService + where + F: tonic::service::Interceptor, + { + InterceptedService::new(Self::new(inner), interceptor) + } + /// Enable decompressing requests with the given encoding. + #[must_use] + pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.accept_compression_encodings.enable(encoding); + self + } + /// Compress responses with the given encoding, if the client supports it. + #[must_use] + pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.send_compression_encodings.enable(encoding); + self + } + /// Limits the maximum size of a decoded message. + /// + /// Default: `4MB` + #[must_use] + pub fn max_decoding_message_size(mut self, limit: usize) -> Self { + self.max_decoding_message_size = Some(limit); + self + } + /// Limits the maximum size of an encoded message. + /// + /// Default: `usize::MAX` + #[must_use] + pub fn max_encoding_message_size(mut self, limit: usize) -> Self { + self.max_encoding_message_size = Some(limit); + self + } + } + impl tonic::codegen::Service> for BackendServer + where + T: Backend, + B: Body + Send + 'static, + B::Error: Into + Send + 'static, + { + type Response = http::Response; + type Error = std::convert::Infallible; + type Future = BoxFuture; + fn poll_ready( + &mut self, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + fn call(&mut self, req: http::Request) -> Self::Future { + let inner = self.inner.clone(); + match req.uri().path() { + "/backend.Backend/Health" => { + #[allow(non_camel_case_types)] + struct HealthSvc(pub Arc); + impl tonic::server::UnaryService + for HealthSvc { + type Response = super::Reply; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; + fn call( + &mut self, + request: tonic::Request, + ) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::health(&inner, request).await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let inner = inner.0; + let method = HealthSvc(inner); + let codec = tonic::codec::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.unary(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/backend.Backend/Predict" => { + #[allow(non_camel_case_types)] + struct PredictSvc(pub Arc); + impl tonic::server::UnaryService + for PredictSvc { + type Response = super::Reply; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; + fn call( + &mut self, + request: tonic::Request, + ) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::predict(&inner, request).await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let inner = inner.0; + let method = PredictSvc(inner); + let codec = tonic::codec::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.unary(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/backend.Backend/LoadModel" => { + #[allow(non_camel_case_types)] + struct LoadModelSvc(pub Arc); + impl tonic::server::UnaryService + for LoadModelSvc { + type Response = super::Result; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; + fn call( + &mut self, + request: tonic::Request, + ) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::load_model(&inner, request).await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let inner = inner.0; + let method = LoadModelSvc(inner); + let codec = tonic::codec::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.unary(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/backend.Backend/PredictStream" => { + #[allow(non_camel_case_types)] + struct PredictStreamSvc(pub Arc); + impl< + T: Backend, + > tonic::server::ServerStreamingService + for PredictStreamSvc { + type Response = super::Reply; + type ResponseStream = T::PredictStreamStream; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; + fn call( + &mut self, + request: tonic::Request, + ) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::predict_stream(&inner, request).await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let inner = inner.0; + let method = PredictStreamSvc(inner); + let codec = tonic::codec::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.server_streaming(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/backend.Backend/Embedding" => { + #[allow(non_camel_case_types)] + struct EmbeddingSvc(pub Arc); + impl tonic::server::UnaryService + for EmbeddingSvc { + type Response = super::EmbeddingResult; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; + fn call( + &mut self, + request: tonic::Request, + ) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::embedding(&inner, request).await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let inner = inner.0; + let method = EmbeddingSvc(inner); + let codec = tonic::codec::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.unary(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/backend.Backend/GenerateImage" => { + #[allow(non_camel_case_types)] + struct GenerateImageSvc(pub Arc); + impl< + T: Backend, + > tonic::server::UnaryService + for GenerateImageSvc { + type Response = super::Result; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; + fn call( + &mut self, + request: tonic::Request, + ) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::generate_image(&inner, request).await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let inner = inner.0; + let method = GenerateImageSvc(inner); + let codec = tonic::codec::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.unary(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/backend.Backend/AudioTranscription" => { + #[allow(non_camel_case_types)] + struct AudioTranscriptionSvc(pub Arc); + impl< + T: Backend, + > tonic::server::UnaryService + for AudioTranscriptionSvc { + type Response = super::TranscriptResult; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; + fn call( + &mut self, + request: tonic::Request, + ) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::audio_transcription(&inner, request).await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let inner = inner.0; + let method = AudioTranscriptionSvc(inner); + let codec = tonic::codec::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.unary(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/backend.Backend/TTS" => { + #[allow(non_camel_case_types)] + struct TTSSvc(pub Arc); + impl tonic::server::UnaryService + for TTSSvc { + type Response = super::Result; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; + fn call( + &mut self, + request: tonic::Request, + ) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::tts(&inner, request).await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let inner = inner.0; + let method = TTSSvc(inner); + let codec = tonic::codec::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.unary(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/backend.Backend/TokenizeString" => { + #[allow(non_camel_case_types)] + struct TokenizeStringSvc(pub Arc); + impl tonic::server::UnaryService + for TokenizeStringSvc { + type Response = super::TokenizationResponse; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; + fn call( + &mut self, + request: tonic::Request, + ) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::tokenize_string(&inner, request).await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let inner = inner.0; + let method = TokenizeStringSvc(inner); + let codec = tonic::codec::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.unary(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/backend.Backend/Status" => { + #[allow(non_camel_case_types)] + struct StatusSvc(pub Arc); + impl tonic::server::UnaryService + for StatusSvc { + type Response = super::StatusResponse; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; + fn call( + &mut self, + request: tonic::Request, + ) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::status(&inner, request).await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let inner = inner.0; + let method = StatusSvc(inner); + let codec = tonic::codec::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.unary(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + _ => { + Box::pin(async move { + Ok( + http::Response::builder() + .status(200) + .header("grpc-status", "12") + .header("content-type", "application/grpc") + .body(empty_body()) + .unwrap(), + ) + }) + } + } + } + } + impl Clone for BackendServer { + fn clone(&self) -> Self { + let inner = self.inner.clone(); + Self { + inner, + accept_compression_encodings: self.accept_compression_encodings, + send_compression_encodings: self.send_compression_encodings, + max_decoding_message_size: self.max_decoding_message_size, + max_encoding_message_size: self.max_encoding_message_size, + } + } + } + impl Clone for _Inner { + fn clone(&self) -> Self { + Self(Arc::clone(&self.0)) + } + } + impl std::fmt::Debug for _Inner { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self.0) + } + } + impl tonic::server::NamedService for BackendServer { + const NAME: &'static str = "backend.Backend"; + } +} diff --git a/backend/rust/bunker/src/lib.rs b/backend/rust/bunker/src/lib.rs new file mode 100644 index 000000000000..a9c9b94b7823 --- /dev/null +++ b/backend/rust/bunker/src/lib.rs @@ -0,0 +1,5 @@ +pub mod pb{ + include!("backend.rs"); +} + +pub mod service; \ No newline at end of file diff --git a/backend/rust/bunker/src/service.rs b/backend/rust/bunker/src/service.rs new file mode 100644 index 000000000000..aefc4c1707c6 --- /dev/null +++ b/backend/rust/bunker/src/service.rs @@ -0,0 +1,21 @@ +use crate::pb::{HealthMessage, PredictOptions,Reply,ModelOptions, EmbeddingResult, GenerateImageRequest,TranscriptRequest,TranscriptResult,TtsRequest,TokenizationResponse,StatusResponse}; +use crate::pb::Result as PbResult; +use tonic::{Request, Response, Status}; +use tokio_stream::wrappers::ReceiverStream; +use async_trait::async_trait; + + +#[async_trait] +trait BackendService>>{ + async fn health(&self, request: Request) -> Result,Status>; + async fn predict(&self, request: Request) -> Result,Status>; + async fn load_model(&self, request: Request) -> Result,Status>; + async fn predict_stream(&self, request: Request) -> Result,Status>; // https://github.com/rust-lang/rust/issues/29661 + async fn embedding(&self, request: Request) -> Result,Status>; + async fn generate_image(&self, request: Request) -> Result,Status>; + async fn audio_transcription(&self, request: Request) -> Result,Status>; + async fn text_to_speech(&self, request: Request) -> Result,Status>; + async fn tokenization(&self, request: Request) -> Result,Status>; + async fn status(&self, request: Request) -> Result,Status>; + +} \ No newline at end of file diff --git a/backend/rust/burn/Cargo.toml b/backend/rust/burn/Cargo.toml new file mode 100644 index 000000000000..1ee6828167c0 --- /dev/null +++ b/backend/rust/burn/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "burn" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] + +bunker = { path = "../bunker" } + +tonic = "0.10" +prost = "0.12" +tokio = { version = "1.0", features = ["rt-multi-thread", "macros", "sync", "time"] } +tokio-stream = "0.1" + +async-stream = "0.2" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +rand = "0.7" +tracing = "0.1" +tracing-subscriber = "0.3" diff --git a/backend/rust/Makefile b/backend/rust/burn/Makefile similarity index 100% rename from backend/rust/Makefile rename to backend/rust/burn/Makefile diff --git a/backend/rust/burn/src/main.rs b/backend/rust/burn/src/main.rs new file mode 100644 index 000000000000..4148a0b556b1 --- /dev/null +++ b/backend/rust/burn/src/main.rs @@ -0,0 +1,40 @@ + +use tracing; +use tracing_subscriber; +use bunker::service::BackendService; + +// implement BackendService trait in bunker + +struct BurnBackend; + +#[async_trait] +impl BackendService for BurnBackend{ + unimplemented!(); +} + + +#[tokio::main] +async fn main() -> Result<(), Box> { + unimplemented!(); + + // let subscriber = tracing_subscriber::fmt() + // .compact() + // .with_file(true) + // .with_line_number(true) + // .with_target(true) + // .with_level(true) + // .finish(); + + // tracing::subscriber::set_global_default(subscriber) + // .expect("setting default subscriber failed"); + + // let addr = "[::1]:50052".parse().unwrap(); + + // let backend = BackendService {}; + + // let svc = BackendServer::new(backend); + + // Server::builder().add_service(svc).serve(addr).await?; + + // Ok(()) +} diff --git a/backend/rust/src/main.rs b/backend/rust/src/main.rs deleted file mode 100644 index 14c3c78804c1..000000000000 --- a/backend/rust/src/main.rs +++ /dev/null @@ -1,88 +0,0 @@ -use backend::backend_server::{Backend, BackendServer}; -use backend::{HealthMessage, PredictOptions,Reply,ModelOptions, EmbeddingResult, GenerateImageRequest,TranscriptRequest,TranscriptResult,TtsRequest,TokenizationResponse}; -use tonic::{Request, Response, Status}; -use tokio_stream::{wrappers::ReceiverStream}; - -use tonic::transport::Server; - -pub mod backend{ - tonic::include_proto!("backend"); -} - - -#[derive(Debug)] -struct BackendService; - -#[tonic::async_trait] -impl Backend for BackendService{ - - // Result in proto/backend.rs is conflict with std::result::Result - // So we need to use use the fully qualified name of the Result type in the protobuf file - - async fn health(&self, request: Request) -> Result,Status> { - - //TODO: Maybe we can move this to a logger - println!("Got a request: {:?}", request); - - let reply = backend::Reply { - message: format!("OK").into(), - }; - - Ok(Response::new(reply)) - - } - - // implmenet the predict function - async fn predict(&self, request: Request) -> Result,Status> { - unimplemented!("Not implemented yet") - } - - // implement the model function - async fn load_model(&self, request: Request) -> Result,Status> { - unimplemented!("Not implemented yet") - } - - type PredictStreamStream = ReceiverStream>; - - async fn predict_stream(&self, request: Request) -> Result,Status> { - unimplemented!("Not implemented yet") - } - - async fn embedding(&self, request: Request) -> Result,Status> { - unimplemented!("Not implemented yet") - } - - async fn generate_image(&self, request: Request) -> Result,Status> { - unimplemented!("Not implemented yet") - } - - async fn audio_transcription(&self, request: Request) -> Result,Status> { - unimplemented!("Not implemented yet") - } - - async fn tts(&self, request: Request) -> Result,Status> { - unimplemented!("Not implemented yet") - } - - async fn tokenize_string(&self, request: Request) -> Result,Status> { - unimplemented!("Not implemented yet") - } - - async fn status(&self, request: Request) -> Result,Status> { - unimplemented!("Not implemented yet") - } - -} - -#[tokio::main] -async fn main() -> Result<(), Box> { - let addr = "[::1]:50052".parse().unwrap(); - - let backend = BackendService {}; - - let svc = BackendServer::new(backend); - - Server::builder().add_service(svc).serve(addr).await?; - - Ok(()) -} From 1806dd7e85f7aeff64720d132d15f27ff1b350dc Mon Sep 17 00:00:00 2001 From: Aisuko Date: Wed, 18 Oct 2023 18:50:43 +1100 Subject: [PATCH 04/18] Add workspace Signed-off-by: Aisuko --- .gitignore | 2 +- backend/rust/Cargo.lock | 96 ++++++++++++++++++++++++++++++ backend/rust/Cargo.toml | 1 + backend/rust/bunker/src/lib.rs | 6 +- backend/rust/bunker/src/service.rs | 6 +- backend/rust/burn/Cargo.toml | 1 + backend/rust/burn/src/main.rs | 85 ++++++++++++++++++++------ 7 files changed, 172 insertions(+), 25 deletions(-) diff --git a/.gitignore b/.gitignore index a031473cf2a9..3986e42857f2 100644 --- a/.gitignore +++ b/.gitignore @@ -41,4 +41,4 @@ release/ backend-assets/ prepare /ggml-metal.metal -/backend/rust/target \ No newline at end of file +/backend/rust/target/ diff --git a/backend/rust/Cargo.lock b/backend/rust/Cargo.lock index f2e9eec81300..9ef570f5daa8 100644 --- a/backend/rust/Cargo.lock +++ b/backend/rust/Cargo.lock @@ -186,6 +186,23 @@ dependencies = [ "tonic-build", ] +[[package]] +name = "burn" +version = "0.1.0" +dependencies = [ + "async-stream 0.2.1", + "bunker", + "prost", + "rand 0.7.3", + "serde", + "serde_json", + "tokio", + "tokio-stream", + "tonic", + "tracing", + "tracing-subscriber", +] + [[package]] name = "bytes" version = "1.5.0" @@ -471,6 +488,12 @@ version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" +[[package]] +name = "lazy_static" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" + [[package]] name = "libc" version = "0.2.149" @@ -533,6 +556,16 @@ version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a" +[[package]] +name = "nu-ansi-term" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +dependencies = [ + "overload", + "winapi", +] + [[package]] name = "num_cpus" version = "1.16.0" @@ -558,6 +591,12 @@ version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" +[[package]] +name = "overload" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" + [[package]] name = "percent-encoding" version = "2.3.0" @@ -865,6 +904,15 @@ dependencies = [ "serde", ] +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + [[package]] name = "slab" version = "0.4.9" @@ -874,6 +922,12 @@ dependencies = [ "autocfg", ] +[[package]] +name = "smallvec" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "942b4a808e05215192e39f4ab80813e599068285906cc91aa64f923db842bd5a" + [[package]] name = "socket2" version = "0.4.9" @@ -935,6 +989,16 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "thread_local" +version = "1.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdd6f064ccff2d6567adcb3873ca630700f00b5ad3f060c25b5dcfd9a4ce152" +dependencies = [ + "cfg-if", + "once_cell", +] + [[package]] name = "tokio" version = "1.33.0" @@ -1099,6 +1163,32 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" dependencies = [ "once_cell", + "valuable", +] + +[[package]] +name = "tracing-log" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78ddad33d2d10b1ed7eb9d1f518a5674713876e97e5bb9b7345a7984fbb4f922" +dependencies = [ + "lazy_static", + "log", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30a651bc37f915e81f087d86e62a18eec5f79550c7faff886f7090b4ea757c77" +dependencies = [ + "nu-ansi-term", + "sharded-slab", + "smallvec", + "thread_local", + "tracing-core", + "tracing-log", ] [[package]] @@ -1113,6 +1203,12 @@ version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" +[[package]] +name = "valuable" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" + [[package]] name = "want" version = "0.3.1" diff --git a/backend/rust/Cargo.toml b/backend/rust/Cargo.toml index 4fa1c2286f9b..194f7573a59f 100644 --- a/backend/rust/Cargo.toml +++ b/backend/rust/Cargo.toml @@ -2,4 +2,5 @@ resolver = "2" members = [ "bunker", + "burn", ] \ No newline at end of file diff --git a/backend/rust/bunker/src/lib.rs b/backend/rust/bunker/src/lib.rs index a9c9b94b7823..484161001f54 100644 --- a/backend/rust/bunker/src/lib.rs +++ b/backend/rust/bunker/src/lib.rs @@ -1,5 +1,5 @@ +pub mod service; + pub mod pb{ include!("backend.rs"); -} - -pub mod service; \ No newline at end of file +} \ No newline at end of file diff --git a/backend/rust/bunker/src/service.rs b/backend/rust/bunker/src/service.rs index aefc4c1707c6..1cebb11daff6 100644 --- a/backend/rust/bunker/src/service.rs +++ b/backend/rust/bunker/src/service.rs @@ -6,7 +6,7 @@ use async_trait::async_trait; #[async_trait] -trait BackendService>>{ +pub trait BackendService>>{ async fn health(&self, request: Request) -> Result,Status>; async fn predict(&self, request: Request) -> Result,Status>; async fn load_model(&self, request: Request) -> Result,Status>; @@ -14,8 +14,8 @@ trait BackendService>>{ async fn embedding(&self, request: Request) -> Result,Status>; async fn generate_image(&self, request: Request) -> Result,Status>; async fn audio_transcription(&self, request: Request) -> Result,Status>; - async fn text_to_speech(&self, request: Request) -> Result,Status>; - async fn tokenization(&self, request: Request) -> Result,Status>; + async fn tts(&self, request: Request) -> Result,Status>; + async fn tokenize_string(&self, request: Request) -> Result,Status>; async fn status(&self, request: Request) -> Result,Status>; } \ No newline at end of file diff --git a/backend/rust/burn/Cargo.toml b/backend/rust/burn/Cargo.toml index 1ee6828167c0..dfc127d8f670 100644 --- a/backend/rust/burn/Cargo.toml +++ b/backend/rust/burn/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" [dependencies] +# import bunker here bunker = { path = "../bunker" } tonic = "0.10" diff --git a/backend/rust/burn/src/main.rs b/backend/rust/burn/src/main.rs index 4148a0b556b1..4c5015975320 100644 --- a/backend/rust/burn/src/main.rs +++ b/backend/rust/burn/src/main.rs @@ -1,40 +1,89 @@ - +use tokio_stream::wrappers::ReceiverStream; +use tonic::transport::Server; +use tonic::{async_trait, Request, Response, Status}; use tracing; use tracing_subscriber; +use bunker::pb::{EmbeddingResult, GenerateImageRequest, HealthMessage, ModelOptions, PredictOptions, Reply, StatusResponse, TokenizationResponse, TranscriptRequest, TranscriptResult, TtsRequest}; +use bunker::pb::Result as PbResult; use bunker::service::BackendService; + // implement BackendService trait in bunker struct BurnBackend; #[async_trait] -impl BackendService for BurnBackend{ - unimplemented!(); +impl BackendService>> for BurnBackend{ + + async fn health(&self, request: Request) -> Result,Status> { + // return a Result,Status> + let reply = Reply { + message: "OK".into(), + }; + let res=Response::new(reply); + Ok(res) + } + + async fn predict(&self, request: Request) -> Result,Status> { + todo!() + } + + async fn load_model(&self, request: Request) -> Result, Status> { + todo!() + } + + async fn predict_stream(&self, request: Request) -> Result>>, Status> { + todo!() + } + + async fn embedding(&self, request: Request) -> Result, Status> { + todo!() + } + + async fn generate_image(&self, request: Request) -> Result, Status> { + todo!() + } + + async fn audio_transcription(&self, request: Request) -> Result, Status> { + todo!() + } + + async fn tts(&self, request: Request) -> Result, Status> { + todo!() + } + + async fn tokenize_string(&self, request: Request) -> Result, Status> { + todo!() + } + + async fn status(&self, request: Request) -> Result, Status> { + todo!() + } + } #[tokio::main] async fn main() -> Result<(), Box> { - unimplemented!(); - // let subscriber = tracing_subscriber::fmt() - // .compact() - // .with_file(true) - // .with_line_number(true) - // .with_target(true) - // .with_level(true) - // .finish(); + let subscriber = tracing_subscriber::fmt() + .compact() + .with_file(true) + .with_line_number(true) + .with_target(true) + .with_level(true) + .finish(); - // tracing::subscriber::set_global_default(subscriber) - // .expect("setting default subscriber failed"); + tracing::subscriber::set_global_default(subscriber) + .expect("setting default subscriber failed"); - // let addr = "[::1]:50052".parse().unwrap(); + let addr = "[::1]:50052".parse().unwrap(); - // let backend = BackendService {}; + let backend = BackendService {}; - // let svc = BackendServer::new(backend); + let svc = BurnBackend::new(backend); - // Server::builder().add_service(svc).serve(addr).await?; + Server::builder().add_service(svc).serve(addr).await?; - // Ok(()) + Ok(()) } From 61bd269af205c4dc8766bf4a3a64d9372881c551 Mon Sep 17 00:00:00 2001 From: Aisuko Date: Thu, 19 Oct 2023 15:23:15 +1100 Subject: [PATCH 05/18] Replace the generated file to the generated folder Signed-off-by: Aisuko --- backend/rust/Cargo.lock | 18 +++-- backend/rust/Cargo.toml | 1 + backend/rust/Makefile | 30 +++++++ backend/rust/README.md | 6 +- backend/rust/bunker/Cargo.toml | 7 +- backend/rust/bunker/build.rs | 10 --- .../rust/bunker/{src => generated}/backend.rs | 0 backend/rust/bunker/src/lib.rs | 9 ++- backend/rust/bunker/src/service.rs | 56 +++++++++---- backend/rust/burn/Cargo.toml | 16 +--- backend/rust/burn/src/main.rs | 81 +++++++++---------- backend/rust/codegen/Cargo.toml | 9 +++ backend/rust/codegen/build.rs | 11 +++ backend/rust/codegen/src/lib.rs | 1 + 14 files changed, 155 insertions(+), 100 deletions(-) create mode 100644 backend/rust/Makefile delete mode 100644 backend/rust/bunker/build.rs rename backend/rust/bunker/{src => generated}/backend.rs (100%) create mode 100644 backend/rust/codegen/Cargo.toml create mode 100644 backend/rust/codegen/build.rs create mode 100644 backend/rust/codegen/src/lib.rs diff --git a/backend/rust/Cargo.lock b/backend/rust/Cargo.lock index 9ef570f5daa8..f4d4397ca217 100644 --- a/backend/rust/Cargo.lock +++ b/backend/rust/Cargo.lock @@ -183,24 +183,19 @@ dependencies = [ "tokio", "tokio-stream", "tonic", - "tonic-build", + "tracing", + "tracing-subscriber", ] [[package]] name = "burn" version = "0.1.0" dependencies = [ - "async-stream 0.2.1", + "async-trait", "bunker", - "prost", - "rand 0.7.3", - "serde", - "serde_json", "tokio", "tokio-stream", "tonic", - "tracing", - "tracing-subscriber", ] [[package]] @@ -224,6 +219,13 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "codegen" +version = "0.1.0" +dependencies = [ + "tonic-build", +] + [[package]] name = "either" version = "1.9.0" diff --git a/backend/rust/Cargo.toml b/backend/rust/Cargo.toml index 194f7573a59f..15a22f54d8b6 100644 --- a/backend/rust/Cargo.toml +++ b/backend/rust/Cargo.toml @@ -3,4 +3,5 @@ resolver = "2" members = [ "bunker", "burn", + "codegen", ] \ No newline at end of file diff --git a/backend/rust/Makefile b/backend/rust/Makefile new file mode 100644 index 000000000000..2a18f0facd82 --- /dev/null +++ b/backend/rust/Makefile @@ -0,0 +1,30 @@ +# default are fmt and then check +.DEFAULT_GOAL := all + +.PHONY: all +all: fmt check + +.PHONY: fmt +fmt: + @echo "Formatting code..." + @cargo fmt --all -- --check + +.PHONY: build +build: + @echo "Building..." + @cargo build --release + +.PHONY: test +test: + @echo "Testing..." + @cargo test --all + +.PHONY: check +check: + @echo "Checking..." + @cargo check --all + +.PHONY: clean +clean: + @echo "Cleaning..." + @cargo clean diff --git a/backend/rust/README.md b/backend/rust/README.md index 8489d0d6fe22..de0012b9b92d 100644 --- a/backend/rust/README.md +++ b/backend/rust/README.md @@ -30,10 +30,8 @@ macOS brew install protobuf ``` -### Generating the server side code - -> Rust backend uses the same proto file with the other backends, so we should keep the same interface of the backend. So, the output file of backend.rs is in the /target folder and do not need to be managed by git. +### Cargo fmt all the code ``` -make build +cargo fmt --all --check ``` \ No newline at end of file diff --git a/backend/rust/bunker/Cargo.toml b/backend/rust/bunker/Cargo.toml index 086317e916f4..d9d9d8ff505f 100644 --- a/backend/rust/bunker/Cargo.toml +++ b/backend/rust/bunker/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -tonic = "0.10" +tonic = { version = "0.10", features = [] } prost = "0.12" tokio = { version = "1.0", features = ["rt-multi-thread", "macros", "sync", "time"] } tokio-stream = "0.1" @@ -16,6 +16,5 @@ serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" rand = "0.7" async-trait = "0.1.74" - -[build-dependencies] -tonic-build = "0.10" +tracing-subscriber = "0.3.17" +tracing = "0.1.39" diff --git a/backend/rust/bunker/build.rs b/backend/rust/bunker/build.rs deleted file mode 100644 index d9d332a04e4d..000000000000 --- a/backend/rust/bunker/build.rs +++ /dev/null @@ -1,10 +0,0 @@ -fn main() { - tonic_build::configure() - .out_dir("src") - .build_server(true) - .build_client(false) - .compile( - &["../../../pkg/grpc/proto/backend.proto"], - &["../../../pkg/grpc/proto"]) - .expect("Failed to compile proto file"); -} \ No newline at end of file diff --git a/backend/rust/bunker/src/backend.rs b/backend/rust/bunker/generated/backend.rs similarity index 100% rename from backend/rust/bunker/src/backend.rs rename to backend/rust/bunker/generated/backend.rs diff --git a/backend/rust/bunker/src/lib.rs b/backend/rust/bunker/src/lib.rs index 484161001f54..b4c8fb2726c2 100644 --- a/backend/rust/bunker/src/lib.rs +++ b/backend/rust/bunker/src/lib.rs @@ -1,5 +1,6 @@ -pub mod service; +/// Import the code() backend.rs was generated from backend.proto +pub mod pb { + include!("../generated/backend.rs"); +} -pub mod pb{ - include!("backend.rs"); -} \ No newline at end of file +pub mod service; diff --git a/backend/rust/bunker/src/service.rs b/backend/rust/bunker/src/service.rs index 1cebb11daff6..81cf3132cea0 100644 --- a/backend/rust/bunker/src/service.rs +++ b/backend/rust/bunker/src/service.rs @@ -1,21 +1,43 @@ -use crate::pb::{HealthMessage, PredictOptions,Reply,ModelOptions, EmbeddingResult, GenerateImageRequest,TranscriptRequest,TranscriptResult,TtsRequest,TokenizationResponse,StatusResponse}; +//! Contains the service trait for the bunker service. + use crate::pb::Result as PbResult; -use tonic::{Request, Response, Status}; -use tokio_stream::wrappers::ReceiverStream; +use crate::pb::{ + EmbeddingResult, GenerateImageRequest, HealthMessage, ModelOptions, PredictOptions, Reply, + StatusResponse, TokenizationResponse, TranscriptRequest, TranscriptResult, TtsRequest, +}; use async_trait::async_trait; - +use tokio_stream::wrappers::ReceiverStream; +use tonic::{Request, Response, Status}; #[async_trait] -pub trait BackendService>>{ - async fn health(&self, request: Request) -> Result,Status>; - async fn predict(&self, request: Request) -> Result,Status>; - async fn load_model(&self, request: Request) -> Result,Status>; - async fn predict_stream(&self, request: Request) -> Result,Status>; // https://github.com/rust-lang/rust/issues/29661 - async fn embedding(&self, request: Request) -> Result,Status>; - async fn generate_image(&self, request: Request) -> Result,Status>; - async fn audio_transcription(&self, request: Request) -> Result,Status>; - async fn tts(&self, request: Request) -> Result,Status>; - async fn tokenize_string(&self, request: Request) -> Result,Status>; - async fn status(&self, request: Request) -> Result,Status>; - -} \ No newline at end of file +pub trait BackendService>> { + async fn health(&self, request: Request) -> Result, Status>; + async fn predict(&self, request: Request) -> Result, Status>; + async fn load_model( + &self, + request: Request, + ) -> Result, Status>; + async fn predict_stream(&self, request: Request) + -> Result, Status>; // https://github.com/rust-lang/rust/issues/29661 + async fn embedding( + &self, + request: Request, + ) -> Result, Status>; + async fn generate_image( + &self, + request: Request, + ) -> Result, Status>; + async fn audio_transcription( + &self, + request: Request, + ) -> Result, Status>; + async fn tts(&self, request: Request) -> Result, Status>; + async fn tokenize_string( + &self, + request: Request, + ) -> Result, Status>; + async fn status( + &self, + request: Request, + ) -> Result, Status>; +} diff --git a/backend/rust/burn/Cargo.toml b/backend/rust/burn/Cargo.toml index dfc127d8f670..1d4c804214ed 100644 --- a/backend/rust/burn/Cargo.toml +++ b/backend/rust/burn/Cargo.toml @@ -9,15 +9,7 @@ edition = "2021" # import bunker here bunker = { path = "../bunker" } - -tonic = "0.10" -prost = "0.12" -tokio = { version = "1.0", features = ["rt-multi-thread", "macros", "sync", "time"] } -tokio-stream = "0.1" - -async-stream = "0.2" -serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" -rand = "0.7" -tracing = "0.1" -tracing-subscriber = "0.3" +tokio = "1.33.0" +async-trait = "0.1.74" +tonic = "0.10.2" +tokio-stream = "0.1.14" diff --git a/backend/rust/burn/src/main.rs b/backend/rust/burn/src/main.rs index 4c5015975320..11035a14c582 100644 --- a/backend/rust/burn/src/main.rs +++ b/backend/rust/burn/src/main.rs @@ -1,50 +1,65 @@ -use tokio_stream::wrappers::ReceiverStream; -use tonic::transport::Server; -use tonic::{async_trait, Request, Response, Status}; -use tracing; -use tracing_subscriber; -use bunker::pb::{EmbeddingResult, GenerateImageRequest, HealthMessage, ModelOptions, PredictOptions, Reply, StatusResponse, TokenizationResponse, TranscriptRequest, TranscriptResult, TtsRequest}; use bunker::pb::Result as PbResult; +use bunker::pb::{ + EmbeddingResult, GenerateImageRequest, HealthMessage, ModelOptions, PredictOptions, Reply, + StatusResponse, TokenizationResponse, TranscriptRequest, TranscriptResult, TtsRequest, +}; use bunker::service::BackendService; +use tokio_stream::wrappers::ReceiverStream; +use tonic::{Request, Response, Status}; +use async_trait::async_trait; // implement BackendService trait in bunker struct BurnBackend; #[async_trait] -impl BackendService>> for BurnBackend{ - - async fn health(&self, request: Request) -> Result,Status> { +impl BackendService>> for BurnBackend { + async fn health(&self, request: Request) -> Result, Status> { // return a Result,Status> let reply = Reply { message: "OK".into(), }; - let res=Response::new(reply); + let res = Response::new(reply); Ok(res) } - async fn predict(&self, request: Request) -> Result,Status> { + async fn predict(&self, request: Request) -> Result, Status> { todo!() } - async fn load_model(&self, request: Request) -> Result, Status> { + async fn load_model( + &self, + request: Request, + ) -> Result, Status> { todo!() } - async fn predict_stream(&self, request: Request) -> Result>>, Status> { + async fn predict_stream( + &self, + request: Request, + ) -> Result>>, Status> { todo!() } - async fn embedding(&self, request: Request) -> Result, Status> { + async fn embedding( + &self, + request: Request, + ) -> Result, Status> { todo!() } - async fn generate_image(&self, request: Request) -> Result, Status> { + async fn generate_image( + &self, + request: Request, + ) -> Result, Status> { todo!() } - async fn audio_transcription(&self, request: Request) -> Result, Status> { + async fn audio_transcription( + &self, + request: Request, + ) -> Result, Status> { todo!() } @@ -52,38 +67,22 @@ impl BackendService>> for BurnBackend{ todo!() } - async fn tokenize_string(&self, request: Request) -> Result, Status> { + async fn tokenize_string( + &self, + request: Request, + ) -> Result, Status> { todo!() } - async fn status(&self, request: Request) -> Result, Status> { + async fn status( + &self, + request: Request, + ) -> Result, Status> { todo!() } - } - #[tokio::main] async fn main() -> Result<(), Box> { - - let subscriber = tracing_subscriber::fmt() - .compact() - .with_file(true) - .with_line_number(true) - .with_target(true) - .with_level(true) - .finish(); - - tracing::subscriber::set_global_default(subscriber) - .expect("setting default subscriber failed"); - - let addr = "[::1]:50052".parse().unwrap(); - - let backend = BackendService {}; - - let svc = BurnBackend::new(backend); - - Server::builder().add_service(svc).serve(addr).await?; - - Ok(()) + todo!() } diff --git a/backend/rust/codegen/Cargo.toml b/backend/rust/codegen/Cargo.toml new file mode 100644 index 000000000000..c8eadcf13764 --- /dev/null +++ b/backend/rust/codegen/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "codegen" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[build-dependencies] +tonic-build = "0.10" diff --git a/backend/rust/codegen/build.rs b/backend/rust/codegen/build.rs new file mode 100644 index 000000000000..f702cc7fdc62 --- /dev/null +++ b/backend/rust/codegen/build.rs @@ -0,0 +1,11 @@ +fn main() { + tonic_build::configure() + .out_dir("../bunker/generated") + .build_server(true) + .build_client(false) + .compile( + &["../../../pkg/grpc/proto/backend.proto"], + &["../../../pkg/grpc/proto"], + ) + .unwrap(); +} diff --git a/backend/rust/codegen/src/lib.rs b/backend/rust/codegen/src/lib.rs new file mode 100644 index 000000000000..8b137891791f --- /dev/null +++ b/backend/rust/codegen/src/lib.rs @@ -0,0 +1 @@ + From b92677b3bf93be2fa1c47b33398902cdcaceb0c3 Mon Sep 17 00:00:00 2001 From: Aisuko Date: Fri, 20 Oct 2023 10:37:00 +1100 Subject: [PATCH 06/18] Update backend/rust/bunker/src/lib.rs Co-authored-by: Luca Barbato Signed-off-by: Aisuko --- backend/rust/bunker/src/lib.rs | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/backend/rust/bunker/src/lib.rs b/backend/rust/bunker/src/lib.rs index b4c8fb2726c2..b1a3c3508b12 100644 --- a/backend/rust/bunker/src/lib.rs +++ b/backend/rust/bunker/src/lib.rs @@ -3,4 +3,18 @@ pub mod pb { include!("../generated/backend.rs"); } -pub mod service; +use tonic::transport::Server; + +pub use crate::pb::backend_server::Backend as BackendService; +use crate::pb::backend_server::BackendServer; + +// Run the backend with the default behavior +pub async fn run(backend: impl BackendService, addr: impl Into) -> anyhow::Result<()> { + let svc = BackendServer::new(backend); + + let r = Server::builder() + .add_service(svc) + .serve(addr.into()) + .await?; + + Ok(r) From a2bb86f5a91d4ef813f2e02302c9eb3dc4991b44 Mon Sep 17 00:00:00 2001 From: Aisuko Date: Fri, 20 Oct 2023 11:25:37 +1100 Subject: [PATCH 07/18] Remove services.rs Signed-off-by: Aisuko --- backend/rust/bunker/src/lib.rs | 7 ++++- backend/rust/bunker/src/service.rs | 43 ------------------------------ backend/rust/burn/src/main.rs | 8 ++++-- backend/rust/codegen/src/lib.rs | 3 ++- 4 files changed, 14 insertions(+), 47 deletions(-) delete mode 100644 backend/rust/bunker/src/service.rs diff --git a/backend/rust/bunker/src/lib.rs b/backend/rust/bunker/src/lib.rs index b1a3c3508b12..9e0f781e0443 100644 --- a/backend/rust/bunker/src/lib.rs +++ b/backend/rust/bunker/src/lib.rs @@ -3,13 +3,17 @@ pub mod pb { include!("../generated/backend.rs"); } +use std::net::SocketAddr; use tonic::transport::Server; pub use crate::pb::backend_server::Backend as BackendService; use crate::pb::backend_server::BackendServer; // Run the backend with the default behavior -pub async fn run(backend: impl BackendService, addr: impl Into) -> anyhow::Result<()> { +pub async fn run( + backend: impl BackendService, + addr: impl Into, +) -> Result<(), Box> { let svc = BackendServer::new(backend); let r = Server::builder() @@ -18,3 +22,4 @@ pub async fn run(backend: impl BackendService, addr: impl Into) -> a .await?; Ok(r) +} diff --git a/backend/rust/bunker/src/service.rs b/backend/rust/bunker/src/service.rs deleted file mode 100644 index 81cf3132cea0..000000000000 --- a/backend/rust/bunker/src/service.rs +++ /dev/null @@ -1,43 +0,0 @@ -//! Contains the service trait for the bunker service. - -use crate::pb::Result as PbResult; -use crate::pb::{ - EmbeddingResult, GenerateImageRequest, HealthMessage, ModelOptions, PredictOptions, Reply, - StatusResponse, TokenizationResponse, TranscriptRequest, TranscriptResult, TtsRequest, -}; -use async_trait::async_trait; -use tokio_stream::wrappers::ReceiverStream; -use tonic::{Request, Response, Status}; - -#[async_trait] -pub trait BackendService>> { - async fn health(&self, request: Request) -> Result, Status>; - async fn predict(&self, request: Request) -> Result, Status>; - async fn load_model( - &self, - request: Request, - ) -> Result, Status>; - async fn predict_stream(&self, request: Request) - -> Result, Status>; // https://github.com/rust-lang/rust/issues/29661 - async fn embedding( - &self, - request: Request, - ) -> Result, Status>; - async fn generate_image( - &self, - request: Request, - ) -> Result, Status>; - async fn audio_transcription( - &self, - request: Request, - ) -> Result, Status>; - async fn tts(&self, request: Request) -> Result, Status>; - async fn tokenize_string( - &self, - request: Request, - ) -> Result, Status>; - async fn status( - &self, - request: Request, - ) -> Result, Status>; -} diff --git a/backend/rust/burn/src/main.rs b/backend/rust/burn/src/main.rs index 11035a14c582..409363eb5688 100644 --- a/backend/rust/burn/src/main.rs +++ b/backend/rust/burn/src/main.rs @@ -3,7 +3,8 @@ use bunker::pb::{ EmbeddingResult, GenerateImageRequest, HealthMessage, ModelOptions, PredictOptions, Reply, StatusResponse, TokenizationResponse, TranscriptRequest, TranscriptResult, TtsRequest, }; -use bunker::service::BackendService; + +use bunker::BackendService; use tokio_stream::wrappers::ReceiverStream; use tonic::{Request, Response, Status}; @@ -14,7 +15,9 @@ use async_trait::async_trait; struct BurnBackend; #[async_trait] -impl BackendService>> for BurnBackend { +impl BackendService for BurnBackend { + type PredictStreamStream = ReceiverStream>; + async fn health(&self, request: Request) -> Result, Status> { // return a Result,Status> let reply = Reply { @@ -84,5 +87,6 @@ impl BackendService>> for BurnBackend { #[tokio::main] async fn main() -> Result<(), Box> { + // call bunker::run with BurnBackend todo!() } diff --git a/backend/rust/codegen/src/lib.rs b/backend/rust/codegen/src/lib.rs index 8b137891791f..208bc595e4cc 100644 --- a/backend/rust/codegen/src/lib.rs +++ b/backend/rust/codegen/src/lib.rs @@ -1 +1,2 @@ - +//! Here is for the more complex situation of code generation. For example, defind the constant for +//! the different build target. \ No newline at end of file From bc6c1fcef87962ee87f38645f97678ca0098916d Mon Sep 17 00:00:00 2001 From: Aisuko Date: Fri, 20 Oct 2023 18:34:23 +1100 Subject: [PATCH 08/18] Add test health in Makefile Signed-off-by: Aisuko --- backend/rust/Cargo.toml | 2 +- backend/rust/Makefile | 12 +++++++++++- backend/rust/README.md | 19 ++++++++++++++++++- backend/rust/burn/Cargo.toml | 4 ++++ backend/rust/burn/Makefile | 10 +++++----- backend/rust/burn/src/{main.rs => server.rs} | 13 ++++++++++++- backend/rust/codegen/src/lib.rs | 4 ++-- 7 files changed, 53 insertions(+), 11 deletions(-) rename backend/rust/burn/src/{main.rs => server.rs} (88%) diff --git a/backend/rust/Cargo.toml b/backend/rust/Cargo.toml index 15a22f54d8b6..183192a8eef4 100644 --- a/backend/rust/Cargo.toml +++ b/backend/rust/Cargo.toml @@ -4,4 +4,4 @@ members = [ "bunker", "burn", "codegen", -] \ No newline at end of file +] diff --git a/backend/rust/Makefile b/backend/rust/Makefile index 2a18f0facd82..283ebc08b61a 100644 --- a/backend/rust/Makefile +++ b/backend/rust/Makefile @@ -7,7 +7,7 @@ all: fmt check .PHONY: fmt fmt: @echo "Formatting code..." - @cargo fmt --all -- --check + @cargo fmt --all --check .PHONY: build build: @@ -28,3 +28,13 @@ check: clean: @echo "Cleaning..." @cargo clean + +.PHONY: burn +burn: + @echo "Burning..." + @cargo run --bin server --package burn + +.PHONY: health +health: + @echo "Burning..." + grpcurl -plaintext -import-path ../../pkg/grpc/proto -proto backend.proto -d '{}' '[::1]:50051' backend.Backend/Health diff --git a/backend/rust/README.md b/backend/rust/README.md index de0012b9b92d..ae1e2ff2eb9e 100644 --- a/backend/rust/README.md +++ b/backend/rust/README.md @@ -34,4 +34,21 @@ brew install protobuf ``` cargo fmt --all --check -``` \ No newline at end of file +``` + +### Check the gRPC backend status + +It will return base64 encoded string of the `OK`. + + +```bash +make burn + +make test +``` + +``` +{ + "message": "T0s=" +} +``` diff --git a/backend/rust/burn/Cargo.toml b/backend/rust/burn/Cargo.toml index 1d4c804214ed..902f2104080f 100644 --- a/backend/rust/burn/Cargo.toml +++ b/backend/rust/burn/Cargo.toml @@ -5,6 +5,10 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[[bin]] +name = "server" +path = "src/server.rs" + [dependencies] # import bunker here diff --git a/backend/rust/burn/Makefile b/backend/rust/burn/Makefile index c5871f45ade3..12c9ff3073e0 100644 --- a/backend/rust/burn/Makefile +++ b/backend/rust/burn/Makefile @@ -3,16 +3,16 @@ check: @echo "Checking..." @cargo check +.PHONY: fmt +fmt: + @echo "Formatting code..." + @cargo fmt + .PHONY: build build: @echo "Building..." @cargo build -.PHONY: run -run: - @echo "Running..." - @cargo run --release - .PHONY: doc doc: @echo "Documenting..." diff --git a/backend/rust/burn/src/main.rs b/backend/rust/burn/src/server.rs similarity index 88% rename from backend/rust/burn/src/main.rs rename to backend/rust/burn/src/server.rs index 409363eb5688..c028cc93a97e 100644 --- a/backend/rust/burn/src/main.rs +++ b/backend/rust/burn/src/server.rs @@ -1,3 +1,5 @@ +use std::net::SocketAddr; + use bunker::pb::Result as PbResult; use bunker::pb::{ EmbeddingResult, GenerateImageRequest, HealthMessage, ModelOptions, PredictOptions, Reply, @@ -12,6 +14,7 @@ use async_trait::async_trait; // implement BackendService trait in bunker +#[derive(Default, Debug)] struct BurnBackend; #[async_trait] @@ -88,5 +91,13 @@ impl BackendService for BurnBackend { #[tokio::main] async fn main() -> Result<(), Box> { // call bunker::run with BurnBackend - todo!() + let burn_backend = BurnBackend {}; + let addr = "[::1]:50051" + .parse::() + .expect("Failed to parse address"); + + // Implmenet Into for addr + let result = bunker::run(burn_backend, addr).await?; + + Ok(result) } diff --git a/backend/rust/codegen/src/lib.rs b/backend/rust/codegen/src/lib.rs index 208bc595e4cc..219ad14b9ebb 100644 --- a/backend/rust/codegen/src/lib.rs +++ b/backend/rust/codegen/src/lib.rs @@ -1,2 +1,2 @@ -//! Here is for the more complex situation of code generation. For example, defind the constant for -//! the different build target. \ No newline at end of file +//! Here is for the more complex situation of code generation. For example, defind the constant for +//! the different build target. From 7e9215fed8fde3c50093ef00ebda88d0bcaf4015 Mon Sep 17 00:00:00 2001 From: Aisuko Date: Mon, 23 Oct 2023 16:33:52 +1100 Subject: [PATCH 09/18] Ignore Cargo.lock Signed-off-by: Aisuko --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 3986e42857f2..8109a5f58150 100644 --- a/.gitignore +++ b/.gitignore @@ -41,4 +41,5 @@ release/ backend-assets/ prepare /ggml-metal.metal -/backend/rust/target/ +backend/rust/target/ +backend/rust/*.lock From bd087ca915b50b50bed6ecc05613ef6351fa27ac Mon Sep 17 00:00:00 2001 From: Aisuko Date: Tue, 24 Oct 2023 11:19:43 +1100 Subject: [PATCH 10/18] Update .gitignore Co-authored-by: Luca Barbato Signed-off-by: Aisuko --- .gitignore | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 8109a5f58150..384da58dc30f 100644 --- a/.gitignore +++ b/.gitignore @@ -41,5 +41,5 @@ release/ backend-assets/ prepare /ggml-metal.metal -backend/rust/target/ -backend/rust/*.lock +target/ +Cargo.lock From 4397789e507bffee6c7a31be5ca97fde0f85be19 Mon Sep 17 00:00:00 2001 From: Aisuko Date: Tue, 24 Oct 2023 13:44:02 +1100 Subject: [PATCH 11/18] Add tracing Signed-off-by: Aisuko --- backend/rust/Cargo.lock | 2 ++ backend/rust/burn/Cargo.toml | 4 ++++ backend/rust/burn/src/server.rs | 29 +++++++++++++++++++++++++++++ 3 files changed, 35 insertions(+) diff --git a/backend/rust/Cargo.lock b/backend/rust/Cargo.lock index f4d4397ca217..dc09e2752753 100644 --- a/backend/rust/Cargo.lock +++ b/backend/rust/Cargo.lock @@ -196,6 +196,8 @@ dependencies = [ "tokio", "tokio-stream", "tonic", + "tracing", + "tracing-subscriber", ] [[package]] diff --git a/backend/rust/burn/Cargo.toml b/backend/rust/burn/Cargo.toml index 902f2104080f..acc58a67e9ce 100644 --- a/backend/rust/burn/Cargo.toml +++ b/backend/rust/burn/Cargo.toml @@ -13,7 +13,11 @@ path = "src/server.rs" # import bunker here bunker = { path = "../bunker" } + tokio = "1.33.0" async-trait = "0.1.74" tonic = "0.10.2" tokio-stream = "0.1.14" + +tracing = "0.1" +tracing-subscriber = "0.3" \ No newline at end of file diff --git a/backend/rust/burn/src/server.rs b/backend/rust/burn/src/server.rs index c028cc93a97e..6bbb2f7c6b79 100644 --- a/backend/rust/burn/src/server.rs +++ b/backend/rust/burn/src/server.rs @@ -12,6 +12,8 @@ use tonic::{Request, Response, Status}; use async_trait::async_trait; +use tracing::{event, span, Level}; + // implement BackendService trait in bunker #[derive(Default, Debug)] @@ -21,6 +23,7 @@ struct BurnBackend; impl BackendService for BurnBackend { type PredictStreamStream = ReceiverStream>; + #[tracing::instrument] async fn health(&self, request: Request) -> Result, Status> { // return a Result,Status> let reply = Reply { @@ -30,10 +33,12 @@ impl BackendService for BurnBackend { Ok(res) } + #[tracing::instrument] async fn predict(&self, request: Request) -> Result, Status> { todo!() } + #[tracing::instrument] async fn load_model( &self, request: Request, @@ -41,6 +46,7 @@ impl BackendService for BurnBackend { todo!() } + #[tracing::instrument] async fn predict_stream( &self, request: Request, @@ -48,6 +54,7 @@ impl BackendService for BurnBackend { todo!() } + #[tracing::instrument] async fn embedding( &self, request: Request, @@ -55,6 +62,7 @@ impl BackendService for BurnBackend { todo!() } + #[tracing::instrument] async fn generate_image( &self, request: Request, @@ -62,6 +70,7 @@ impl BackendService for BurnBackend { todo!() } + #[tracing::instrument] async fn audio_transcription( &self, request: Request, @@ -69,10 +78,12 @@ impl BackendService for BurnBackend { todo!() } + #[tracing::instrument] async fn tts(&self, request: Request) -> Result, Status> { todo!() } + #[tracing::instrument] async fn tokenize_string( &self, request: Request, @@ -80,6 +91,7 @@ impl BackendService for BurnBackend { todo!() } + #[tracing::instrument] async fn status( &self, request: Request, @@ -90,6 +102,16 @@ impl BackendService for BurnBackend { #[tokio::main] async fn main() -> Result<(), Box> { + let subscriber = tracing_subscriber::fmt() + .compact() + .with_file(true) + .with_line_number(true) + .with_thread_ids(true) + .with_target(false) + .finish(); + + tracing::subscriber::set_global_default(subscriber)?; + // call bunker::run with BurnBackend let burn_backend = BurnBackend {}; let addr = "[::1]:50051" @@ -99,5 +121,12 @@ async fn main() -> Result<(), Box> { // Implmenet Into for addr let result = bunker::run(burn_backend, addr).await?; + event!(Level::INFO, "Burn Server is starting"); + + let span = span!(Level::INFO, "Burn Server"); + let _enter = span.enter(); + + event!(Level::INFO, "Burn Server started successfully"); + Ok(result) } From c0dadcc9d24091c6c8114374cab294b2152dd6aa Mon Sep 17 00:00:00 2001 From: Aisuko Date: Wed, 1 Nov 2023 20:12:03 +1100 Subject: [PATCH 12/18] Add new model Signed-off-by: Aisuko --- backend/rust/Cargo.lock | 1335 ----------------- backend/rust/Cargo.toml | 2 +- backend/rust/Makefile | 2 +- .../rust/{burn => backend-burn}/Cargo.toml | 4 +- backend/rust/{burn => backend-burn}/Makefile | 0 .../server.rs => backend-burn/src/main.rs} | 0 6 files changed, 4 insertions(+), 1339 deletions(-) delete mode 100644 backend/rust/Cargo.lock rename backend/rust/{burn => backend-burn}/Cargo.toml (89%) rename backend/rust/{burn => backend-burn}/Makefile (100%) rename backend/rust/{burn/src/server.rs => backend-burn/src/main.rs} (100%) diff --git a/backend/rust/Cargo.lock b/backend/rust/Cargo.lock deleted file mode 100644 index dc09e2752753..000000000000 --- a/backend/rust/Cargo.lock +++ /dev/null @@ -1,1335 +0,0 @@ -# This file is automatically @generated by Cargo. -# It is not intended for manual editing. -version = 3 - -[[package]] -name = "addr2line" -version = "0.21.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb" -dependencies = [ - "gimli", -] - -[[package]] -name = "adler" -version = "1.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" - -[[package]] -name = "aho-corasick" -version = "1.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2969dcb958b36655471fc61f7e416fa76033bdd4bfed0678d8fee1e2d07a1f0" -dependencies = [ - "memchr", -] - -[[package]] -name = "anyhow" -version = "1.0.75" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4668cab20f66d8d020e1fbc0ebe47217433c1b6c8f2040faf858554e394ace6" - -[[package]] -name = "async-stream" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22068c0c19514942eefcfd4daf8976ef1aad84e61539f95cd200c35202f80af5" -dependencies = [ - "async-stream-impl 0.2.1", - "futures-core", -] - -[[package]] -name = "async-stream" -version = "0.3.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd56dd203fef61ac097dd65721a419ddccb106b2d2b70ba60a6b529f03961a51" -dependencies = [ - "async-stream-impl 0.3.5", - "futures-core", - "pin-project-lite", -] - -[[package]] -name = "async-stream-impl" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25f9db3b38af870bf7e5cc649167533b493928e50744e2c30ae350230b414670" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", -] - -[[package]] -name = "async-stream-impl" -version = "0.3.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.38", -] - -[[package]] -name = "async-trait" -version = "0.1.74" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a66537f1bb974b254c98ed142ff995236e81b9d0fe4db0575f46612cb15eb0f9" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.38", -] - -[[package]] -name = "autocfg" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" - -[[package]] -name = "axum" -version = "0.6.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b829e4e32b91e643de6eafe82b1d90675f5874230191a4ffbc1b336dec4d6bf" -dependencies = [ - "async-trait", - "axum-core", - "bitflags 1.3.2", - "bytes", - "futures-util", - "http", - "http-body", - "hyper", - "itoa", - "matchit", - "memchr", - "mime", - "percent-encoding", - "pin-project-lite", - "rustversion", - "serde", - "sync_wrapper", - "tower", - "tower-layer", - "tower-service", -] - -[[package]] -name = "axum-core" -version = "0.3.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "759fa577a247914fd3f7f76d62972792636412fbfd634cd452f6a385a74d2d2c" -dependencies = [ - "async-trait", - "bytes", - "futures-util", - "http", - "http-body", - "mime", - "rustversion", - "tower-layer", - "tower-service", -] - -[[package]] -name = "backtrace" -version = "0.3.69" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2089b7e3f35b9dd2d0ed921ead4f6d318c27680d4a5bd167b3ee120edb105837" -dependencies = [ - "addr2line", - "cc", - "cfg-if", - "libc", - "miniz_oxide", - "object", - "rustc-demangle", -] - -[[package]] -name = "base64" -version = "0.21.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ba43ea6f343b788c8764558649e08df62f86c6ef251fdaeb1ffd010a9ae50a2" - -[[package]] -name = "bitflags" -version = "1.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" - -[[package]] -name = "bitflags" -version = "2.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07" - -[[package]] -name = "bunker" -version = "0.1.0" -dependencies = [ - "async-stream 0.2.1", - "async-trait", - "prost", - "rand 0.7.3", - "serde", - "serde_json", - "tokio", - "tokio-stream", - "tonic", - "tracing", - "tracing-subscriber", -] - -[[package]] -name = "burn" -version = "0.1.0" -dependencies = [ - "async-trait", - "bunker", - "tokio", - "tokio-stream", - "tonic", - "tracing", - "tracing-subscriber", -] - -[[package]] -name = "bytes" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" - -[[package]] -name = "cc" -version = "1.0.83" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" -dependencies = [ - "libc", -] - -[[package]] -name = "cfg-if" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" - -[[package]] -name = "codegen" -version = "0.1.0" -dependencies = [ - "tonic-build", -] - -[[package]] -name = "either" -version = "1.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" - -[[package]] -name = "equivalent" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" - -[[package]] -name = "errno" -version = "0.3.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac3e13f66a2f95e32a39eaa81f6b95d42878ca0e1db0c7543723dfe12557e860" -dependencies = [ - "libc", - "windows-sys", -] - -[[package]] -name = "fastrand" -version = "2.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25cbce373ec4653f1a01a31e8a5e5ec0c622dc27ff9c4e6606eefef5cbbed4a5" - -[[package]] -name = "fixedbitset" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" - -[[package]] -name = "fnv" -version = "1.0.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" - -[[package]] -name = "futures-channel" -version = "0.3.28" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "955518d47e09b25bbebc7a18df10b81f0c766eaf4c4f1cccef2fca5f2a4fb5f2" -dependencies = [ - "futures-core", -] - -[[package]] -name = "futures-core" -version = "0.3.28" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c" - -[[package]] -name = "futures-sink" -version = "0.3.28" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f43be4fe21a13b9781a69afa4985b0f6ee0e1afab2c6f454a8cf30e2b2237b6e" - -[[package]] -name = "futures-task" -version = "0.3.28" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65" - -[[package]] -name = "futures-util" -version = "0.3.28" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533" -dependencies = [ - "futures-core", - "futures-task", - "pin-project-lite", - "pin-utils", -] - -[[package]] -name = "getrandom" -version = "0.1.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce" -dependencies = [ - "cfg-if", - "libc", - "wasi 0.9.0+wasi-snapshot-preview1", -] - -[[package]] -name = "getrandom" -version = "0.2.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427" -dependencies = [ - "cfg-if", - "libc", - "wasi 0.11.0+wasi-snapshot-preview1", -] - -[[package]] -name = "gimli" -version = "0.28.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6fb8d784f27acf97159b40fc4db5ecd8aa23b9ad5ef69cdd136d3bc80665f0c0" - -[[package]] -name = "h2" -version = "0.3.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91fc23aa11be92976ef4729127f1a74adf36d8436f7816b185d18df956790833" -dependencies = [ - "bytes", - "fnv", - "futures-core", - "futures-sink", - "futures-util", - "http", - "indexmap 1.9.3", - "slab", - "tokio", - "tokio-util", - "tracing", -] - -[[package]] -name = "hashbrown" -version = "0.12.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" - -[[package]] -name = "hashbrown" -version = "0.14.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7dfda62a12f55daeae5015f81b0baea145391cb4520f86c248fc615d72640d12" - -[[package]] -name = "heck" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" - -[[package]] -name = "hermit-abi" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7" - -[[package]] -name = "home" -version = "0.5.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5444c27eef6923071f7ebcc33e3444508466a76f7a2b93da00ed6e19f30c1ddb" -dependencies = [ - "windows-sys", -] - -[[package]] -name = "http" -version = "0.2.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd6effc99afb63425aff9b05836f029929e345a6148a14b7ecd5ab67af944482" -dependencies = [ - "bytes", - "fnv", - "itoa", -] - -[[package]] -name = "http-body" -version = "0.4.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d5f38f16d184e36f2408a55281cd658ecbd3ca05cce6d6510a176eca393e26d1" -dependencies = [ - "bytes", - "http", - "pin-project-lite", -] - -[[package]] -name = "httparse" -version = "1.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904" - -[[package]] -name = "httpdate" -version = "1.0.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" - -[[package]] -name = "hyper" -version = "0.14.27" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ffb1cfd654a8219eaef89881fdb3bb3b1cdc5fa75ded05d6933b2b382e395468" -dependencies = [ - "bytes", - "futures-channel", - "futures-core", - "futures-util", - "h2", - "http", - "http-body", - "httparse", - "httpdate", - "itoa", - "pin-project-lite", - "socket2 0.4.9", - "tokio", - "tower-service", - "tracing", - "want", -] - -[[package]] -name = "hyper-timeout" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbb958482e8c7be4bc3cf272a766a2b0bf1a6755e7a6ae777f017a31d11b13b1" -dependencies = [ - "hyper", - "pin-project-lite", - "tokio", - "tokio-io-timeout", -] - -[[package]] -name = "indexmap" -version = "1.9.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" -dependencies = [ - "autocfg", - "hashbrown 0.12.3", -] - -[[package]] -name = "indexmap" -version = "2.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8adf3ddd720272c6ea8bf59463c04e0f93d0bbf7c5439b691bca2987e0270897" -dependencies = [ - "equivalent", - "hashbrown 0.14.1", -] - -[[package]] -name = "itertools" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57" -dependencies = [ - "either", -] - -[[package]] -name = "itoa" -version = "1.0.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" - -[[package]] -name = "lazy_static" -version = "1.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" - -[[package]] -name = "libc" -version = "0.2.149" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a08173bc88b7955d1b3145aa561539096c421ac8debde8cbc3612ec635fee29b" - -[[package]] -name = "linux-raw-sys" -version = "0.4.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da2479e8c062e40bf0066ffa0bc823de0a9368974af99c9f6df941d2c231e03f" - -[[package]] -name = "log" -version = "0.4.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" - -[[package]] -name = "matchit" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" - -[[package]] -name = "memchr" -version = "2.6.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f665ee40bc4a3c5590afb1e9677db74a508659dfd71e126420da8274909a0167" - -[[package]] -name = "mime" -version = "0.3.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" - -[[package]] -name = "miniz_oxide" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7810e0be55b428ada41041c41f32c9f1a42817901b4ccf45fa3d4b6561e74c7" -dependencies = [ - "adler", -] - -[[package]] -name = "mio" -version = "0.8.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "927a765cd3fc26206e66b296465fa9d3e5ab003e651c1b3c060e7956d96b19d2" -dependencies = [ - "libc", - "wasi 0.11.0+wasi-snapshot-preview1", - "windows-sys", -] - -[[package]] -name = "multimap" -version = "0.8.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a" - -[[package]] -name = "nu-ansi-term" -version = "0.46.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" -dependencies = [ - "overload", - "winapi", -] - -[[package]] -name = "num_cpus" -version = "1.16.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" -dependencies = [ - "hermit-abi", - "libc", -] - -[[package]] -name = "object" -version = "0.32.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9cf5f9dd3933bd50a9e1f149ec995f39ae2c496d31fd772c1fd45ebc27e902b0" -dependencies = [ - "memchr", -] - -[[package]] -name = "once_cell" -version = "1.18.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" - -[[package]] -name = "overload" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" - -[[package]] -name = "percent-encoding" -version = "2.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94" - -[[package]] -name = "petgraph" -version = "0.6.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1d3afd2628e69da2be385eb6f2fd57c8ac7977ceeff6dc166ff1657b0e386a9" -dependencies = [ - "fixedbitset", - "indexmap 2.0.2", -] - -[[package]] -name = "pin-project" -version = "1.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fda4ed1c6c173e3fc7a83629421152e01d7b1f9b7f65fb301e490e8cfc656422" -dependencies = [ - "pin-project-internal", -] - -[[package]] -name = "pin-project-internal" -version = "1.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4359fd9c9171ec6e8c62926d6faaf553a8dc3f64e1507e76da7911b4f6a04405" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.38", -] - -[[package]] -name = "pin-project-lite" -version = "0.2.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58" - -[[package]] -name = "pin-utils" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" - -[[package]] -name = "ppv-lite86" -version = "0.2.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" - -[[package]] -name = "prettyplease" -version = "0.2.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae005bd773ab59b4725093fd7df83fd7892f7d8eafb48dbd7de6e024e4215f9d" -dependencies = [ - "proc-macro2", - "syn 2.0.38", -] - -[[package]] -name = "proc-macro2" -version = "1.0.69" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "134c189feb4956b20f6f547d2cf727d4c0fe06722b20a0eec87ed445a97f92da" -dependencies = [ - "unicode-ident", -] - -[[package]] -name = "prost" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4fdd22f3b9c31b53c060df4a0613a1c7f062d4115a2b984dd15b1858f7e340d" -dependencies = [ - "bytes", - "prost-derive", -] - -[[package]] -name = "prost-build" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8bdf592881d821b83d471f8af290226c8d51402259e9bb5be7f9f8bdebbb11ac" -dependencies = [ - "bytes", - "heck", - "itertools", - "log", - "multimap", - "once_cell", - "petgraph", - "prettyplease", - "prost", - "prost-types", - "regex", - "syn 2.0.38", - "tempfile", - "which", -] - -[[package]] -name = "prost-derive" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "265baba7fabd416cf5078179f7d2cbeca4ce7a9041111900675ea7c4cb8a4c32" -dependencies = [ - "anyhow", - "itertools", - "proc-macro2", - "quote", - "syn 2.0.38", -] - -[[package]] -name = "prost-types" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e081b29f63d83a4bc75cfc9f3fe424f9156cf92d8a4f0c9407cce9a1b67327cf" -dependencies = [ - "prost", -] - -[[package]] -name = "quote" -version = "1.0.33" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" -dependencies = [ - "proc-macro2", -] - -[[package]] -name = "rand" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03" -dependencies = [ - "getrandom 0.1.16", - "libc", - "rand_chacha 0.2.2", - "rand_core 0.5.1", - "rand_hc", -] - -[[package]] -name = "rand" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" -dependencies = [ - "libc", - "rand_chacha 0.3.1", - "rand_core 0.6.4", -] - -[[package]] -name = "rand_chacha" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4c8ed856279c9737206bf725bf36935d8666ead7aa69b52be55af369d193402" -dependencies = [ - "ppv-lite86", - "rand_core 0.5.1", -] - -[[package]] -name = "rand_chacha" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" -dependencies = [ - "ppv-lite86", - "rand_core 0.6.4", -] - -[[package]] -name = "rand_core" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19" -dependencies = [ - "getrandom 0.1.16", -] - -[[package]] -name = "rand_core" -version = "0.6.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" -dependencies = [ - "getrandom 0.2.10", -] - -[[package]] -name = "rand_hc" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c" -dependencies = [ - "rand_core 0.5.1", -] - -[[package]] -name = "redox_syscall" -version = "0.3.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" -dependencies = [ - "bitflags 1.3.2", -] - -[[package]] -name = "regex" -version = "1.10.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "380b951a9c5e80ddfd6136919eef32310721aa4aacd4889a8d39124b026ab343" -dependencies = [ - "aho-corasick", - "memchr", - "regex-automata", - "regex-syntax", -] - -[[package]] -name = "regex-automata" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f804c7828047e88b2d32e2d7fe5a105da8ee3264f01902f796c8e067dc2483f" -dependencies = [ - "aho-corasick", - "memchr", - "regex-syntax", -] - -[[package]] -name = "regex-syntax" -version = "0.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" - -[[package]] -name = "rustc-demangle" -version = "0.1.23" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" - -[[package]] -name = "rustix" -version = "0.38.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "745ecfa778e66b2b63c88a61cb36e0eea109e803b0b86bf9879fbc77c70e86ed" -dependencies = [ - "bitflags 2.4.1", - "errno", - "libc", - "linux-raw-sys", - "windows-sys", -] - -[[package]] -name = "rustversion" -version = "1.0.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ffc183a10b4478d04cbbbfc96d0873219d962dd5accaff2ffbd4ceb7df837f4" - -[[package]] -name = "ryu" -version = "1.0.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741" - -[[package]] -name = "serde" -version = "1.0.189" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e422a44e74ad4001bdc8eede9a4570ab52f71190e9c076d14369f38b9200537" -dependencies = [ - "serde_derive", -] - -[[package]] -name = "serde_derive" -version = "1.0.189" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e48d1f918009ce3145511378cf68d613e3b3d9137d67272562080d68a2b32d5" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.38", -] - -[[package]] -name = "serde_json" -version = "1.0.107" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b420ce6e3d8bd882e9b243c6eed35dbc9a6110c9769e74b584e0d68d1f20c65" -dependencies = [ - "itoa", - "ryu", - "serde", -] - -[[package]] -name = "sharded-slab" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" -dependencies = [ - "lazy_static", -] - -[[package]] -name = "slab" -version = "0.4.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" -dependencies = [ - "autocfg", -] - -[[package]] -name = "smallvec" -version = "1.11.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "942b4a808e05215192e39f4ab80813e599068285906cc91aa64f923db842bd5a" - -[[package]] -name = "socket2" -version = "0.4.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64a4a911eed85daf18834cfaa86a79b7d266ff93ff5ba14005426219480ed662" -dependencies = [ - "libc", - "winapi", -] - -[[package]] -name = "socket2" -version = "0.5.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4031e820eb552adee9295814c0ced9e5cf38ddf1e8b7d566d6de8e2538ea989e" -dependencies = [ - "libc", - "windows-sys", -] - -[[package]] -name = "syn" -version = "1.0.109" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", -] - -[[package]] -name = "syn" -version = "2.0.38" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e96b79aaa137db8f61e26363a0c9b47d8b4ec75da28b7d1d614c2303e232408b" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", -] - -[[package]] -name = "sync_wrapper" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" - -[[package]] -name = "tempfile" -version = "3.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb94d2f3cc536af71caac6b6fcebf65860b347e7ce0cc9ebe8f70d3e521054ef" -dependencies = [ - "cfg-if", - "fastrand", - "redox_syscall", - "rustix", - "windows-sys", -] - -[[package]] -name = "thread_local" -version = "1.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdd6f064ccff2d6567adcb3873ca630700f00b5ad3f060c25b5dcfd9a4ce152" -dependencies = [ - "cfg-if", - "once_cell", -] - -[[package]] -name = "tokio" -version = "1.33.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f38200e3ef7995e5ef13baec2f432a6da0aa9ac495b2c0e8f3b7eec2c92d653" -dependencies = [ - "backtrace", - "bytes", - "libc", - "mio", - "num_cpus", - "pin-project-lite", - "socket2 0.5.4", - "tokio-macros", - "windows-sys", -] - -[[package]] -name = "tokio-io-timeout" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30b74022ada614a1b4834de765f9bb43877f910cc8ce4be40e89042c9223a8bf" -dependencies = [ - "pin-project-lite", - "tokio", -] - -[[package]] -name = "tokio-macros" -version = "2.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.38", -] - -[[package]] -name = "tokio-stream" -version = "0.1.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "397c988d37662c7dda6d2208364a706264bf3d6138b11d436cbac0ad38832842" -dependencies = [ - "futures-core", - "pin-project-lite", - "tokio", -] - -[[package]] -name = "tokio-util" -version = "0.7.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d68074620f57a0b21594d9735eb2e98ab38b17f80d3fcb189fca266771ca60d" -dependencies = [ - "bytes", - "futures-core", - "futures-sink", - "pin-project-lite", - "tokio", - "tracing", -] - -[[package]] -name = "tonic" -version = "0.10.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d560933a0de61cf715926b9cac824d4c883c2c43142f787595e48280c40a1d0e" -dependencies = [ - "async-stream 0.3.5", - "async-trait", - "axum", - "base64", - "bytes", - "h2", - "http", - "http-body", - "hyper", - "hyper-timeout", - "percent-encoding", - "pin-project", - "prost", - "tokio", - "tokio-stream", - "tower", - "tower-layer", - "tower-service", - "tracing", -] - -[[package]] -name = "tonic-build" -version = "0.10.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d021fc044c18582b9a2408cd0dd05b1596e3ecdb5c4df822bb0183545683889" -dependencies = [ - "prettyplease", - "proc-macro2", - "prost-build", - "quote", - "syn 2.0.38", -] - -[[package]] -name = "tower" -version = "0.4.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" -dependencies = [ - "futures-core", - "futures-util", - "indexmap 1.9.3", - "pin-project", - "pin-project-lite", - "rand 0.8.5", - "slab", - "tokio", - "tokio-util", - "tower-layer", - "tower-service", - "tracing", -] - -[[package]] -name = "tower-layer" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0" - -[[package]] -name = "tower-service" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" - -[[package]] -name = "tracing" -version = "0.1.39" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee2ef2af84856a50c1d430afce2fdded0a4ec7eda868db86409b4543df0797f9" -dependencies = [ - "pin-project-lite", - "tracing-attributes", - "tracing-core", -] - -[[package]] -name = "tracing-attributes" -version = "0.1.27" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.38", -] - -[[package]] -name = "tracing-core" -version = "0.1.32" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" -dependencies = [ - "once_cell", - "valuable", -] - -[[package]] -name = "tracing-log" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78ddad33d2d10b1ed7eb9d1f518a5674713876e97e5bb9b7345a7984fbb4f922" -dependencies = [ - "lazy_static", - "log", - "tracing-core", -] - -[[package]] -name = "tracing-subscriber" -version = "0.3.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30a651bc37f915e81f087d86e62a18eec5f79550c7faff886f7090b4ea757c77" -dependencies = [ - "nu-ansi-term", - "sharded-slab", - "smallvec", - "thread_local", - "tracing-core", - "tracing-log", -] - -[[package]] -name = "try-lock" -version = "0.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed" - -[[package]] -name = "unicode-ident" -version = "1.0.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" - -[[package]] -name = "valuable" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" - -[[package]] -name = "want" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e" -dependencies = [ - "try-lock", -] - -[[package]] -name = "wasi" -version = "0.9.0+wasi-snapshot-preview1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519" - -[[package]] -name = "wasi" -version = "0.11.0+wasi-snapshot-preview1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" - -[[package]] -name = "which" -version = "4.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7" -dependencies = [ - "either", - "home", - "once_cell", - "rustix", -] - -[[package]] -name = "winapi" -version = "0.3.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" -dependencies = [ - "winapi-i686-pc-windows-gnu", - "winapi-x86_64-pc-windows-gnu", -] - -[[package]] -name = "winapi-i686-pc-windows-gnu" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" - -[[package]] -name = "winapi-x86_64-pc-windows-gnu" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" - -[[package]] -name = "windows-sys" -version = "0.48.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" -dependencies = [ - "windows-targets", -] - -[[package]] -name = "windows-targets" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" -dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", -] - -[[package]] -name = "windows_aarch64_gnullvm" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" - -[[package]] -name = "windows_aarch64_msvc" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" - -[[package]] -name = "windows_i686_gnu" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" - -[[package]] -name = "windows_i686_msvc" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" - -[[package]] -name = "windows_x86_64_gnu" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" - -[[package]] -name = "windows_x86_64_gnullvm" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" - -[[package]] -name = "windows_x86_64_msvc" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" diff --git a/backend/rust/Cargo.toml b/backend/rust/Cargo.toml index 183192a8eef4..3c6e726b6958 100644 --- a/backend/rust/Cargo.toml +++ b/backend/rust/Cargo.toml @@ -2,6 +2,6 @@ resolver = "2" members = [ "bunker", - "burn", + "backend-burn", "codegen", ] diff --git a/backend/rust/Makefile b/backend/rust/Makefile index 283ebc08b61a..2b25e025084a 100644 --- a/backend/rust/Makefile +++ b/backend/rust/Makefile @@ -32,7 +32,7 @@ clean: .PHONY: burn burn: @echo "Burning..." - @cargo run --bin server --package burn + @cargo run --bin server --package backend-burn .PHONY: health health: diff --git a/backend/rust/burn/Cargo.toml b/backend/rust/backend-burn/Cargo.toml similarity index 89% rename from backend/rust/burn/Cargo.toml rename to backend/rust/backend-burn/Cargo.toml index acc58a67e9ce..007b3b44a4ad 100644 --- a/backend/rust/burn/Cargo.toml +++ b/backend/rust/backend-burn/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "burn" +name = "backend-burn" version = "0.1.0" edition = "2021" @@ -7,7 +7,7 @@ edition = "2021" [[bin]] name = "server" -path = "src/server.rs" +path = "src/main.rs" [dependencies] diff --git a/backend/rust/burn/Makefile b/backend/rust/backend-burn/Makefile similarity index 100% rename from backend/rust/burn/Makefile rename to backend/rust/backend-burn/Makefile diff --git a/backend/rust/burn/src/server.rs b/backend/rust/backend-burn/src/main.rs similarity index 100% rename from backend/rust/burn/src/server.rs rename to backend/rust/backend-burn/src/main.rs From fb67c91a053f4ed0342cee30d8fccd4888cbf63b Mon Sep 17 00:00:00 2001 From: Aisuko Date: Wed, 1 Nov 2023 21:55:09 +1100 Subject: [PATCH 13/18] Implement a new simple model Signed-off-by: Aisuko --- backend/rust/Cargo.toml | 3 +- backend/rust/models/Cargo.toml | 10 +++ backend/rust/models/src/lib.rs | 1 + backend/rust/models/src/onnx/inference.rs | 1 + backend/rust/models/src/onnx/mod.rs | 90 +++++++++++++++++++++++ 5 files changed, 104 insertions(+), 1 deletion(-) create mode 100644 backend/rust/models/Cargo.toml create mode 100644 backend/rust/models/src/lib.rs create mode 100644 backend/rust/models/src/onnx/inference.rs create mode 100644 backend/rust/models/src/onnx/mod.rs diff --git a/backend/rust/Cargo.toml b/backend/rust/Cargo.toml index 3c6e726b6958..60719544e525 100644 --- a/backend/rust/Cargo.toml +++ b/backend/rust/Cargo.toml @@ -4,4 +4,5 @@ members = [ "bunker", "backend-burn", "codegen", -] + "models", +] \ No newline at end of file diff --git a/backend/rust/models/Cargo.toml b/backend/rust/models/Cargo.toml new file mode 100644 index 000000000000..f092f75b6c98 --- /dev/null +++ b/backend/rust/models/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "models" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +burn = { version="0.10.0", features=["ndarray"] } # https://github.com/mudler/LocalAI/discussions/1219 +serde = "1.0.190" diff --git a/backend/rust/models/src/lib.rs b/backend/rust/models/src/lib.rs new file mode 100644 index 000000000000..bb2b1591647f --- /dev/null +++ b/backend/rust/models/src/lib.rs @@ -0,0 +1 @@ +pub(crate) mod onnx; diff --git a/backend/rust/models/src/onnx/inference.rs b/backend/rust/models/src/onnx/inference.rs new file mode 100644 index 000000000000..febee9cf4b0b --- /dev/null +++ b/backend/rust/models/src/onnx/inference.rs @@ -0,0 +1 @@ +use std::env::args; diff --git a/backend/rust/models/src/onnx/mod.rs b/backend/rust/models/src/onnx/mod.rs new file mode 100644 index 000000000000..6cb6eb000455 --- /dev/null +++ b/backend/rust/models/src/onnx/mod.rs @@ -0,0 +1,90 @@ +//! Defination of a mninst model and config of it. +//! The source code is from https://github.com/burn-rs/burn/blob/main/examples/mnist-inference-web/src/model.rs +//! The license is Apache-2.0 and MIT. +//! Adapter by Aisuko + +pub(crate) mod inference; +use inference::*; + +use burn::{ + module::Module, + nn::{self, BatchNorm, PaddingConfig2d}, + tensor::{backend::Backend, Tensor}, +}; + +const NUM_CLASSES: usize = 10; + +#[derive(Module, Debug)] +/// A struct representing an ONNX model. +pub struct Model { + /// The first convolutional block of the model. + conv1: ConvBlock, + /// The second convolutional block of the model. + conv2: ConvBlock, + /// The third convolutional block of the model. + conv3: ConvBlock, + /// A dropout layer used in the model. + dropout: nn::Dropout, + /// The first fully connected layer of the model. + fc1: nn::Linear, + /// The second fully connected layer of the model. + fc2: nn::Linear, + /// The activation function used in the model. + activation: nn::GELU, +} + +impl Model { + pub fn new() -> Self { + todo!("Implement the Model::new() function") + } + + pub fn forward(&self, input: Tensor) -> Tensor { + todo!("Implement the Model::forward() function") + } +} + +/// A struct representing a convolutional block in a neural network model. +#[derive(Module, Debug)] +pub struct ConvBlock { + /// A 2D convolutional layer. + conv: nn::conv::Conv2d, + /// A batch normalization layer. + norm: BatchNorm, + /// A GELU activation function. + activation: nn::GELU, +} + +/// A convolutional block with batch normalization and GELU activation. +impl ConvBlock { + /// Creates a new `ConvBlock` with the given number of output channels and kernel size. + pub fn new(channels: [usize; 2], kernel_size: [usize; 2]) -> Self { + // Initialize a 2D convolutional layer with the given output channels and kernel size, + // and set the padding to "valid". + let conv = nn::conv::Conv2dConfig::new(channels, kernel_size) + .with_padding(PaddingConfig2d::Valid) + .init(); + + // Initialize a batch normalization layer with the number of channels in the second dimension of the output. + let norm = nn::BatchNormConfig::new(channels[1]).init(); + + // Create a new `ConvBlock` with the initialized convolutional and batch normalization layers, + // and a GELU activation function. + Self { + conv: conv, + norm: norm, + activation: nn::GELU::new(), + } + } + + /// Applies the convolutional block to the given input tensor. + pub fn forward(&self, input: Tensor) -> Tensor { + // Apply the convolutional layer to the input tensor. + let x = self.conv.forward(input); + + // Apply the batch normalization layer to the output of the convolutional layer. + let x = self.norm.forward(x); + + // Apply the GELU activation function to the output of the batch normalization layer. + self.activation.forward(x) + } +} From 1d2fd99d2f507476bba475976b1629b5ed2c69fa Mon Sep 17 00:00:00 2001 From: Aisuko Date: Thu, 2 Nov 2023 12:53:19 +1100 Subject: [PATCH 14/18] Implement MNIST model and inference Signed-off-by: Aisuko --- .gitignore | 1 + backend/rust/Makefile | 10 ++ backend/rust/backend-burn/Cargo.toml | 1 + backend/rust/backend-burn/src/main.rs | 21 ++- backend/rust/models/Cargo.toml | 9 +- backend/rust/models/src/lib.rs | 10 +- backend/rust/models/src/mnist/mnist.rs | 185 ++++++++++++++++++++++ backend/rust/models/src/mnist/mod.rs | 33 ++++ backend/rust/models/src/onnx/inference.rs | 1 - backend/rust/models/src/onnx/mod.rs | 90 ----------- 10 files changed, 267 insertions(+), 94 deletions(-) create mode 100644 backend/rust/models/src/mnist/mnist.rs create mode 100644 backend/rust/models/src/mnist/mod.rs delete mode 100644 backend/rust/models/src/onnx/inference.rs delete mode 100644 backend/rust/models/src/onnx/mod.rs diff --git a/.gitignore b/.gitignore index 384da58dc30f..ef15c70d86eb 100644 --- a/.gitignore +++ b/.gitignore @@ -43,3 +43,4 @@ prepare /ggml-metal.metal target/ Cargo.lock +model.bin diff --git a/backend/rust/Makefile b/backend/rust/Makefile index 2b25e025084a..ad550d67e9b9 100644 --- a/backend/rust/Makefile +++ b/backend/rust/Makefile @@ -34,6 +34,16 @@ burn: @echo "Burning..." @cargo run --bin server --package backend-burn + +############################################################################################################ +# gRPC testing commands + + +.PHONY: list +list: + @echo "Burning..." + @grpcurl -plaintext -import-path ../../../pkg/grpc/proto -proto backend.proto list backend.Backend + .PHONY: health health: @echo "Burning..." diff --git a/backend/rust/backend-burn/Cargo.toml b/backend/rust/backend-burn/Cargo.toml index 007b3b44a4ad..f97347d324b8 100644 --- a/backend/rust/backend-burn/Cargo.toml +++ b/backend/rust/backend-burn/Cargo.toml @@ -13,6 +13,7 @@ path = "src/main.rs" # import bunker here bunker = { path = "../bunker" } +models = { path = "../models" } tokio = "1.33.0" async-trait = "0.1.74" diff --git a/backend/rust/backend-burn/src/main.rs b/backend/rust/backend-burn/src/main.rs index 6bbb2f7c6b79..9ee1d96cf4c2 100644 --- a/backend/rust/backend-burn/src/main.rs +++ b/backend/rust/backend-burn/src/main.rs @@ -14,6 +14,7 @@ use async_trait::async_trait; use tracing::{event, span, Level}; +use models::*; // implement BackendService trait in bunker #[derive(Default, Debug)] @@ -35,7 +36,25 @@ impl BackendService for BurnBackend { #[tracing::instrument] async fn predict(&self, request: Request) -> Result, Status> { - todo!() + let mut models: Vec> = vec![Box::new(models::MNINST::new())]; + let result = models[0].predict(request.into_inner()); + + match result { + Ok(res) => { + let reply = Reply { + message: res.into(), + }; + let res = Response::new(reply); + Ok(res) + } + Err(e) => { + let reply = Reply { + message: e.to_string().into(), + }; + let res = Response::new(reply); + Ok(res) + } + } } #[tracing::instrument] diff --git a/backend/rust/models/Cargo.toml b/backend/rust/models/Cargo.toml index f092f75b6c98..8cb651876cf4 100644 --- a/backend/rust/models/Cargo.toml +++ b/backend/rust/models/Cargo.toml @@ -5,6 +5,13 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[features] +default = ["ndarray"] + +ndarray = ["burn/ndarray"] +wgpu = ["burn/wgpu"] + [dependencies] -burn = { version="0.10.0", features=["ndarray"] } # https://github.com/mudler/LocalAI/discussions/1219 +bunker = { path = "../bunker" } +burn = { version="0.10.0", features=["ndarray","wgpu"] } # https://github.com/mudler/LocalAI/discussions/1219 serde = "1.0.190" diff --git a/backend/rust/models/src/lib.rs b/backend/rust/models/src/lib.rs index bb2b1591647f..f3302e83ef73 100644 --- a/backend/rust/models/src/lib.rs +++ b/backend/rust/models/src/lib.rs @@ -1 +1,9 @@ -pub(crate) mod onnx; +pub(crate) mod mnist; +pub use mnist::mnist::MNINST; + +use bunker::pb::{ModelOptions, PredictOptions}; + +pub trait LLM { + fn load_model(&mut self, request: ModelOptions) -> Result>; + fn predict(&mut self, request: PredictOptions) -> Result>; +} diff --git a/backend/rust/models/src/mnist/mnist.rs b/backend/rust/models/src/mnist/mnist.rs new file mode 100644 index 000000000000..995b2706ed05 --- /dev/null +++ b/backend/rust/models/src/mnist/mnist.rs @@ -0,0 +1,185 @@ +//! Defination of a mninst model and config of it. +//! The source code is from https://github.com/burn-rs/burn/blob/main/examples/mnist-inference-web/src/model.rs +//! The license is Apache-2.0 and MIT. +//! Adapter by Aisuko + +use burn::{ + backend::wgpu::{compute::init_async, AutoGraphicsApi, WgpuDevice}, + module::Module, + nn::{self, BatchNorm, PaddingConfig2d}, + record::{BinBytesRecorder, FullPrecisionSettings, Recorder}, + tensor::{backend::Backend, Tensor}, +}; + +// https://github.com/burn-rs/burn/blob/main/examples/mnist-inference-web/model.bin +static STATE_ENCODED: &[u8] = include_bytes!("model.bin"); + +const NUM_CLASSES: usize = 10; + +#[derive(Module, Debug)] +/// A struct representing an MNINST model. +pub struct MNINST { + /// The first convolutional block of the model. + conv1: ConvBlock, + /// The second convolutional block of the model. + conv2: ConvBlock, + /// The third convolutional block of the model. + conv3: ConvBlock, + /// A dropout layer used in the model. + dropout: nn::Dropout, + /// The first fully connected layer of the model. + fc1: nn::Linear, + /// The second fully connected layer of the model. + fc2: nn::Linear, + /// The activation function used in the model. + activation: nn::GELU, +} + +impl MNINST { + pub fn new() -> Self { + let conv1 = ConvBlock::new([1, 8], [3, 3]); // 1 input channel, 8 output channels, 3x3 kernel size + let conv2 = ConvBlock::new([8, 16], [3, 3]); // 8 input channels, 16 output channels, 3x3 kernel size + let conv3 = ConvBlock::new([16, 24], [3, 3]); // 16 input channels, 24 output channels, 3x3 kernel size + let hidden_size = 24 * 22 * 22; + let fc1 = nn::LinearConfig::new(hidden_size, 32) + .with_bias(false) + .init(); + let fc2 = nn::LinearConfig::new(32, NUM_CLASSES) + .with_bias(false) + .init(); + + let dropout = nn::DropoutConfig::new(0.5).init(); + + let instance = Self { + conv1: conv1, + conv2: conv2, + conv3: conv3, + dropout: dropout, + fc1: fc1, + fc2: fc2, + activation: nn::GELU::new(), + }; + let record = BinBytesRecorder::::default() + .load(STATE_ENCODED.to_vec()) + .expect("Failed to decode state"); + + instance.load_record(record) + } + + /// Applies the forward pass of the neural network on the given input tensor. + /// + /// # Arguments + /// + /// * `input` - A 3-dimensional tensor of shape [batch_size, height, width]. + /// + /// # Returns + /// + /// A 2-dimensional tensor of shape [batch_size, num_classes] containing the output of the neural network. + pub fn forward(&self, input: Tensor) -> Tensor { + // Get the dimensions of the input tensor + let [batch_size, height, width] = input.dims(); + // Reshape the input tensor to have a shape of [batch_size, 1, height, width] and detach it + let x = input.reshape([batch_size, 1, height, width]).detach(); + // Apply the first convolutional layer to the input tensor + let x = self.conv1.forward(x); + // Apply the second convolutional layer to the output of the first convolutional layer + let x = self.conv2.forward(x); + // Apply the third convolutional layer to the output of the second convolutional layer + let x = self.conv3.forward(x); + + // Get the dimensions of the output tensor from the third convolutional layer + let [batch_size, channels, height, width] = x.dims(); + // Reshape the output tensor to have a shape of [batch_size, channels*height*width] + let x = x.reshape([batch_size, channels * height * width]); + + // Apply dropout to the output of the third convolutional layer + let x = self.dropout.forward(x); + // Apply the first fully connected layer to the output of the dropout layer + let x = self.fc1.forward(x); + // Apply the activation function to the output of the first fully connected layer + let x = self.activation.forward(x); + + // Apply the second fully connected layer to the output of the activation function + self.fc2.forward(x) + } + + pub fn inference(&mut self, input: &[f32]) -> Result, Box> { + // Reshape from the 1D array to 3d tensor [batch, height, width] + let input: Tensor = Tensor::from_floats(input).reshape([1, 28, 28]); + + // Normalize input: make between [0,1] and make the mean=0 and std=1 + // values mean=0.1307, std=0.3081 + // Source: https://github.com/pytorch/examples/blob/54f4572509891883a947411fd7239237dd2a39c3/mnist/main.py#L122 + let input = ((input / 255) - 0.1307) / 0.3081; + + // Run the tensor input through the model + let output: Tensor = self.forward(input); + + // Convert the model output into probalibility distribution using softmax formula + let output = burn::tensor::activation::softmax(output, 1); + + // Flatten oupuut tensor with [1,10] shape into boxed slice of [f32] + let output = output.into_data().convert::().value; + + Ok(output) + } +} + +/// A struct representing a convolutional block in a neural network model. +#[derive(Module, Debug)] +pub struct ConvBlock { + /// A 2D convolutional layer. + conv: nn::conv::Conv2d, + /// A batch normalization layer. + norm: BatchNorm, + /// A GELU activation function. + activation: nn::GELU, +} + +/// A convolutional block with batch normalization and GELU activation. +impl ConvBlock { + /// Creates a new `ConvBlock` with the given number of output channels and kernel size. + pub fn new(channels: [usize; 2], kernel_size: [usize; 2]) -> Self { + // Initialize a 2D convolutional layer with the given output channels and kernel size, + // and set the padding to "valid". + let conv = nn::conv::Conv2dConfig::new(channels, kernel_size) + .with_padding(PaddingConfig2d::Valid) + .init(); + + // Initialize a batch normalization layer with the number of channels in the second dimension of the output. + let norm = nn::BatchNormConfig::new(channels[1]).init(); + + // Create a new `ConvBlock` with the initialized convolutional and batch normalization layers, + // and a GELU activation function. + Self { + conv: conv, + norm: norm, + activation: nn::GELU::new(), + } + } + + /// Applies the convolutional block to the given input tensor. + pub fn forward(&self, input: Tensor) -> Tensor { + // Apply the convolutional layer to the input tensor. + let x = self.conv.forward(input); + + // Apply the batch normalization layer to the output of the convolutional layer. + let x = self.norm.forward(x); + + // Apply the GELU activation function to the output of the batch normalization layer. + self.activation.forward(x) + } +} + +#[cfg(test)] +mod tests { + use super::*; + #[cfg(feature = "ndarray")] + pub type Backend = burn::backend::NdArrayBackend; + #[test] + fn test_inference() { + let mut model = MNINST::::new(); + let output = model.inference(&[0.0; 28 * 28]).unwrap(); + assert_eq!(output.len(), 10); + } +} diff --git a/backend/rust/models/src/mnist/mod.rs b/backend/rust/models/src/mnist/mod.rs new file mode 100644 index 000000000000..d53b76c6c7a4 --- /dev/null +++ b/backend/rust/models/src/mnist/mod.rs @@ -0,0 +1,33 @@ +use crate::LLM; +use bunker::pb::{ModelOptions, PredictOptions}; + +pub(crate) mod mnist; + +#[cfg(feature = "ndarray")] +pub type Backend = burn::backend::NdArrayBackend; + +impl LLM for mnist::MNINST { + fn load_model(&mut self, request: ModelOptions) -> Result> { + todo!("load model") + } + + fn predict(&mut self, pre_ops: PredictOptions) -> Result> { + // convert prost::alloc::string::String to &[f32] + let input = pre_ops.prompt.as_bytes(); + let input = input.iter().map(|x| *x as f32).collect::>(); + + let result = self.inference(&input); + + match result { + Ok(output) => { + let output = output + .iter() + .map(|f| f.to_string()) + .collect::>() + .join(","); + Ok(output) + } + Err(e) => Err(e), + } + } +} diff --git a/backend/rust/models/src/onnx/inference.rs b/backend/rust/models/src/onnx/inference.rs deleted file mode 100644 index febee9cf4b0b..000000000000 --- a/backend/rust/models/src/onnx/inference.rs +++ /dev/null @@ -1 +0,0 @@ -use std::env::args; diff --git a/backend/rust/models/src/onnx/mod.rs b/backend/rust/models/src/onnx/mod.rs deleted file mode 100644 index 6cb6eb000455..000000000000 --- a/backend/rust/models/src/onnx/mod.rs +++ /dev/null @@ -1,90 +0,0 @@ -//! Defination of a mninst model and config of it. -//! The source code is from https://github.com/burn-rs/burn/blob/main/examples/mnist-inference-web/src/model.rs -//! The license is Apache-2.0 and MIT. -//! Adapter by Aisuko - -pub(crate) mod inference; -use inference::*; - -use burn::{ - module::Module, - nn::{self, BatchNorm, PaddingConfig2d}, - tensor::{backend::Backend, Tensor}, -}; - -const NUM_CLASSES: usize = 10; - -#[derive(Module, Debug)] -/// A struct representing an ONNX model. -pub struct Model { - /// The first convolutional block of the model. - conv1: ConvBlock, - /// The second convolutional block of the model. - conv2: ConvBlock, - /// The third convolutional block of the model. - conv3: ConvBlock, - /// A dropout layer used in the model. - dropout: nn::Dropout, - /// The first fully connected layer of the model. - fc1: nn::Linear, - /// The second fully connected layer of the model. - fc2: nn::Linear, - /// The activation function used in the model. - activation: nn::GELU, -} - -impl Model { - pub fn new() -> Self { - todo!("Implement the Model::new() function") - } - - pub fn forward(&self, input: Tensor) -> Tensor { - todo!("Implement the Model::forward() function") - } -} - -/// A struct representing a convolutional block in a neural network model. -#[derive(Module, Debug)] -pub struct ConvBlock { - /// A 2D convolutional layer. - conv: nn::conv::Conv2d, - /// A batch normalization layer. - norm: BatchNorm, - /// A GELU activation function. - activation: nn::GELU, -} - -/// A convolutional block with batch normalization and GELU activation. -impl ConvBlock { - /// Creates a new `ConvBlock` with the given number of output channels and kernel size. - pub fn new(channels: [usize; 2], kernel_size: [usize; 2]) -> Self { - // Initialize a 2D convolutional layer with the given output channels and kernel size, - // and set the padding to "valid". - let conv = nn::conv::Conv2dConfig::new(channels, kernel_size) - .with_padding(PaddingConfig2d::Valid) - .init(); - - // Initialize a batch normalization layer with the number of channels in the second dimension of the output. - let norm = nn::BatchNormConfig::new(channels[1]).init(); - - // Create a new `ConvBlock` with the initialized convolutional and batch normalization layers, - // and a GELU activation function. - Self { - conv: conv, - norm: norm, - activation: nn::GELU::new(), - } - } - - /// Applies the convolutional block to the given input tensor. - pub fn forward(&self, input: Tensor) -> Tensor { - // Apply the convolutional layer to the input tensor. - let x = self.conv.forward(input); - - // Apply the batch normalization layer to the output of the convolutional layer. - let x = self.norm.forward(x); - - // Apply the GELU activation function to the output of the batch normalization layer. - self.activation.forward(x) - } -} From 660cc49746df7d8818a015f09d7fb949ff2b085f Mon Sep 17 00:00:00 2001 From: Aisuko Date: Sat, 4 Nov 2023 13:11:22 +1100 Subject: [PATCH 15/18] Add check memory feature Signed-off-by: Aisuko --- backend/rust/Makefile | 9 ++++- backend/rust/backend-burn/src/main.rs | 54 +++++++++++++++++++++++++-- 2 files changed, 59 insertions(+), 4 deletions(-) diff --git a/backend/rust/Makefile b/backend/rust/Makefile index ad550d67e9b9..4213797500f8 100644 --- a/backend/rust/Makefile +++ b/backend/rust/Makefile @@ -35,6 +35,7 @@ burn: @cargo run --bin server --package backend-burn + ############################################################################################################ # gRPC testing commands @@ -47,4 +48,10 @@ list: .PHONY: health health: @echo "Burning..." - grpcurl -plaintext -import-path ../../pkg/grpc/proto -proto backend.proto -d '{}' '[::1]:50051' backend.Backend/Health + @grpcurl -plaintext -import-path ../../pkg/grpc/proto -proto backend.proto -d '{}' '[::1]:50051' backend.Backend/Health + + +.PHONY: status +status: + @echo "Burning..." + @grpcurl -plaintext -import-path ../../pkg/grpc/proto -proto backend.proto -d '{}' '[::1]:50051' backend.Backend/Status diff --git a/backend/rust/backend-burn/src/main.rs b/backend/rust/backend-burn/src/main.rs index 9ee1d96cf4c2..6aadfaba69e6 100644 --- a/backend/rust/backend-burn/src/main.rs +++ b/backend/rust/backend-burn/src/main.rs @@ -1,9 +1,11 @@ +use std::collections::HashMap; use std::net::SocketAddr; use bunker::pb::Result as PbResult; use bunker::pb::{ - EmbeddingResult, GenerateImageRequest, HealthMessage, ModelOptions, PredictOptions, Reply, - StatusResponse, TokenizationResponse, TranscriptRequest, TranscriptResult, TtsRequest, + EmbeddingResult, GenerateImageRequest, HealthMessage, MemoryUsageData, ModelOptions, + PredictOptions, Reply, StatusResponse, TokenizationResponse, TranscriptRequest, + TranscriptResult, TtsRequest, }; use bunker::BackendService; @@ -13,6 +15,10 @@ use tonic::{Request, Response, Status}; use async_trait::async_trait; use tracing::{event, span, Level}; +use tracing_subscriber::filter::LevelParseError; + +use std::fs; +use std::process::{Command,id}; use models::*; // implement BackendService trait in bunker @@ -115,7 +121,49 @@ impl BackendService for BurnBackend { &self, request: Request, ) -> Result, Status> { - todo!() + + // Here we do not need to cover the windows platform + let mut breakdown = HashMap::new(); + let mut memory_usage: u64=0; + + #[cfg(target_os = "linux")] + { + let pid =id(); + let stat = fs::read_to_string(format!("/proc/{}/stat", pid)).expect("Failed to read stat file"); + + let stats: Vec<&str> = stat.split_whitespace().collect(); + memory_usage = stats[23].parse::().expect("Failed to parse RSS"); + } + + #[cfg(target_os="macos")] + { + let output=Command::new("ps") + .arg("-p") + .arg(id().to_string()) + .arg("-o") + .arg("rss=") + .output() + .expect("failed to execute process"); + + memory_usage = String::from_utf8_lossy(&output.stdout) + .trim() + .parse::() + .expect("Failed to parse memory usage"); + + } + breakdown.insert("RSS".to_string(), memory_usage); + + let memory_usage = Option::from(MemoryUsageData { + total: memory_usage, + breakdown, + }); + + let reponse = StatusResponse { + state: 0, //TODO: add state https://github.com/mudler/LocalAI/blob/9b17af18b3aa0c3cab16284df2d6f691736c30c1/pkg/grpc/proto/backend.proto#L188C9-L188C9 + memory: memory_usage, + }; + + Ok(Response::new(reponse)) } } From d62c701403a2125c3ddde42e033f1e6051a24f9e Mon Sep 17 00:00:00 2001 From: Aisuko Date: Sat, 18 Nov 2023 12:21:46 +1100 Subject: [PATCH 16/18] Trying to call mnist model in main Signed-off-by: Aisuko --- backend/rust/Cargo.toml | 2 +- backend/rust/backend-burn/src/main.rs | 199 ----------- .../rust/{backend-burn => backend}/Cargo.toml | 2 +- .../rust/{backend-burn => backend}/Makefile | 0 backend/rust/backend/src/main.rs | 309 ++++++++++++++++++ backend/rust/models/src/lib.rs | 11 +- backend/rust/models/src/mnist/mnist.rs | 10 +- backend/rust/models/src/mnist/mod.rs | 12 +- 8 files changed, 334 insertions(+), 211 deletions(-) delete mode 100644 backend/rust/backend-burn/src/main.rs rename backend/rust/{backend-burn => backend}/Cargo.toml (94%) rename backend/rust/{backend-burn => backend}/Makefile (100%) create mode 100644 backend/rust/backend/src/main.rs diff --git a/backend/rust/Cargo.toml b/backend/rust/Cargo.toml index 60719544e525..c9ec91b1559f 100644 --- a/backend/rust/Cargo.toml +++ b/backend/rust/Cargo.toml @@ -2,7 +2,7 @@ resolver = "2" members = [ "bunker", - "backend-burn", + "backend", "codegen", "models", ] \ No newline at end of file diff --git a/backend/rust/backend-burn/src/main.rs b/backend/rust/backend-burn/src/main.rs deleted file mode 100644 index 6aadfaba69e6..000000000000 --- a/backend/rust/backend-burn/src/main.rs +++ /dev/null @@ -1,199 +0,0 @@ -use std::collections::HashMap; -use std::net::SocketAddr; - -use bunker::pb::Result as PbResult; -use bunker::pb::{ - EmbeddingResult, GenerateImageRequest, HealthMessage, MemoryUsageData, ModelOptions, - PredictOptions, Reply, StatusResponse, TokenizationResponse, TranscriptRequest, - TranscriptResult, TtsRequest, -}; - -use bunker::BackendService; -use tokio_stream::wrappers::ReceiverStream; -use tonic::{Request, Response, Status}; - -use async_trait::async_trait; - -use tracing::{event, span, Level}; -use tracing_subscriber::filter::LevelParseError; - -use std::fs; -use std::process::{Command,id}; - -use models::*; -// implement BackendService trait in bunker - -#[derive(Default, Debug)] -struct BurnBackend; - -#[async_trait] -impl BackendService for BurnBackend { - type PredictStreamStream = ReceiverStream>; - - #[tracing::instrument] - async fn health(&self, request: Request) -> Result, Status> { - // return a Result,Status> - let reply = Reply { - message: "OK".into(), - }; - let res = Response::new(reply); - Ok(res) - } - - #[tracing::instrument] - async fn predict(&self, request: Request) -> Result, Status> { - let mut models: Vec> = vec![Box::new(models::MNINST::new())]; - let result = models[0].predict(request.into_inner()); - - match result { - Ok(res) => { - let reply = Reply { - message: res.into(), - }; - let res = Response::new(reply); - Ok(res) - } - Err(e) => { - let reply = Reply { - message: e.to_string().into(), - }; - let res = Response::new(reply); - Ok(res) - } - } - } - - #[tracing::instrument] - async fn load_model( - &self, - request: Request, - ) -> Result, Status> { - todo!() - } - - #[tracing::instrument] - async fn predict_stream( - &self, - request: Request, - ) -> Result>>, Status> { - todo!() - } - - #[tracing::instrument] - async fn embedding( - &self, - request: Request, - ) -> Result, Status> { - todo!() - } - - #[tracing::instrument] - async fn generate_image( - &self, - request: Request, - ) -> Result, Status> { - todo!() - } - - #[tracing::instrument] - async fn audio_transcription( - &self, - request: Request, - ) -> Result, Status> { - todo!() - } - - #[tracing::instrument] - async fn tts(&self, request: Request) -> Result, Status> { - todo!() - } - - #[tracing::instrument] - async fn tokenize_string( - &self, - request: Request, - ) -> Result, Status> { - todo!() - } - - #[tracing::instrument] - async fn status( - &self, - request: Request, - ) -> Result, Status> { - - // Here we do not need to cover the windows platform - let mut breakdown = HashMap::new(); - let mut memory_usage: u64=0; - - #[cfg(target_os = "linux")] - { - let pid =id(); - let stat = fs::read_to_string(format!("/proc/{}/stat", pid)).expect("Failed to read stat file"); - - let stats: Vec<&str> = stat.split_whitespace().collect(); - memory_usage = stats[23].parse::().expect("Failed to parse RSS"); - } - - #[cfg(target_os="macos")] - { - let output=Command::new("ps") - .arg("-p") - .arg(id().to_string()) - .arg("-o") - .arg("rss=") - .output() - .expect("failed to execute process"); - - memory_usage = String::from_utf8_lossy(&output.stdout) - .trim() - .parse::() - .expect("Failed to parse memory usage"); - - } - breakdown.insert("RSS".to_string(), memory_usage); - - let memory_usage = Option::from(MemoryUsageData { - total: memory_usage, - breakdown, - }); - - let reponse = StatusResponse { - state: 0, //TODO: add state https://github.com/mudler/LocalAI/blob/9b17af18b3aa0c3cab16284df2d6f691736c30c1/pkg/grpc/proto/backend.proto#L188C9-L188C9 - memory: memory_usage, - }; - - Ok(Response::new(reponse)) - } -} - -#[tokio::main] -async fn main() -> Result<(), Box> { - let subscriber = tracing_subscriber::fmt() - .compact() - .with_file(true) - .with_line_number(true) - .with_thread_ids(true) - .with_target(false) - .finish(); - - tracing::subscriber::set_global_default(subscriber)?; - - // call bunker::run with BurnBackend - let burn_backend = BurnBackend {}; - let addr = "[::1]:50051" - .parse::() - .expect("Failed to parse address"); - - // Implmenet Into for addr - let result = bunker::run(burn_backend, addr).await?; - - event!(Level::INFO, "Burn Server is starting"); - - let span = span!(Level::INFO, "Burn Server"); - let _enter = span.enter(); - - event!(Level::INFO, "Burn Server started successfully"); - - Ok(result) -} diff --git a/backend/rust/backend-burn/Cargo.toml b/backend/rust/backend/Cargo.toml similarity index 94% rename from backend/rust/backend-burn/Cargo.toml rename to backend/rust/backend/Cargo.toml index f97347d324b8..4f32f6e1640a 100644 --- a/backend/rust/backend-burn/Cargo.toml +++ b/backend/rust/backend/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "backend-burn" +name = "backend" version = "0.1.0" edition = "2021" diff --git a/backend/rust/backend-burn/Makefile b/backend/rust/backend/Makefile similarity index 100% rename from backend/rust/backend-burn/Makefile rename to backend/rust/backend/Makefile diff --git a/backend/rust/backend/src/main.rs b/backend/rust/backend/src/main.rs new file mode 100644 index 000000000000..7f6b442c3b6c --- /dev/null +++ b/backend/rust/backend/src/main.rs @@ -0,0 +1,309 @@ +use std::collections::HashMap; +use std::net::SocketAddr; +use std::process::{id, Command}; +use std::sync::{Arc, Mutex}; + +use bunker::pb::Result as PbResult; +use bunker::pb::{ + EmbeddingResult, GenerateImageRequest, HealthMessage, MemoryUsageData, ModelOptions, + PredictOptions, Reply, StatusResponse, TokenizationResponse, TranscriptRequest, + TranscriptResult, TtsRequest, +}; + +use bunker::BackendService; +use tokio_stream::wrappers::ReceiverStream; +use tonic::{Request, Response, Status}; + +use async_trait::async_trait; + +use tracing::{event, span, Level}; + +use models::*; +// implement BackendService trait in bunker + +#[derive(Default, Debug)] +pub struct BurnBackend; + +#[async_trait] +impl BackendService for BurnBackend { + type PredictStreamStream = ReceiverStream>; + + #[tracing::instrument] + async fn health(&self, request: Request) -> Result, Status> { + // return a Result,Status> + let reply = Reply { + message: "OK".into(), + }; + let res = Response::new(reply); + Ok(res) + } + + #[tracing::instrument] + async fn predict(&self, request: Request) -> Result, Status> { + // TODO: How to get model from load_model function? + let mut model= MNINST::new("model.bin"); + let result = model.predict(request.get_ref().clone()); + match result { + Ok(output) => { + let reply = Reply { + message: output.into_bytes(), + }; + let res = Response::new(reply); + Ok(res) + } + Err(e) => { + let result = PbResult { + message: format!("Failed to predict: {}", e), + success: false, + }; + Err(Status::internal(result.message)) + } + } + } + + #[tracing::instrument] + async fn load_model( + &self, + request: Request, + ) -> Result, Status> { + let result= match request.get_ref().model.as_str() { + "mnist" => { + let mut model = MNINST::new("model.bin"); + let result = model.load_model(request.get_ref().clone()); + match result { + Ok(_) => { + let model = Arc::new(Mutex::new(model)); + let model = model.clone(); + let result = PbResult { + message: "Model loaded successfully".into(), + success: true, + }; + Ok(Response::new(result)) + } + Err(e) => { + let result = PbResult { + message: format!("Failed to load model: {}", e), + success: false, + }; + Err(Status::internal(result.message)) + } + } + } + _ => { + let result = PbResult { + message: format!("Model {} not found", request.get_ref().model), + success: false, + }; + Err(Status::internal(result.message)) + } + }; + // TODO: add model to backend, how to transfer model to backend and let predict funciton can use it? + result + } + + #[tracing::instrument] + async fn predict_stream( + &self, + request: Request, + ) -> Result>>, Status> { + todo!() + } + + #[tracing::instrument] + async fn embedding( + &self, + request: Request, + ) -> Result, Status> { + todo!() + } + + #[tracing::instrument] + async fn generate_image( + &self, + request: Request, + ) -> Result, Status> { + todo!() + } + + #[tracing::instrument] + async fn audio_transcription( + &self, + request: Request, + ) -> Result, Status> { + todo!() + } + + #[tracing::instrument] + async fn tts(&self, request: Request) -> Result, Status> { + todo!() + } + + #[tracing::instrument] + async fn tokenize_string( + &self, + request: Request, + ) -> Result, Status> { + todo!() + } + + #[tracing::instrument] + async fn status( + &self, + request: Request, + ) -> Result, Status> { + // Here we do not need to cover the windows platform + let mut breakdown = HashMap::new(); + let mut memory_usage: u64 = 0; + + #[cfg(target_os = "linux")] + { + let pid = id(); + let stat = fs::read_to_string(format!("/proc/{}/stat", pid)) + .expect("Failed to read stat file"); + + let stats: Vec<&str> = stat.split_whitespace().collect(); + memory_usage = stats[23].parse::().expect("Failed to parse RSS"); + } + + #[cfg(target_os = "macos")] + { + let output = Command::new("ps") + .arg("-p") + .arg(id().to_string()) + .arg("-o") + .arg("rss=") + .output() + .expect("failed to execute process"); + + memory_usage = String::from_utf8_lossy(&output.stdout) + .trim() + .parse::() + .expect("Failed to parse memory usage"); + } + breakdown.insert("RSS".to_string(), memory_usage); + + let memory_usage = Option::from(MemoryUsageData { + total: memory_usage, + breakdown, + }); + + let reponse = StatusResponse { + state: 0, //TODO: add state https://github.com/mudler/LocalAI/blob/9b17af18b3aa0c3cab16284df2d6f691736c30c1/pkg/grpc/proto/backend.proto#L188C9-L188C9 + memory: memory_usage, + }; + + Ok(Response::new(reponse)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tonic::Request; + + #[tokio::test] + async fn test_health() { + let backend = BurnBackend::default(); + let request = Request::new(HealthMessage {}); + let response = backend.health(request).await; + + assert!(response.is_ok()); + let response = response.unwrap(); + let message_str = String::from_utf8(response.get_ref().message.clone()).unwrap(); + assert_eq!(message_str, "OK"); + } + #[tokio::test] + async fn test_status() { + let backend = BurnBackend::default(); + let request = Request::new(HealthMessage {}); + let response = backend.status(request).await; + + assert!(response.is_ok()); + let response = response.unwrap(); + let state = response.get_ref().state; + assert_eq!(state, 0); + } + + #[tokio::test] + async fn test_load_model() { + let backend = BurnBackend::default(); + let request = Request::new(ModelOptions { + model: "test".to_string(), + context_size: 0, + seed: 0, + n_batch: 0, + f16_memory: false, + m_lock: false, + m_map: false, + vocab_only: false, + low_vram: false, + embeddings: false, + numa: false, + ngpu_layers: 0, + main_gpu: "".to_string(), + tensor_split: "".to_string(), + threads: 1, + library_search_path: "".to_string(), + rope_freq_base: 0.0, + rope_freq_scale: 0.0, + rms_norm_eps: 0.0, + ngqa: 0, + model_file: "".to_string(), + device: "".to_string(), + use_triton: false, + model_base_name: "".to_string(), + use_fast_tokenizer: false, + pipeline_type: "".to_string(), + scheduler_type: "".to_string(), + cuda: false, + cfg_scale: 0.0, + img2img: false, + clip_model: "".to_string(), + clip_subfolder: "".to_string(), + clip_skip: 0, + tokenizer: "".to_string(), + lora_base: "".to_string(), + lora_adapter: "".to_string(), + no_mul_mat_q: false, + draft_model: "".to_string(), + audio_path: "".to_string(), + quantization: "".to_string(), + }); + let response = backend.load_model(request).await; + + assert!(response.is_ok()); + let response = response.unwrap(); + //TO_DO: add test for response + } +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let subscriber = tracing_subscriber::fmt() + .compact() + .with_file(true) + .with_line_number(true) + .with_thread_ids(true) + .with_target(false) + .finish(); + + tracing::subscriber::set_global_default(subscriber)?; + + // call bunker::run with BurnBackend + let burn_backend = BurnBackend {}; + let addr = "[::1]:50051" + .parse::() + .expect("Failed to parse address"); + + // Implmenet Into for addr + let result = bunker::run(burn_backend, addr).await?; + + event!(Level::INFO, "Burn Server is starting"); + + let span = span!(Level::INFO, "Burn Server"); + let _enter = span.enter(); + + event!(Level::INFO, "Burn Server started successfully"); + + Ok(result) +} diff --git a/backend/rust/models/src/lib.rs b/backend/rust/models/src/lib.rs index f3302e83ef73..b739bf0b0d51 100644 --- a/backend/rust/models/src/lib.rs +++ b/backend/rust/models/src/lib.rs @@ -1,9 +1,16 @@ +use bunker::pb::{ModelOptions, PredictOptions}; + pub(crate) mod mnist; pub use mnist::mnist::MNINST; -use bunker::pb::{ModelOptions, PredictOptions}; - +/// Trait for implementing a Language Model. pub trait LLM { + /// Loads the model from the given options. fn load_model(&mut self, request: ModelOptions) -> Result>; + /// Predicts the output for the given input options. fn predict(&mut self, request: PredictOptions) -> Result>; } + +pub struct LLModel { + model: Box, +} diff --git a/backend/rust/models/src/mnist/mnist.rs b/backend/rust/models/src/mnist/mnist.rs index 995b2706ed05..7a727bbbf441 100644 --- a/backend/rust/models/src/mnist/mnist.rs +++ b/backend/rust/models/src/mnist/mnist.rs @@ -4,7 +4,7 @@ //! Adapter by Aisuko use burn::{ - backend::wgpu::{compute::init_async, AutoGraphicsApi, WgpuDevice}, + backend::wgpu::{AutoGraphicsApi, WgpuDevice}, module::Module, nn::{self, BatchNorm, PaddingConfig2d}, record::{BinBytesRecorder, FullPrecisionSettings, Recorder}, @@ -12,7 +12,6 @@ use burn::{ }; // https://github.com/burn-rs/burn/blob/main/examples/mnist-inference-web/model.bin -static STATE_ENCODED: &[u8] = include_bytes!("model.bin"); const NUM_CLASSES: usize = 10; @@ -36,7 +35,7 @@ pub struct MNINST { } impl MNINST { - pub fn new() -> Self { + pub fn new(model_name: &str) -> Self { let conv1 = ConvBlock::new([1, 8], [3, 3]); // 1 input channel, 8 output channels, 3x3 kernel size let conv2 = ConvBlock::new([8, 16], [3, 3]); // 8 input channels, 16 output channels, 3x3 kernel size let conv3 = ConvBlock::new([16, 24], [3, 3]); // 16 input channels, 24 output channels, 3x3 kernel size @@ -59,8 +58,9 @@ impl MNINST { fc2: fc2, activation: nn::GELU::new(), }; + let state_encoded: &[u8] = &std::fs::read(model_name).expect("Failed to load model"); let record = BinBytesRecorder::::default() - .load(STATE_ENCODED.to_vec()) + .load(state_encoded.to_vec()) .expect("Failed to decode state"); instance.load_record(record) @@ -178,7 +178,7 @@ mod tests { pub type Backend = burn::backend::NdArrayBackend; #[test] fn test_inference() { - let mut model = MNINST::::new(); + let mut model = MNINST::::new("model.bin"); let output = model.inference(&[0.0; 28 * 28]).unwrap(); assert_eq!(output.len(), 10); } diff --git a/backend/rust/models/src/mnist/mod.rs b/backend/rust/models/src/mnist/mod.rs index d53b76c6c7a4..8cc85ce0e520 100644 --- a/backend/rust/models/src/mnist/mod.rs +++ b/backend/rust/models/src/mnist/mod.rs @@ -1,14 +1,20 @@ use crate::LLM; -use bunker::pb::{ModelOptions, PredictOptions}; pub(crate) mod mnist; +use mnist::MNINST; + +use bunker::pb::{ModelOptions, PredictOptions}; + #[cfg(feature = "ndarray")] pub type Backend = burn::backend::NdArrayBackend; -impl LLM for mnist::MNINST { +impl LLM for MNINST { fn load_model(&mut self, request: ModelOptions) -> Result> { - todo!("load model") + let model = request.model_file; + let instance = MNINST::::new(&model); + *self = instance; + Ok("".to_string()) } fn predict(&mut self, pre_ops: PredictOptions) -> Result> { From b210203166cc9b5454de56d65b7f946024138c33 Mon Sep 17 00:00:00 2001 From: Aisuko Date: Sun, 19 Nov 2023 12:35:47 +1100 Subject: [PATCH 17/18] Add test case for load model and import getusage Signed-off-by: Aisuko --- backend/rust/backend/Cargo.toml | 3 +- backend/rust/backend/src/main.rs | 165 ++++++++++++++----------- backend/rust/models/src/lib.rs | 10 +- backend/rust/models/src/mnist/mnist.rs | 5 +- backend/rust/models/src/mnist/mod.rs | 29 ++--- 5 files changed, 114 insertions(+), 98 deletions(-) diff --git a/backend/rust/backend/Cargo.toml b/backend/rust/backend/Cargo.toml index 4f32f6e1640a..767c349c8574 100644 --- a/backend/rust/backend/Cargo.toml +++ b/backend/rust/backend/Cargo.toml @@ -21,4 +21,5 @@ tonic = "0.10.2" tokio-stream = "0.1.14" tracing = "0.1" -tracing-subscriber = "0.3" \ No newline at end of file +tracing-subscriber = "0.3" +nix = { version="0.27.1", features=["resource"]} diff --git a/backend/rust/backend/src/main.rs b/backend/rust/backend/src/main.rs index 7f6b442c3b6c..0b60212d5790 100644 --- a/backend/rust/backend/src/main.rs +++ b/backend/rust/backend/src/main.rs @@ -1,6 +1,5 @@ use std::collections::HashMap; use std::net::SocketAddr; -use std::process::{id, Command}; use std::sync::{Arc, Mutex}; use bunker::pb::Result as PbResult; @@ -19,8 +18,10 @@ use async_trait::async_trait; use tracing::{event, span, Level}; use models::*; -// implement BackendService trait in bunker + +/// TODO: In order to use the model, we need to add some common attributes like: model, device, tokenizer, embeddings, etc. +/// And these attributes should be thread safe. #[derive(Default, Debug)] pub struct BurnBackend; @@ -41,24 +42,7 @@ impl BackendService for BurnBackend { #[tracing::instrument] async fn predict(&self, request: Request) -> Result, Status> { // TODO: How to get model from load_model function? - let mut model= MNINST::new("model.bin"); - let result = model.predict(request.get_ref().clone()); - match result { - Ok(output) => { - let reply = Reply { - message: output.into_bytes(), - }; - let res = Response::new(reply); - Ok(res) - } - Err(e) => { - let result = PbResult { - message: format!("Failed to predict: {}", e), - success: false, - }; - Err(Status::internal(result.message)) - } - } + todo!("How to get model from load_model function?") } #[tracing::instrument] @@ -66,35 +50,21 @@ impl BackendService for BurnBackend { &self, request: Request, ) -> Result, Status> { - let result= match request.get_ref().model.as_str() { + let result = match request.get_ref().model.as_str() { "mnist" => { - let mut model = MNINST::new("model.bin"); - let result = model.load_model(request.get_ref().clone()); - match result { - Ok(_) => { - let model = Arc::new(Mutex::new(model)); - let model = model.clone(); - let result = PbResult { - message: "Model loaded successfully".into(), - success: true, - }; - Ok(Response::new(result)) - } - Err(e) => { - let result = PbResult { - message: format!("Failed to load model: {}", e), - success: false, - }; - Err(Status::internal(result.message)) - } - } + let model = MNINST::load_model(request.get_ref().clone()); + let result= PbResult { + message: format!("Model {} loaded successfully", request.get_ref().model), + success: true, + }; + Ok(Response::new(result)) } _ => { let result = PbResult { message: format!("Model {} not found", request.get_ref().model), success: false, }; - Err(Status::internal(result.message)) + Ok(Response::new(result)) } }; // TODO: add model to backend, how to transfer model to backend and let predict funciton can use it? @@ -155,31 +125,9 @@ impl BackendService for BurnBackend { let mut breakdown = HashMap::new(); let mut memory_usage: u64 = 0; - #[cfg(target_os = "linux")] - { - let pid = id(); - let stat = fs::read_to_string(format!("/proc/{}/stat", pid)) - .expect("Failed to read stat file"); - - let stats: Vec<&str> = stat.split_whitespace().collect(); - memory_usage = stats[23].parse::().expect("Failed to parse RSS"); - } - - #[cfg(target_os = "macos")] - { - let output = Command::new("ps") - .arg("-p") - .arg(id().to_string()) - .arg("-o") - .arg("rss=") - .output() - .expect("failed to execute process"); - - memory_usage = String::from_utf8_lossy(&output.stdout) - .trim() - .parse::() - .expect("Failed to parse memory usage"); - } + use nix::sys::resource::{getrusage, UsageWho}; + let usage = getrusage(UsageWho::RUSAGE_SELF).expect("Failed to fet usage"); + memory_usage = usage.as_ref().ru_maxrss as u64; breakdown.insert("RSS".to_string(), memory_usage); let memory_usage = Option::from(MemoryUsageData { @@ -221,13 +169,15 @@ mod tests { assert!(response.is_ok()); let response = response.unwrap(); let state = response.get_ref().state; + let memory = response.get_ref().memory.clone(); assert_eq!(state, 0); + assert!(memory.is_some()); } #[tokio::test] async fn test_load_model() { let backend = BurnBackend::default(); - let request = Request::new(ModelOptions { + let model_options = ModelOptions { model: "test".to_string(), context_size: 0, seed: 0, @@ -248,7 +198,7 @@ mod tests { rope_freq_scale: 0.0, rms_norm_eps: 0.0, ngqa: 0, - model_file: "".to_string(), + model_file: "models/src/mnist/model.bin".to_string(), device: "".to_string(), use_triton: false, model_base_name: "".to_string(), @@ -268,9 +218,84 @@ mod tests { draft_model: "".to_string(), audio_path: "".to_string(), quantization: "".to_string(), - }); + }; + + // Load the wrong model + let request = Request::new(model_options.clone()); + let response = backend.load_model(request).await; + + assert!(response.is_ok()); + let response = response.unwrap(); + let message_str = response.get_ref().message.clone(); + assert_eq!( + message_str, + format!("Model {} not found", model_options.model.clone()) + ); + + // Load the correct model + let mut model_options2=model_options.clone(); + model_options2.model="mnist".to_string(); + model_options2.model_file="models/src/mnist/model.bin".to_string(); + + let request = Request::new(model_options2.clone()); let response = backend.load_model(request).await; + assert!(response.is_ok()); + let response = response.unwrap(); + let message_str = response.get_ref().message.clone(); + assert_eq!( + message_str, + format!("Model {} loaded successfully", model_options2.model.clone()) + ); + } + + #[tokio::test] + async fn test_predict() { + let backend = BurnBackend::default(); + let request = Request::new(PredictOptions { + prompt: "test".to_string(), + seed: 100, + threads: 1, + tokens: 10, + temperature: 0.0, + top_k: 0, + top_p: 0.0, + repeat: 0, + batch: 1, + n_keep: 0, + penalty: 0.0, + f16kv: false, + debug_mode: false, + stop_prompts: vec!["".to_string()], + ignore_eos: false, + tail_free_sampling_z: 0.0, + typical_p: 0.0, + frequency_penalty: 0.0, + presence_penalty: 0.0, + mirostat: 0, + mirostat_eta: 0.0, + mirostat_tau: 0.0, + penalize_nl: false, + logit_bias: "".to_string(), + m_lock: false, + m_map: false, + prompt_cache_all: false, + prompt_cache_ro: false, + grammar: "".to_string(), + main_gpu: "".to_string(), + tensor_split: "".to_string(), + prompt_cache_path: "".to_string(), + debug: false, + embedding_tokens: vec![0], + embeddings: "".to_string(), + rope_freq_base: 0.0, + rope_freq_scale: 0.0, + negative_prompt_scale: 0.0, + negative_prompt: "".to_string(), + n_draft: 0, + }); + let response: Result, Status> = backend.predict(request).await; + assert!(response.is_ok()); let response = response.unwrap(); //TO_DO: add test for response diff --git a/backend/rust/models/src/lib.rs b/backend/rust/models/src/lib.rs index b739bf0b0d51..c53edaf54692 100644 --- a/backend/rust/models/src/lib.rs +++ b/backend/rust/models/src/lib.rs @@ -5,12 +5,12 @@ pub use mnist::mnist::MNINST; /// Trait for implementing a Language Model. pub trait LLM { + + type Model: LLM; + /// Loads the model from the given options. - fn load_model(&mut self, request: ModelOptions) -> Result>; + fn load_model(request: ModelOptions) -> Result>; /// Predicts the output for the given input options. - fn predict(&mut self, request: PredictOptions) -> Result>; + fn predict(request: PredictOptions) -> Result>; } -pub struct LLModel { - model: Box, -} diff --git a/backend/rust/models/src/mnist/mnist.rs b/backend/rust/models/src/mnist/mnist.rs index 7a727bbbf441..6b2c06ae2932 100644 --- a/backend/rust/models/src/mnist/mnist.rs +++ b/backend/rust/models/src/mnist/mnist.rs @@ -58,7 +58,10 @@ impl MNINST { fc2: fc2, activation: nn::GELU::new(), }; - let state_encoded: &[u8] = &std::fs::read(model_name).expect("Failed to load model"); + use std::path::Path; + let path_name=Path::new(model_name); + + let state_encoded: &[u8] = &std::fs::read(&path_name).expect("Failed to read model file"); let record = BinBytesRecorder::::default() .load(state_encoded.to_vec()) .expect("Failed to decode state"); diff --git a/backend/rust/models/src/mnist/mod.rs b/backend/rust/models/src/mnist/mod.rs index 8cc85ce0e520..8372ec3f1bbd 100644 --- a/backend/rust/models/src/mnist/mod.rs +++ b/backend/rust/models/src/mnist/mod.rs @@ -10,30 +10,17 @@ use bunker::pb::{ModelOptions, PredictOptions}; pub type Backend = burn::backend::NdArrayBackend; impl LLM for MNINST { - fn load_model(&mut self, request: ModelOptions) -> Result> { + type Model = MNINST; + + fn load_model(request: ModelOptions) -> Result> { let model = request.model_file; - let instance = MNINST::::new(&model); - *self = instance; - Ok("".to_string()) + let instance= MNINST::::new(&model); + // check instance and return result + Ok(instance) } - fn predict(&mut self, pre_ops: PredictOptions) -> Result> { + fn predict(pre_ops: PredictOptions) -> Result> { // convert prost::alloc::string::String to &[f32] - let input = pre_ops.prompt.as_bytes(); - let input = input.iter().map(|x| *x as f32).collect::>(); - - let result = self.inference(&input); - - match result { - Ok(output) => { - let output = output - .iter() - .map(|f| f.to_string()) - .collect::>() - .join(","); - Ok(output) - } - Err(e) => Err(e), - } + todo!() } } From c9901126900b210fcb73c3d6505d1544405c236f Mon Sep 17 00:00:00 2001 From: Aisuko Date: Thu, 23 Nov 2023 12:15:21 +1100 Subject: [PATCH 18/18] Add llama for test Signed-off-by: Aisuko --- backend/rust/models/src/lib.rs | 3 + backend/rust/models/src/llama/llama.rs | 720 +++++++++++++++++++++++++ backend/rust/models/src/llama/mod.rs | 3 + 3 files changed, 726 insertions(+) create mode 100644 backend/rust/models/src/llama/llama.rs create mode 100644 backend/rust/models/src/llama/mod.rs diff --git a/backend/rust/models/src/lib.rs b/backend/rust/models/src/lib.rs index c53edaf54692..8d13c577703f 100644 --- a/backend/rust/models/src/lib.rs +++ b/backend/rust/models/src/lib.rs @@ -3,6 +3,9 @@ use bunker::pb::{ModelOptions, PredictOptions}; pub(crate) mod mnist; pub use mnist::mnist::MNINST; +pub(crate) mod llama; + + /// Trait for implementing a Language Model. pub trait LLM { diff --git a/backend/rust/models/src/llama/llama.rs b/backend/rust/models/src/llama/llama.rs new file mode 100644 index 000000000000..805f3eb7c34a --- /dev/null +++ b/backend/rust/models/src/llama/llama.rs @@ -0,0 +1,720 @@ +//! The source code is from https://github.com/Gadersd/llama2-burn/blob/main/src/model.rs +//! The license is Special MIT License.(And the code will be replaced in the future, currently it is just a test.) +//! Adapter by Aisuko + +use std::f32::NEG_INFINITY; + +use burn::{ + config::Config, + module::{Module, Param}, + nn, + tensor::{ + activation::{sigmoid, softmax}, + backend::Backend, + module::embedding, + Data, Distribution, Int, Tensor, + }, backend::wgpu::tensor, +}; + +#[derive(Config, Debug)] +pub struct LlamaConfig { + n_vocab: usize, + n_ctx: usize, + n_state: usize, + multiple_of: usize, + ffn_dim_multiplier: Option, + n_head: usize, + n_kv_head: usize, + n_layer: usize, + #[config(default = 1e-6)] + norm_eps: f64, +} + +impl LlamaConfig { + pub fn init(&self) -> Llama { + let token_embedding = nn::EmbeddingConfig::new(self.n_vocab, self.n_state).init(); + let rotary_encoder = + RotaryEncodingConfig::new(self.n_ctx, self.n_state / self.n_head, 10000.0).init(); + let blocks: Vec<_> = (0..self.n_layer) + .into_iter() + .map(|_| { + ResidualDecoderAttentionBlockConfig::new( + self.n_state, + self.multiple_of, + self.n_head, + self.n_kv_head, + self.norm_eps, + ) + .with_ffn_dim_multiplier(self.ffn_dim_multiplier) + .init() + }) + .collect(); + + let norm = RMSNormConfig::new(self.n_state, self.norm_eps).init(); + let output = nn::LinearConfig::new(self.n_state, self.n_vocab) + .with_bias(false) + .init(); + + let mask = attn_decoder_mask(self.n_ctx).into(); + + let n_vocab = self.n_vocab; + let n_ctx = self.n_ctx; + + Llama { + token_embedding, + rotary_encoder, + blocks, + norm, + output, + mask, + n_vocab, + n_ctx, + } + } +} + +#[derive(Module, Debug)] +pub struct Llama { + token_embedding: nn::Embedding, + rotary_encoder: RotaryEncoding, + blocks: Vec>, + norm: RMSNorm, + output: nn::Linear, + mask: Param>, + n_vocab: usize, + n_ctx: usize, +} + +impl Llama { + pub fn forward(&self, x: Tensor) -> Tensor { + let [n_batch, seq_len] = x.dims(); + + assert!( + seq_len <= self.n_ctx, + "Token sequence length {} must not exceed {}.", + seq_len, + self.n_ctx + ); + + let x = self.token_embedding.forward(x); + + let mut x = x; + for block in &self.blocks { + x = block.forward(x, &self.rotary_encoder, self.mask.val()); + } + + self.output.forward(self.norm.forward(x)) + } +} + +#[derive(Config)] +pub struct ResidualDecoderAttentionBlockConfig { + n_state: usize, + multiple_of: usize, + ffn_dim_multiplier: Option, + n_head: usize, + n_kv_head: usize, + norm_eps: f64, +} + +impl ResidualDecoderAttentionBlockConfig { + fn init(&self) -> ResidualDecoderAttentionBlock { + let attn = + MultiHeadSelfAttentionConfig::new(self.n_state, self.n_head, self.n_kv_head).init(); + let attn_norm = RMSNormConfig::new(self.n_state, self.norm_eps).init(); + + let mlp = MLPConfig::new(self.n_state, 4 * self.n_state, self.multiple_of) + .with_ffn_dim_multiplier(self.ffn_dim_multiplier) + .init(); + let mlp_norm = RMSNormConfig::new(self.n_state, self.norm_eps).init(); + + ResidualDecoderAttentionBlock { + attn, + attn_norm, + mlp, + mlp_norm, + } + } +} + +#[derive(Module, Debug)] +pub struct ResidualDecoderAttentionBlock { + attn: MultiHeadSelfAttention, + attn_norm: RMSNorm, + mlp: MLP, + mlp_norm: RMSNorm, +} + +impl ResidualDecoderAttentionBlock { + fn forward( + &self, + x: Tensor, + rotary_encoder: &RotaryEncoding, + mask: Tensor, + ) -> Tensor { + let x = x.clone() + + self + .attn + .forward(self.attn_norm.forward(x), rotary_encoder, Some(mask)); + let x = x.clone() + self.mlp.forward(self.mlp_norm.forward(x)); + return x; + } +} + +#[derive(Config)] +pub struct MLPConfig { + n_state: usize, + n_state_hidden: usize, + multiple_of: usize, + ffn_dim_multiplier: Option, +} + +impl MLPConfig { + fn init(&self) -> MLP { + let mut hidden_dim = 2 * self.n_state_hidden / 3; + if let Some(ffn_dim_multiplier) = self.ffn_dim_multiplier { + hidden_dim = ffn_dim_multiplier * hidden_dim; + } + hidden_dim = self.multiple_of * ((hidden_dim + self.multiple_of - 1) / self.multiple_of); + + let w1 = nn::LinearConfig::new(self.n_state, hidden_dim) + .with_bias(false) + .init(); + let w2 = nn::LinearConfig::new(hidden_dim, self.n_state) + .with_bias(false) + .init(); + let w3 = nn::LinearConfig::new(self.n_state, hidden_dim) + .with_bias(false) + .init(); + + let silu = SILU::new(); + + MLP { w1, w2, w3, silu } + } +} + +#[derive(Module, Debug)] +pub struct MLP { + w1: nn::Linear, + w2: nn::Linear, + w3: nn::Linear, + silu: SILU, +} + +impl MLP { + fn forward(&self, x: Tensor) -> Tensor { + self.w2 + .forward(self.silu.forward(self.w1.forward(x.clone())) * self.w3.forward(x)) + } +} + +#[derive(Config)] +pub struct MultiHeadSelfAttentionConfig { + n_state: usize, + n_head: usize, + n_kv_head: usize, +} + +impl MultiHeadSelfAttentionConfig { + fn init(&self) -> MultiHeadSelfAttention { + assert!( + self.n_state % self.n_head == 0, + "State size {} must be a multiple of the number of heads {}", + self.n_state, + self.n_head + ); + assert!( + self.n_head % self.n_kv_head == 0, + "The number of query heads {} must be a multiple of the number of k/v heads {}", + self.n_head, + self.n_kv_head + ); + + let n_head_dim = self.n_state / self.n_head; + + let n_head = self.n_head; + let n_kv_head = self.n_kv_head; + let query = nn::LinearConfig::new(self.n_state, self.n_state) + .with_bias(false) + .init(); + let key = nn::LinearConfig::new(self.n_state, n_kv_head * n_head_dim) + .with_bias(false) + .init(); + let value = nn::LinearConfig::new(self.n_state, n_kv_head * n_head_dim) + .with_bias(false) + .init(); + let out = nn::LinearConfig::new(self.n_state, self.n_state) + .with_bias(false) + .init(); + + MultiHeadSelfAttention { + n_head, + n_kv_head, + query, + key, + value, + out, + } + } +} + +#[derive(Module, Debug)] +pub struct MultiHeadSelfAttention { + n_head: usize, + n_kv_head: usize, + query: nn::Linear, + key: nn::Linear, + value: nn::Linear, + out: nn::Linear, +} + +impl MultiHeadSelfAttention { + fn forward( + &self, + x: Tensor, + rotary_encoder: &RotaryEncoding, + mask: Option>, + ) -> Tensor { + let q = self.query.forward(x.clone()); + let k = self.key.forward(x.clone()); + let v = self.value.forward(x); + + let wv = qkv_attention_rotary(q, k, v, mask, self.n_head, self.n_kv_head, rotary_encoder); + + return self.out.forward(wv); + } +} + +fn qkv_attention_rotary( + q: Tensor, + k: Tensor, + v: Tensor, + mask: Option>, + n_head: usize, + n_kv_head: usize, + rotary_encoder: &RotaryEncoding, +) -> Tensor { + let [n_batch, n_qctx, n_state] = q.dims(); + let [_, n_ctx, _] = k.dims(); + + let n_hstate = n_state / n_head; + let scale = (n_hstate as f64).powf(-0.25); // keeps the value weightings roughly normally distributed + + let q: Tensor = q.reshape([n_batch, n_qctx, n_head, n_hstate]); + // interleave kv heads to match the number of q heads + let n_repeat = n_head / n_kv_head; + let k = repeat_kv(k.reshape([n_batch, n_ctx, n_kv_head, n_hstate]), n_repeat); + let v = repeat_kv(v.reshape([n_batch, n_ctx, n_kv_head, n_hstate]), n_repeat); + + // the last two dims need to be (n_ctx, n_hstate) + let q = rotary_encoder.forward(q.swap_dims(1, 2)) * scale; + let k = rotary_encoder.forward(k.swap_dims(1, 2)) * scale; + let v = v.swap_dims(1, 2); + + // compute value weightings + let qk = q.matmul(k.transpose()); + + // apply mask + let qk = if let Some(mask) = mask { + qk + mask.slice([0..n_qctx, 0..n_ctx]).unsqueeze::<4>() + } else { + qk + }; + + // normalize value weightings + let w = softmax(qk, 3); + let o = w.matmul(v).swap_dims(1, 2).flatten(2, 3); + + return o; +} + +/// For a tensor of size (n_batch, n_ctx, n_kv_head, n_hstate), repeats the head keys or values in an interleaving manner so that the number +/// of heads is effectively multiplied by n_repeat +fn repeat_kv(x: Tensor, n_repeat: usize) -> Tensor { + if n_repeat > 1 { + let [n_batch, n_ctx, n_kv_head, n_hstate] = x.dims(); + x.repeat(3, n_repeat) + .reshape([n_batch, n_ctx, n_kv_head * n_repeat, n_hstate]) + } else { + x + } +} + +/// Generates a strictly upper triangular matrix filled with -inf that when added to an attention weight matrix prevents +/// vectors from attending to other vectors further in the sequence, preventing future information from flowing into the past +pub fn attn_decoder_mask(seq_length: usize) -> Tensor { + let mut mask = Tensor::::zeros([seq_length, seq_length]); + + for i in 0..(seq_length - 1) { + let values = Tensor::::zeros([1, seq_length - (i + 1)]).add_scalar(NEG_INFINITY); + mask = mask.slice_assign([i..i + 1, i + 1..seq_length], values); + } + + return mask; +} + +#[derive(Config, Debug)] +pub struct RotaryEncodingConfig { + max_sequence_length: usize, + state_size: usize, + theta: f64, +} + +impl RotaryEncodingConfig { + pub fn init(&self) -> RotaryEncoding { + assert!(self.state_size % 2 == 0, "Head dims must be even."); + assert!(self.theta > 0.0, "Theta must be positive."); + + let half_state_size = self.state_size / 2; + + let arange_m = Tensor::from_floats([[1.0, 0.0, 0.0, 1.0], [0.0, -1.0, 1.0, 0.0]]).into(); + + let inv_freq = powto( + self.theta, + Tensor::arange(0..half_state_size).float() * (2.0 / self.state_size as f64), + ) + .powf(-1.0); + + let periods = Tensor::arange(0..self.max_sequence_length) + .float() + .unsqueeze::<2>() + .transpose() + .repeat(1, half_state_size) + * inv_freq.unsqueeze(); + + let p_cos = periods.clone().cos(); + let p_sin = periods.sin(); + let tensor=Tensor::cat(vec![p_cos, p_sin], 1); + + let tensor2=tensor.reshape([self.max_sequence_length,2,half_state_size]); + + let tensor3=tensor2.transpose(); + + let tensor41=tensor3.repeat(2, 2); + + let tensor5=tensor41.reshape([self.max_sequence_length,self.state_size,2]); + + let freq_cis=tensor5.into(); + + RotaryEncoding { arange_m, freq_cis } + } + +} + +fn powto(base: f64, x: Tensor) -> Tensor { + let logbase = base.ln(); + x.mul_scalar(logbase).exp() +} + +/// Conceptually, pairs the values of a vector (v0 v1 v2 ... vn) into complex numbers (c0 c1 c2 ... cn/2) +/// which are then rotated counter-clockwise by the angle seq_index / theta^(2*pair_index/n). +/// This encodes sequence positions in a way that is agnostic to the maximum sequence length +/// which potentially allows for arbitrarily long sequences without retraining. +#[derive(Module, Debug)] +pub struct RotaryEncoding { + arange_m: Param>, + freq_cis: Param>, +} + +impl RotaryEncoding { + /// Applies rotary positional encoding to a tensor of dimenions (..., seq_len, n_state) + fn forward(&self, x: Tensor) -> Tensor { + assert!(D >= 2); + let orig_shape = x.shape(); + let (n_ctx, n_state) = (orig_shape.dims[D - 2], orig_shape.dims[D - 1]); + let dummy_dim_size = orig_shape.num_elements() / (n_ctx * n_state); + + let out = x + .reshape([dummy_dim_size, n_ctx, n_state / 2, 2]) + .matmul(self.arange_m.val().unsqueeze()) + .reshape([dummy_dim_size, n_ctx, n_state, 2]) + * self.freq_cis.val().slice([0..n_ctx]).unsqueeze(); + + out.sum_dim(D - 1).reshape(orig_shape) + } +} + +#[derive(Config)] +pub struct RMSNormConfig { + layer_size: usize, + eps: f64, +} + +impl RMSNormConfig { + fn init(&self) -> RMSNorm { + assert!(self.eps > 0.0, "eps must be positive."); + + let weight = Tensor::ones([self.layer_size]); + let eps = self.eps; + + RMSNorm { weight, eps } + } +} + +#[derive(Module, Debug)] +pub struct RMSNorm { + weight: Tensor, + eps: f64, +} + +impl RMSNorm { + fn forward(&self, x: Tensor) -> Tensor { + let rms = (x.clone().powf(2.0).mean_dim(D - 1) + self.eps).sqrt(); + (x / rms) * self.weight.clone().unsqueeze() + } +} + +#[derive(Module, Clone, Debug)] +pub struct SILU {} + +impl SILU { + fn new() -> Self { + Self {} + } + + fn forward(&self, x: Tensor) -> Tensor { + x.clone() * sigmoid(x) + } +} + +use npy::{self, NpyData}; //TODO: NpyData is deprecated, use ndarray_npy instead, but before replace it, I want to make sure the project works well. +use num_traits::cast::ToPrimitive; // TODO: Same here. +use std::error::Error; +use std::io::Read; + +use burn::tensor::ElementConversion; + +fn numpy_to_tensor( + numpy_data: NpyData, + device: &B::Device, +) -> Tensor { + let mut v = numpy_data.to_vec(); + + let shape: Vec<_> = v[0..D].into_iter().map(|&v| v as usize).collect(); + let data: Vec = v[D..].into_iter().map(|e| e.elem()).collect(); + + Tensor::from_data_device(Data::new(data, shape.into()), device) +} + +fn load_tensor( + name: &str, + path: &str, + device: &B::Device, +) -> Result, Box> { + let tensor_path = format!("{}/{}.npy", path, name); + + let mut buf = vec![]; + std::fs::File::open(&tensor_path)?.read_to_end(&mut buf)?; + + let tensor_numpy: NpyData = NpyData::from_bytes(&buf)?; + + let tensor = numpy_to_tensor(tensor_numpy, device); + + println!("{}", tensor_path); + + Ok(tensor) +} + +fn load_f32(name: &str, path: &str, device: &B::Device) -> Result> { + load_tensor::(name, path, device).map(|t| t.into_scalar().to_f32().unwrap()) +} + +fn load_usize( + name: &str, + path: &str, + device: &B::Device, +) -> Result> { + load_tensor::(name, path, device).map(|t| t.into_scalar().to_usize().unwrap()) +} + +fn load_linear( + path: &str, + device: &B::Device, +) -> Result, Box> { + let weight = load_tensor::("weight", path, device)?; + let bias = load_tensor::("bias", path, device).ok(); + + let record = nn::LinearRecord { + weight: weight.into(), + bias: bias.map(|t| t.into()), + }; + + let linear: nn::Linear = nn::LinearConfig::new(3, 3).init_with(record); + Ok(linear) +} + +fn load_rmsnorm(path: &str, device: &B::Device) -> Result, Box> { + let weight = load_tensor::("weight", path, device)?; + let eps = load_f32::("eps", path, device)?.into(); + + let rmsnorm = RMSNorm { + weight: weight.into(), + eps: eps, + }; + + Ok(rmsnorm) +} + +fn load_attention( + path: &str, + device: &B::Device, +) -> Result, Box> { + let query = load_linear(&format!("{}/{}", path, "wq"), device)?; + let key = load_linear(&format!("{}/{}", path, "wk"), device)?; + let value = load_linear(&format!("{}/{}", path, "wv"), device)?; + let out = load_linear(&format!("{}/{}", path, "wo"), device)?; + + let n_head = load_usize::("n_head", path, device)?; + let n_kv_head = load_usize::("n_kv_head", path, device)?; + + let attention = MultiHeadSelfAttention { + n_head: n_head, + n_kv_head: n_kv_head, + query: query, + key: key, + value: value, + out: out, + }; + + Ok(attention) +} + +fn load_feedforward(path: &str, device: &B::Device) -> Result, Box> { + let w1 = load_linear(&format!("{}/{}", path, "w1"), device)?; + let w2 = load_linear(&format!("{}/{}", path, "w2"), device)?; + let w3 = load_linear(&format!("{}/{}", path, "w3"), device)?; + + let mlp = MLP { + w1: w1, + w2: w2, + w3: w3, + silu: SILU::new(), + }; + + Ok(mlp) +} + +fn load_transformer_block( + path: &str, + device: &B::Device, +) -> Result, Box> { + let attn = load_attention(&format!("{}/{}", path, "attention"), device)?; + let attn_norm = load_rmsnorm(&format!("{}/{}", path, "attention_norm"), device)?; + let mlp = load_feedforward(&format!("{}/{}", path, "feedforward"), device)?; + let mlp_norm = load_rmsnorm(&format!("{}/{}", path, "ffn_norm"), device)?; + + let block = ResidualDecoderAttentionBlock { + attn: attn, + attn_norm: attn_norm, + mlp: mlp, + mlp_norm: mlp_norm, + }; + + Ok(block) +} + +use burn::nn::{EmbeddingConfig, EmbeddingRecord}; + +pub fn load_llama_dump( + path: &str, + device: &B::Device, +) -> Result<(Llama, LlamaConfig), Box> { + let mut blocks: Vec> = vec![]; + let n_layer = load_usize::("n_layer", path, device)?; + for i in 0..n_layer { + let block = load_transformer_block(&format!("{}/layer{}", path, i), device)?; + blocks.push(block); + } + + let n_ctx = load_usize::("n_ctx", path, device)?; + let theta = load_f32::("theta", path, device)?; + let multiple_of = load_usize::("multiple_of", path, device)?; + let ffn_dim_multiplier = load_usize::("ffn_dim_multiplier", path, device).ok(); + + let token_embedding = load_tensor("tok_embeddings/weight", path, device)?; + let [n_vocab, n_state] = token_embedding.dims(); + let n_head = blocks[0].attn.n_head; + let n_kv_head = blocks[0].attn.n_kv_head; + let head_dim = n_state / n_head; + + let token_embedding = EmbeddingConfig::new(n_vocab, n_state).init_with(EmbeddingRecord { + weight: token_embedding.into(), + }); + let rotary_encoding = RotaryEncodingConfig::new(n_ctx, head_dim, theta.into()).init(); + + let norm = load_rmsnorm(&format!("{}/{}", path, "norm"), device)?; + let output = load_linear(&format!("{}/{}", path, "output"), device)?; + let mask = attn_decoder_mask(n_ctx).into(); + + let norm_eps = norm.eps; + + let llama = Llama { + token_embedding: token_embedding, + rotary_encoder: rotary_encoding, + blocks: blocks, + norm: norm, + output: output, + mask: mask, + n_vocab: n_vocab, + n_ctx: n_ctx, + }; + + let llama_config = LlamaConfig::new( + n_vocab, + n_ctx, + n_state, + multiple_of, + n_head, + n_kv_head, + n_layer, + ) + .with_norm_eps(norm_eps) + .with_ffn_dim_multiplier(ffn_dim_multiplier); + + Ok((llama, llama_config)) +} + + +#[cfg(test)] +mod tests{ + use super::*; + + #[test] + fn test_feq_cis_reshape(){ + use burn::backend::WgpuBackend; + use burn::backend::wgpu::{AutoGraphicsApi}; + + type Backend = WgpuBackend; + + let config= RotaryEncodingConfig{ + max_sequence_length: 10, + state_size: 4, + theta: 1.0, + }; + + let encoding=config.init::(); + + assert_eq!(encoding.freq_cis.dims(),[10,4,2]); + assert_eq!(encoding.arange_m.dims(),[2,4]); + } + + #[test] + fn test_rotary_encoding_config_init(){ + use burn::backend::WgpuBackend; + use burn::backend::wgpu::{AutoGraphicsApi}; + + type Backend = WgpuBackend; + + let config = RotaryEncodingConfig{ + state_size: 4, + theta:1.0, + max_sequence_length: 10, + }; + + let encoding=config.init::(); + + assert_eq!(encoding.arange_m.dims(),[2,4]); + assert_eq!(encoding.freq_cis.dims(),[10,4,2]); + + } +} \ No newline at end of file diff --git a/backend/rust/models/src/llama/mod.rs b/backend/rust/models/src/llama/mod.rs new file mode 100644 index 000000000000..c000cb3c53ee --- /dev/null +++ b/backend/rust/models/src/llama/mod.rs @@ -0,0 +1,3 @@ +pub(crate) mod llama; + +use llama::*; \ No newline at end of file