diff --git a/Cargo.lock b/Cargo.lock index 9a976c0d4..2c12fb17c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -27,12 +27,6 @@ dependencies = [ "gimli", ] -[[package]] -name = "adler" -version = "1.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" - [[package]] name = "adler2" version = "2.0.0" @@ -68,9 +62,9 @@ checksum = "1763692fc1416554cf051efc56a3de5595eca47299d731cc5c2b583adf8b4d2f" [[package]] name = "allocator-api2" -version = "0.2.18" +version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f" +checksum = "45862d1c77f2228b9e10bc609d5bc203d86ebc9b87ad8d5d5167a6c9abf739d9" [[package]] name = "android-tzdata" @@ -87,11 +81,17 @@ dependencies = [ "libc", ] +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + [[package]] name = "anstream" -version = "0.6.17" +version = "0.6.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23a1e53f0f5d86382dafe1cf314783b2044280f406e7e1506368220ad11b1338" +checksum = "8acc5369981196006228e28809f761875c0327210a891e941f4c683b3a99529b" dependencies = [ "anstyle", "anstyle-parse", @@ -104,9 +104,9 @@ dependencies = [ [[package]] name = "anstyle" -version = "1.0.9" +version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8365de52b16c035ff4fcafe0092ba9390540e3e352870ac09933bebcaa2c8c56" +checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9" [[package]] name = "anstyle-parse" @@ -138,18 +138,18 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.91" +version = "1.0.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c042108f3ed77fd83760a5fd79b53be043192bb3b9dba91d8c574c0ada7850c8" +checksum = "4c95c10ba0b00a02636238b814946408b1322d5ac4760326e6fb8ec956d85775" dependencies = [ "backtrace", ] [[package]] name = "arbitrary" -version = "1.3.2" +version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d5a26814d8dcb93b0e5a0ff3c6d80a8843bafb21b39e8e18a6f05471870e110" +checksum = "dde20b3d026af13f561bdd0f15edf01fc734f0dafcedbaf42bba506a9517f223" dependencies = [ "derive_arbitrary", ] @@ -174,7 +174,7 @@ checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -191,9 +191,9 @@ checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" [[package]] name = "axum" -version = "0.7.7" +version = "0.7.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "504e3947307ac8326a5437504c517c4b56716c9d98fac0028c2acc7ca47d70ae" +checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" dependencies = [ "async-trait", "axum-core", @@ -253,7 +253,7 @@ dependencies = [ "addr2line", "cfg-if", "libc", - "miniz_oxide 0.8.0", + "miniz_oxide", "object", "rustc-demangle", "windows-targets 0.52.6", @@ -341,8 +341,8 @@ dependencies = [ "proc-macro2", "quote", "str_inflector", - "syn 2.0.85", - "thiserror", + "syn 2.0.87", + "thiserror 1.0.69", "try_match", ] @@ -354,9 +354,9 @@ checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" [[package]] name = "bytemuck" -version = "1.19.0" +version = "1.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8334215b81e418a0a7bdb8ef0849474f40bb10c8b71f1c4ed315cff49f32494d" +checksum = "8b37c88a63ffd85d15b406896cc343916d7cf57838a847b3a6f2ca5d39a5695a" dependencies = [ "bytemuck_derive", ] @@ -369,7 +369,7 @@ checksum = "bcfcc3cd946cb52f0bbfdbbcfa2f4e24f75ebb6c0e1002f7c25904fada18b9ec" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -391,89 +391,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ac0150caa2ae65ca5bd83f25c7de183dea78d4d366469f148435e2acfbad0da" [[package]] -name = "candle-core" -version = "0.8.0" -source = "git+https://github.com/EricLBuehler/candle.git?rev=cb8082b#cb8082bf28eadb4140c1774f002983a6d182bf3e" -dependencies = [ - "accelerate-src", - "byteorder", - "candle-kernels", - "candle-metal-kernels", - "cudarc", - "float8", - "gemm", - "half", - "intel-mkl-src", - "libc", - "memmap2", - "metal 0.27.0", - "num-traits", - "num_cpus", - "rand", - "rand_distr", - "rayon", - "safetensors", - "thiserror", - "ug", - "ug-cuda", - "ug-metal", - "yoke", - "zip", -] - -[[package]] -name = "candle-flash-attn" -version = "0.8.0" -source = "git+https://github.com/EricLBuehler/candle.git?rev=cb8082b#cb8082bf28eadb4140c1774f002983a6d182bf3e" -dependencies = [ - "anyhow", - "bindgen_cuda 0.1.5", - "candle-core", - "half", -] - -[[package]] -name = "candle-kernels" -version = "0.8.0" -source = "git+https://github.com/EricLBuehler/candle.git?rev=cb8082b#cb8082bf28eadb4140c1774f002983a6d182bf3e" -dependencies = [ - "bindgen_cuda 0.1.5", -] - -[[package]] -name = "candle-metal-kernels" -version = "0.8.0" -source = "git+https://github.com/EricLBuehler/candle.git?rev=cb8082b#cb8082bf28eadb4140c1774f002983a6d182bf3e" -dependencies = [ - "metal 0.27.0", - "once_cell", - "thiserror", - "tracing", -] - -[[package]] -name = "candle-nn" -version = "0.8.0" -source = "git+https://github.com/EricLBuehler/candle.git?rev=cb8082b#cb8082bf28eadb4140c1774f002983a6d182bf3e" -dependencies = [ - "accelerate-src", - "candle-core", - "candle-metal-kernels", - "half", - "intel-mkl-src", - "metal 0.27.0", - "num-traits", - "rayon", - "safetensors", - "serde", - "thiserror", -] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.1.31" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2e7962b54006dcfcc61cb72735f4d89bb97061dd6a7ed882ec6b8ee53714c6f" +checksum = "fd9de9f2205d5ef3fd67e685b0df337994ddd4495e2a28d185500d0e1edfea47" dependencies = [ "shlex", ] @@ -492,9 +419,9 @@ checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" [[package]] name = "cfgrammar" -version = "0.13.7" +version = "0.13.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6026d8cd82ada8bbcfe337805dd1eb6afdc9e80fa4d57e977b3a36315e0c5525" +checksum = "6d621f687a04efa1f269f1cd13d8cfea9660852bdb3d1cd2c3c9fb6fdd34daf2" dependencies = [ "indexmap", "lazy_static", @@ -540,11 +467,38 @@ dependencies = [ "phf_codegen", ] +[[package]] +name = "ciborium" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" + +[[package]] +name = "ciborium-ll" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" +dependencies = [ + "ciborium-io", + "half", +] + [[package]] name = "clap" -version = "4.5.20" +version = "4.5.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b97f376d85a664d5837dbae44bf546e6477a679ff6610010f17276f686d867e8" +checksum = "fb3b4b9e5a7c7514dfa52869339ee98b3156b0bfb4e8a77c4ff4babb64b1604f" dependencies = [ "clap_builder", "clap_derive", @@ -552,9 +506,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.20" +version = "4.5.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19bc80abd44e4bed93ca373a0704ccbd1b710dc5749406201bb018272808dc54" +checksum = "b17a95aa67cc7b5ebd32aa5370189aa0d79069ef1c64ce893bd30fb24bff20ec" dependencies = [ "anstream", "anstyle", @@ -572,14 +526,14 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] name = "clap_lex" -version = "0.7.2" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97" +checksum = "afb84c814227b90d6895e01398aee0d8033c00e7466aca416fb6a8e0eb19d8a7" [[package]] name = "cli-table" @@ -590,7 +544,7 @@ dependencies = [ "cli-table-derive", "csv", "termcolor", - "unicode-width", + "unicode-width 0.1.14", ] [[package]] @@ -625,7 +579,7 @@ dependencies = [ "encode_unicode", "lazy_static", "libc", - "unicode-width", + "unicode-width 0.1.14", "windows-sys 0.52.0", ] @@ -658,9 +612,9 @@ dependencies = [ [[package]] name = "cpufeatures" -version = "0.2.14" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "608697df725056feaccfa42cffdaeeec3fccc4ffc38358ecd19b243e716a78e0" +checksum = "0ca741a962e1b0bff6d724a1a0958b686406e853bb14061f218562e1896f95e6" dependencies = [ "libc", ] @@ -674,6 +628,40 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "criterion" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +dependencies = [ + "anes", + "cast", + "ciborium", + "clap", + "criterion-plot", + "is-terminal", + "itertools 0.10.5", + "num-traits", + "once_cell", + "oorandom", + "regex", + "serde", + "serde_derive", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +dependencies = [ + "cast", + "itertools 0.10.5", +] + [[package]] name = "crossbeam-deque" version = "0.8.5" @@ -742,9 +730,9 @@ dependencies = [ [[package]] name = "csv" -version = "1.3.0" +version = "1.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac574ff4d437a7b5ad237ef331c17ccca63c46479e5b5453eb8e10bb99a759fe" +checksum = "acdc4883a9c96732e4733212c01447ebd805833b7275a73ca3ee080fd77afdaf" dependencies = [ "csv-core", "itoa", @@ -826,7 +814,7 @@ dependencies = [ "proc-macro2", "quote", "strsim 0.11.1", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -848,7 +836,7 @@ checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" dependencies = [ "darling_core 0.20.10", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -871,18 +859,18 @@ checksum = "2cdc8d50f426189eef89dac62fabfa0abb27d5cc008f25bf4156a0203325becc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] name = "derive_arbitrary" -version = "1.3.2" +version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67e77553c4162a157adbf834ebae5b415acbecbeafc7a74b0e886657506a7611" +checksum = "30542c1ad912e0e3d22a1935c290e12e8a29d704a420177a31faad4a601a0800" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -903,7 +891,7 @@ dependencies = [ "darling 0.20.10", "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -913,7 +901,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" dependencies = [ "derive_builder_core", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -924,7 +912,7 @@ checksum = "5f33878137e4dafd7fa914ad4e259e18a4e8e532b9617a2d0150262bf53abfce" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -975,7 +963,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -1033,7 +1021,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -1063,15 +1051,14 @@ dependencies = [ [[package]] name = "exr" -version = "1.72.0" +version = "1.73.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "887d93f60543e9a9362ef8a21beedd0a833c5d9610e18c67abe15a5963dcb1a4" +checksum = "f83197f59927b46c04a183a619b7c29df34e63e63c7869320862268c0ef687e0" dependencies = [ "bit_field", - "flume", "half", "lebe", - "miniz_oxide 0.7.4", + "miniz_oxide", "rayon-core", "smallvec", "zune-inflate", @@ -1089,15 +1076,15 @@ dependencies = [ [[package]] name = "fastrand" -version = "2.1.1" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6" +checksum = "486f806e73c5707928240ddc295403b1b93c96a02038563881c4a2fd84b81ac4" [[package]] name = "fdeflate" -version = "0.3.5" +version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8090f921a24b04994d9929e204f50b498a33ea6ba559ffaa05e04f7ee7fb5ab" +checksum = "07c6f4c64c1d33a3111c4466f7365ebdcc37c5bd1ea0d62aae2e3d722aacbedb" dependencies = [ "simd-adler32", ] @@ -1116,12 +1103,12 @@ dependencies = [ [[package]] name = "flate2" -version = "1.0.34" +version = "1.0.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1b589b4dc103969ad3cf85c950899926ec64300a1a46d76c03a6072957036f0" +checksum = "c936bfdafb507ebbf50b8074c54fa31c5be9a1e7e5f467dd659697041407d07c" dependencies = [ "crc32fast", - "miniz_oxide 0.8.0", + "miniz_oxide", ] [[package]] @@ -1137,15 +1124,6 @@ dependencies = [ "rand_distr", ] -[[package]] -name = "flume" -version = "0.11.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da0e4dd2a88388a1f4ccc7c9ce104604dab68d9f408dc34cd45823d5a9069095" -dependencies = [ - "spin", -] - [[package]] name = "fnv" version = "1.0.7" @@ -1179,7 +1157,7 @@ checksum = "1a5c6c585bc94aaf2c7b51dd4c2ba22680844aba4c687be581871a6f518c5742" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -1259,7 +1237,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -1437,8 +1415,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" dependencies = [ "cfg-if", + "js-sys", "libc", "wasi", + "wasm-bindgen", ] [[package]] @@ -1450,7 +1430,7 @@ dependencies = [ "proc-macro-error2", "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -1477,9 +1457,9 @@ checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" [[package]] name = "h2" -version = "0.4.6" +version = "0.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "524e8ac6999421f49a846c2d4411f337e53497d8ec55d67753beffa43c5d9205" +checksum = "ccae279728d634d083c00f6099cb58f01cc99c145b84b8be2f6c74618d79922e" dependencies = [ "atomic-waker", "bytes", @@ -1520,9 +1500,9 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.15.0" +version = "0.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e087f84d4f86bf4b218b927129862374b72199ae7d8657835f1e89000eea4fb" +checksum = "3a9bfc1af68b1726ea47d3d5109de126281def866b33970e10fbab11b5dafab3" [[package]] name = "heck" @@ -1536,6 +1516,12 @@ version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" +[[package]] +name = "hermit-abi" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbf6a919d6cf397374f7dfeeea91d974c7c0a7221d0d0f4f20d859d329e53fcc" + [[package]] name = "hf-hub" version = "0.3.2" @@ -1549,7 +1535,7 @@ dependencies = [ "rand", "serde", "serde_json", - "thiserror", + "thiserror 1.0.69", "ureq", ] @@ -1601,9 +1587,9 @@ checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" [[package]] name = "hyper" -version = "1.5.0" +version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbbff0a806a4728c99295b254c8838933b5b082d75e3cb70c8dab21fdfbcfa9a" +checksum = "97818827ef4f364230e16705d4706e2897df2bb60617d6ca15d598025a3c481f" dependencies = [ "bytes", "futures-channel", @@ -1656,9 +1642,9 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.9" +version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41296eb09f183ac68eec06e03cdbea2e759633d4067b2f6552fc2e009bcad08b" +checksum = "df2dcfbe0677734ab2f3ffa7fa7bfd4706bfdc1ef393f2ee30184aed67e631b4" dependencies = [ "bytes", "futures-channel", @@ -1696,6 +1682,124 @@ dependencies = [ "cc", ] +[[package]] +name = "icu_collections" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db2fa452206ebee18c4b5c2274dbf1de17008e874b4dc4f0aea9d01ca79e4526" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_locid" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13acbb8371917fc971be86fc8057c41a64b521c184808a698c02acc242dbf637" +dependencies = [ + "displaydoc", + "litemap", + "tinystr", + "writeable", + "zerovec", +] + +[[package]] +name = "icu_locid_transform" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01d11ac35de8e40fdeda00d9e1e9d92525f3f9d887cdd7aa81d727596788b54e" +dependencies = [ + "displaydoc", + "icu_locid", + "icu_locid_transform_data", + "icu_provider", + "tinystr", + "zerovec", +] + +[[package]] +name = "icu_locid_transform_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdc8ff3388f852bede6b579ad4e978ab004f139284d7b28715f773507b946f6e" + +[[package]] +name = "icu_normalizer" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19ce3e0da2ec68599d193c93d088142efd7f9c5d6fc9b803774855747dc6a84f" +dependencies = [ + "displaydoc", + "icu_collections", + "icu_normalizer_data", + "icu_properties", + "icu_provider", + "smallvec", + "utf16_iter", + "utf8_iter", + "write16", + "zerovec", +] + +[[package]] +name = "icu_normalizer_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8cafbf7aa791e9b22bec55a167906f9e1215fd475cd22adfcf660e03e989516" + +[[package]] +name = "icu_properties" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93d6020766cfc6302c15dbbc9c8778c37e62c14427cb7f6e601d849e092aeef5" +dependencies = [ + "displaydoc", + "icu_collections", + "icu_locid_transform", + "icu_properties_data", + "icu_provider", + "tinystr", + "zerovec", +] + +[[package]] +name = "icu_properties_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67a8effbc3dd3e4ba1afa8ad918d5684b8868b3b26500753effea8d2eed19569" + +[[package]] +name = "icu_provider" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ed421c8a8ef78d3e2dbc98a973be2f3770cb42b606e3ab18d6237c4dfde68d9" +dependencies = [ + "displaydoc", + "icu_locid", + "icu_provider_macros", + "stable_deref_trait", + "tinystr", + "writeable", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_provider_macros" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + [[package]] name = "ident_case" version = "1.0.1" @@ -1704,19 +1808,30 @@ checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" [[package]] name = "idna" -version = "0.5.0" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" +checksum = "686f825264d630750a544639377bae737628043f20d38bbc029e8f29ea968a7e" dependencies = [ - "unicode-bidi", - "unicode-normalization", + "idna_adapter", + "smallvec", + "utf8_iter", +] + +[[package]] +name = "idna_adapter" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daca1df1c957320b2cf139ac61e7bd64fed304c5040df000a745aa1de3b4ef71" +dependencies = [ + "icu_normalizer", + "icu_properties", ] [[package]] name = "image" -version = "0.25.4" +version = "0.25.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc144d44a31d753b02ce64093d532f55ff8dc4ebf2ffb8a63c0dda691385acae" +checksum = "cd6f44aed642f18953a158afeb30206f4d50da59fbc66ecb53c66488de73563b" dependencies = [ "bytemuck", "byteorder-lite", @@ -1755,22 +1870,22 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da" dependencies = [ "equivalent", - "hashbrown 0.15.0", + "hashbrown 0.15.1", "serde", ] [[package]] name = "indicatif" -version = "0.17.8" +version = "0.17.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "763a5a8f45087d6bcea4222e7b72c291a054edf80e4ef6efd2a4979878c7bea3" +checksum = "cbf675b85ed934d3c67b5c5469701eec7db22689d0a2139d856e0925fa28b281" dependencies = [ "console", - "instant", "number_prefix", "portable-atomic", "rayon", - "unicode-width", + "unicode-width 0.2.0", + "web-time", ] [[package]] @@ -1779,15 +1894,6 @@ version = "2.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" -[[package]] -name = "instant" -version = "0.1.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0242819d153cba4b4b05a5a8f2a7e9bbf97b6055b2a002b395c96b5ff3c0222" -dependencies = [ - "cfg-if", -] - [[package]] name = "intel-mkl-src" version = "0.8.1" @@ -1816,12 +1922,32 @@ version = "2.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ddc24109865250148c2e0f3d25d4f0f479571723792d3802153c60922a4fb708" +[[package]] +name = "is-terminal" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "261f68e344040fbd0edea105bef17c66edf46f984ddb1115b775ce31be948f4b" +dependencies = [ + "hermit-abi 0.4.0", + "libc", + "windows-sys 0.52.0", +] + [[package]] name = "is_terminal_polyfill" version = "1.70.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" +[[package]] +name = "itertools" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" +dependencies = [ + "either", +] + [[package]] name = "itertools" version = "0.11.0" @@ -1884,9 +2010,9 @@ checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8" [[package]] name = "libc" -version = "0.2.161" +version = "0.2.164" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e9489c2807c139ffd9c1794f4af0ebe86a828db53ecdc7fea2111d0fed085d1" +checksum = "433bfe06b8c75da9b2e3fbea6e5329ff87748f0b144ef75306e674c3f6f7c13f" [[package]] name = "libloading" @@ -1900,9 +2026,9 @@ dependencies = [ [[package]] name = "libm" -version = "0.2.8" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" +checksum = "8355be11b20d696c8f18f6cc018c4e372165b1fa8126cef092399c9951984ffa" [[package]] name = "libredox" @@ -1921,6 +2047,12 @@ version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" +[[package]] +name = "litemap" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "643cb0b8d4fcc284004d5fd0d67ccf61dfffadb7f75e1e71bc420f4688a3a704" + [[package]] name = "lock_api" version = "0.4.12" @@ -1939,9 +2071,9 @@ checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" [[package]] name = "lrtable" -version = "0.13.7" +version = "0.13.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d42d2752cb50a171efadda0cb6fa97432e8bf05accfff3eed320b87e80a2f69e" +checksum = "49e35162de3a5d91b380f8ebb31fc6c5e9a4618276465df4725ff1f88613312b" dependencies = [ "cfgrammar", "fnv", @@ -1990,6 +2122,88 @@ version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" +[[package]] +name = "mcandle-core" +version = "0.3.2" +dependencies = [ + "accelerate-src", + "anyhow", + "byteorder", + "clap", + "criterion", + "cudarc", + "float8", + "gemm", + "half", + "intel-mkl-src", + "libc", + "mcandle-kernels", + "mcandle-metal-kernels", + "memmap2", + "metal 0.27.0", + "num-traits", + "num_cpus", + "rand", + "rand_distr", + "rayon", + "safetensors", + "thiserror 1.0.69", + "ug", + "ug-cuda", + "ug-metal", + "yoke", + "zip", +] + +[[package]] +name = "mcandle-flash-attn" +version = "0.3.2" +dependencies = [ + "anyhow", + "bindgen_cuda 0.1.5", + "half", + "mcandle-core", +] + +[[package]] +name = "mcandle-kernels" +version = "0.3.2" +dependencies = [ + "bindgen_cuda 0.1.5", +] + +[[package]] +name = "mcandle-metal-kernels" +version = "0.3.2" +dependencies = [ + "metal 0.27.0", + "once_cell", + "thiserror 1.0.69", + "tracing", +] + +[[package]] +name = "mcandle-nn" +version = "0.3.2" +dependencies = [ + "accelerate-src", + "anyhow", + "clap", + "criterion", + "half", + "intel-mkl-src", + "mcandle-core", + "mcandle-flash-attn", + "mcandle-metal-kernels", + "metal 0.27.0", + "num-traits", + "rand", + "rayon", + "safetensors", + "serde", + "thiserror 1.0.69", +] + [[package]] name = "memchr" version = "2.7.4" @@ -2063,9 +2277,9 @@ dependencies = [ [[package]] name = "minijinja" -version = "2.4.0" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c9ca8daf4b0b4029777f1bc6e1aedd1aec7b74c276a43bc6f620a8e1a1c0a90e" +checksum = "2c37e1b517d1dcd0e51dc36c4567b9d5a29262b3ec8da6cb5d35e27a8fb529b5" dependencies = [ "serde", "serde_json", @@ -2073,9 +2287,9 @@ dependencies = [ [[package]] name = "minijinja-contrib" -version = "2.3.1" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39ffd46ee854be23604a20efd6c9655374fefbe4d44b949dc0f907305d92873a" +checksum = "7fe51f1a6a8285f03fcd1544d834234fe8db285f29e1c2253600c93b3ae19242" dependencies = [ "minijinja", "serde", @@ -2087,15 +2301,6 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" -[[package]] -name = "miniz_oxide" -version = "0.7.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8a240ddb74feaf34a79a7add65a741f3167852fba007066dcac1ca548d89c08" -dependencies = [ - "adler", -] - [[package]] name = "miniz_oxide" version = "0.8.0" @@ -2124,7 +2329,7 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "80e04d1dcff3aae0704555fe5fee3bcfaf3d1fdf8a7e521d5b9d2b42acb52cec" dependencies = [ - "hermit-abi", + "hermit-abi 0.3.9", "libc", "wasi", "windows-sys 0.52.0", @@ -2135,11 +2340,11 @@ name = "mistralrs" version = "0.3.2" dependencies = [ "anyhow", - "candle-core", "either", "futures", "image", "indexmap", + "mcandle-core", "mistralrs-core", "rand", "reqwest", @@ -2153,9 +2358,9 @@ name = "mistralrs-bench" version = "0.3.2" dependencies = [ "anyhow", - "candle-core", "clap", "cli-table", + "mcandle-core", "mistralrs-core", "serde", "serde_json", @@ -2177,9 +2382,6 @@ dependencies = [ "buildstructor", "bytemuck", "bytemuck_derive", - "candle-core", - "candle-flash-attn", - "candle-nn", "cfgrammar", "chrono", "clap", @@ -2199,6 +2401,9 @@ dependencies = [ "intel-mkl-src", "itertools 0.13.0", "lrtable", + "mcandle-core", + "mcandle-flash-attn", + "mcandle-nn", "minijinja", "minijinja-contrib", "mistralrs-paged-attn", @@ -2211,7 +2416,7 @@ dependencies = [ "rand_isaac", "rayon", "regex", - "regex-automata 0.4.8", + "regex-automata 0.4.9", "reqwest", "rustc-hash", "safetensors", @@ -2222,7 +2427,7 @@ dependencies = [ "serde_yaml", "strum", "sysinfo", - "thiserror", + "thiserror 1.0.69", "tokenizers", "tokio", "tokio-rayon", @@ -2241,9 +2446,9 @@ version = "0.3.2" dependencies = [ "anyhow", "bindgen_cuda 0.1.6", - "candle-core", "float8", "half", + "mcandle-core", ] [[package]] @@ -2253,13 +2458,13 @@ dependencies = [ "accelerate-src", "anyhow", "base64 0.22.1", - "candle-core", "data-url", "either", "futures", "image", "indexmap", "intel-mkl-src", + "mcandle-core", "mistralrs-core", "pyo3", "pyo3-build-config", @@ -2276,17 +2481,17 @@ version = "0.3.2" dependencies = [ "bindgen_cuda 0.1.5", "byteorder", - "candle-core", - "candle-nn", "float8", "half", "lazy_static", + "mcandle-core", + "mcandle-nn", "metal 0.27.0", "once_cell", "paste", "rayon", "serde", - "thiserror", + "thiserror 1.0.69", "tracing", ] @@ -2297,7 +2502,6 @@ dependencies = [ "accelerate-src", "anyhow", "axum", - "candle-core", "clap", "ctrlc", "data-url", @@ -2306,6 +2510,7 @@ dependencies = [ "image", "indexmap", "intel-mkl-src", + "mcandle-core", "mistralrs-core", "once_cell", "regex", @@ -2324,8 +2529,8 @@ dependencies = [ name = "mistralrs-vision" version = "0.3.2" dependencies = [ - "candle-core", "image", + "mcandle-core", ] [[package]] @@ -2346,7 +2551,7 @@ checksum = "a7ce64b975ed4f123575d11afd9491f2e37bbd5813fbfbc0f09ae1fbddea74e0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -2497,7 +2702,7 @@ version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" dependencies = [ - "hermit-abi", + "hermit-abi 0.3.9", "libc", ] @@ -2519,7 +2724,7 @@ dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -2568,7 +2773,7 @@ dependencies = [ "serde_json", "strum", "strum_macros", - "thiserror", + "thiserror 1.0.69", ] [[package]] @@ -2590,7 +2795,7 @@ dependencies = [ "serde_json", "sha2", "tar", - "thiserror", + "thiserror 1.0.69", "toml", "ureq", "url", @@ -2626,6 +2831,12 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "oorandom" +version = "11.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" + [[package]] name = "openssl" version = "0.10.68" @@ -2649,7 +2860,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -2802,7 +3013,7 @@ dependencies = [ "crc32fast", "fdeflate", "flate2", - "miniz_oxide 0.8.0", + "miniz_oxide", ] [[package]] @@ -2872,7 +3083,7 @@ dependencies = [ "proc-macro-error-attr2", "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -2898,9 +3109,9 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.22.5" +version = "0.22.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d922163ba1f79c04bc49073ba7b32fd5a8d3b76a87c955921234b8e77333c51" +checksum = "f402062616ab18202ae8319da13fa4279883a2b8a9d9f83f20dbade813ce1884" dependencies = [ "anyhow", "cfg-if", @@ -2929,9 +3140,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.22.5" +version = "0.22.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc38c5feeb496c8321091edf3d63e9a6829eab4b863b4a6a65f26f3e9cc6b179" +checksum = "b14b5775b5ff446dd1056212d778012cbe8a0fbffd368029fd9e25b514479c38" dependencies = [ "once_cell", "target-lexicon", @@ -2939,9 +3150,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.22.5" +version = "0.22.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94845622d88ae274d2729fcefc850e63d7a3ddff5e3ce11bd88486db9f1d357d" +checksum = "9ab5bcf04a2cdcbb50c7d6105de943f543f9ed92af55818fd17b660390fc8636" dependencies = [ "libc", "pyo3-build-config", @@ -2949,27 +3160,27 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.22.5" +version = "0.22.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e655aad15e09b94ffdb3ce3d217acf652e26bbc37697ef012f5e5e348c716e5e" +checksum = "0fd24d897903a9e6d80b968368a34e1525aeb719d568dba8b3d4bfa5dc67d453" dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] name = "pyo3-macros-backend" -version = "0.22.5" +version = "0.22.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae1e3f09eecd94618f60a455a23def79f79eba4dc561a97324bf9ac8c6df30ce" +checksum = "36c011a03ba1e50152b4b394b479826cad97e7a21eb52df179cd91ac411cbfbe" dependencies = [ "heck", "proc-macro2", "pyo3-build-config", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -2989,9 +3200,9 @@ checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3" [[package]] name = "quinn" -version = "0.11.5" +version = "0.11.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c7c5fdde3cdae7203427dc4f0a68fe0ed09833edc525a03456b153b79828684" +checksum = "62e96808277ec6f97351a2380e6c25114bc9e67037775464979f3037c92d05ef" dependencies = [ "bytes", "pin-project-lite", @@ -3000,34 +3211,38 @@ dependencies = [ "rustc-hash", "rustls", "socket2", - "thiserror", + "thiserror 2.0.3", "tokio", "tracing", ] [[package]] name = "quinn-proto" -version = "0.11.8" +version = "0.11.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fadfaed2cd7f389d0161bb73eeb07b7b78f8691047a6f3e73caaeae55310a4a6" +checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d" dependencies = [ "bytes", + "getrandom", "rand", "ring", "rustc-hash", "rustls", + "rustls-pki-types", "slab", - "thiserror", + "thiserror 2.0.3", "tinyvec", "tracing", + "web-time", ] [[package]] name = "quinn-udp" -version = "0.5.5" +version = "0.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fe68c2e9e1a1234e218683dbdf9f9dfcb094113c5ac2b938dfcb9bab4c4140b" +checksum = "7d5a626c6807713b15cac82a6acaccd6043c9a5408c24baae07611fec3f243da" dependencies = [ + "cfg_aliases", "libc", "once_cell", "socket2", @@ -3166,7 +3381,7 @@ checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" dependencies = [ "getrandom", "libredox", - "thiserror", + "thiserror 1.0.69", ] [[package]] @@ -3177,7 +3392,7 @@ checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", - "regex-automata 0.4.8", + "regex-automata 0.4.9", "regex-syntax 0.8.5", ] @@ -3192,9 +3407,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.8" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3" +checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" dependencies = [ "aho-corasick", "memchr", @@ -3215,9 +3430,9 @@ checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] name = "reqwest" -version = "0.12.8" +version = "0.12.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f713147fbe92361e52392c73b8c9e48c04c6625bce969ef54dc901e58e042a7b" +checksum = "a77c62af46e79de0a562e1a9849205ffcb7fc1238876e9bd743357570e04046f" dependencies = [ "base64 0.22.1", "bytes", @@ -3297,7 +3512,7 @@ dependencies = [ "proc-macro2", "quote", "rust-embed-utils", - "syn 2.0.85", + "syn 2.0.87", "walkdir", ] @@ -3344,9 +3559,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.37" +version = "0.38.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8acb788b847c24f28525660c4d7758620a7210875711f79e7f663cc152726811" +checksum = "d7f649912bc1495e167a6edee79151c84b1bad49748cb4f1f1167f459f6224f6" dependencies = [ "bitflags 2.6.0", "errno", @@ -3357,9 +3572,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.15" +version = "0.23.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5fbb44d7acc4e873d613422379f69f237a1b141928c02f6bc6ccfddddc2d7993" +checksum = "7f1a745511c54ba6d4465e8d5dfbd81b45791756de28d4981af70d6dca128f1e" dependencies = [ "log", "once_cell", @@ -3384,6 +3599,9 @@ name = "rustls-pki-types" version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "16f1201b3c9a7ee8039bcadc17b7e605e2945b27eee7631788c1bd2b0643674b" +dependencies = [ + "web-time", +] [[package]] name = "rustls-webpki" @@ -3429,9 +3647,9 @@ dependencies = [ [[package]] name = "schannel" -version = "0.1.26" +version = "0.1.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01227be5826fa0690321a2ba6c5cd57a19cf3f6a09e76973b58e61de6ab9d1c1" +checksum = "1f29ebaa345f945cec9fbbc532eb307f0fdad8161f281b6369539c8d84876b3d" dependencies = [ "windows-sys 0.59.0", ] @@ -3457,7 +3675,7 @@ dependencies = [ "proc-macro2", "quote", "serde_derive_internals", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -3481,9 +3699,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.12.0" +version = "2.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea4a292869320c0272d7bc55a5a6aafaff59b4f63404a003887b679a2e05b4b6" +checksum = "fa39c7303dc58b5543c94d22c1766b0d31f2ee58306363ea622b10bbc075eaa2" dependencies = [ "core-foundation-sys", "libc", @@ -3503,22 +3721,22 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" [[package]] name = "serde" -version = "1.0.213" +version = "1.0.215" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ea7893ff5e2466df8d720bb615088341b295f849602c6956047f8f80f0e9bc1" +checksum = "6513c1ad0b11a9376da888e3e0baa0077f1aed55c17f50e7b2397136129fb88f" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.213" +version = "1.0.215" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e85ad2009c50b58e87caa8cd6dac16bdf511bbfb7af6c33df902396aa480fa5" +checksum = "ad1e866f866923f252f05c889987993144fb74e722403468a4ebd70c3cd756c0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -3529,14 +3747,14 @@ checksum = "18d26a20a969b9e3fdf2fc2d9f21eda6c40e2de84c9408bb5d3b05d499aae711" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] name = "serde_json" -version = "1.0.132" +version = "1.0.133" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03" +checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377" dependencies = [ "itoa", "memchr", @@ -3707,9 +3925,6 @@ name = "spin" version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" -dependencies = [ - "lock_api", -] [[package]] name = "spm_precompiled" @@ -3770,7 +3985,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -3792,9 +4007,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.85" +version = "2.0.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5023162dfcd14ef8f32034d8bcd4cc5ddc61ef7a247c024a33e24e1f24d21b56" +checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d" dependencies = [ "proc-macro2", "quote", @@ -3824,7 +4039,7 @@ checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -3837,7 +4052,7 @@ dependencies = [ "byteorder", "enum-as-inner", "libc", - "thiserror", + "thiserror 1.0.69", "walkdir", ] @@ -3879,9 +4094,9 @@ dependencies = [ [[package]] name = "tar" -version = "0.4.42" +version = "0.4.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ff6c40d3aedb5e06b57c6f669ad17ab063dd1e63d977c6a88e7f4dfa4f04020" +checksum = "c65998313f8e17d0d553d28f91a0df93e4dbbbf770279c7bc21ca0f09ea1a1f6" dependencies = [ "filetime", "libc", @@ -3896,9 +4111,9 @@ checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" [[package]] name = "tempfile" -version = "3.13.0" +version = "3.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0f2c9fc62d0beef6951ccffd757e241266a2c833136efbe35af6cd2567dca5b" +checksum = "28cce251fcbc87fac86a866eeb0d6c2d536fc16d06f184bb61aeae11aa4cee0c" dependencies = [ "cfg-if", "fastrand", @@ -3928,22 +4143,42 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.65" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl 1.0.69", +] + +[[package]] +name = "thiserror" +version = "2.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d11abd9594d9b38965ef50805c5e469ca9cc6f197f883f717e0269a3057b3d5" +checksum = "c006c85c7651b3cf2ada4584faa36773bd07bac24acfb39f3c431b36d7e667aa" dependencies = [ - "thiserror-impl", + "thiserror-impl 2.0.3", ] [[package]] name = "thiserror-impl" -version = "1.0.65" +version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae71770322cbd277e69d762a16c444af02aa0575ac0d174f0b9562d3b37f8602" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f077553d607adc1caf65430528a576c757a71ed73944b66ebb58ef2bbd243568" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", ] [[package]] @@ -3967,6 +4202,26 @@ dependencies = [ "weezl", ] +[[package]] +name = "tinystr" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9117f5d4db391c1cf6927e7bea3db74b9a1c1add8f7eda9ffd5364f40f57b82f" +dependencies = [ + "displaydoc", + "zerovec", +] + +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "tinyvec" version = "1.8.0" @@ -4008,7 +4263,7 @@ dependencies = [ "serde", "serde_json", "spm_precompiled", - "thiserror", + "thiserror 1.0.69", "unicode-normalization-alignments", "unicode-segmentation", "unicode_categories", @@ -4016,9 +4271,9 @@ dependencies = [ [[package]] name = "tokio" -version = "1.41.0" +version = "1.41.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "145f3413504347a2be84393cc8a7d2fb4d863b375909ea59f2158261aa258bbb" +checksum = "22cfb5bee7a6a52939ca9224d6ac897bb669134078daa8735560897f69de4d33" dependencies = [ "backtrace", "bytes", @@ -4040,7 +4295,7 @@ checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -4196,7 +4451,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -4261,7 +4516,7 @@ checksum = "b9c81686f7ab4065ccac3df7a910c4249f8c0f3fb70421d6ddec19b9311f63f9" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -4280,7 +4535,7 @@ dependencies = [ "num", "serde", "serde_json", - "thiserror", + "thiserror 1.0.69", ] [[package]] @@ -4293,7 +4548,7 @@ dependencies = [ "half", "serde", "serde_json", - "thiserror", + "thiserror 1.0.69", "ug", ] @@ -4308,7 +4563,7 @@ dependencies = [ "objc", "serde", "serde_json", - "thiserror", + "thiserror 1.0.69", "ug", ] @@ -4324,27 +4579,12 @@ version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7e51b68083f157f853b6379db119d1c1be0e6e4dec98101079dec41f6f5cf6df" -[[package]] -name = "unicode-bidi" -version = "0.3.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ab17db44d7388991a428b2ee655ce0c212e862eff1768a455c58f9aad6e7893" - [[package]] name = "unicode-ident" version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" -[[package]] -name = "unicode-normalization" -version = "0.1.24" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5033c97c4262335cded6d6fc3e5c18ab755e1a3dc96376350f3d8e9f009ad956" -dependencies = [ - "tinyvec", -] - [[package]] name = "unicode-normalization-alignments" version = "0.1.12" @@ -4366,6 +4606,12 @@ version = "0.1.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" +[[package]] +name = "unicode-width" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd" + [[package]] name = "unicode_categories" version = "0.1.1" @@ -4411,15 +4657,27 @@ dependencies = [ [[package]] name = "url" -version = "2.5.2" +version = "2.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22784dbdf76fdde8af1aeda5622b546b422b6fc585325248a2bf9f5e41e94d6c" +checksum = "8d157f1b96d14500ffdc1f10ba712e780825526c03d9a49b4d0324b0d9113ada" dependencies = [ "form_urlencoded", "idna", "percent-encoding", ] +[[package]] +name = "utf16_iter" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8232dd3cdaed5356e0f716d285e4b40b932ac434100fe9b7e0e8e935b9e6246" + +[[package]] +name = "utf8_iter" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" + [[package]] name = "utf8parse" version = "0.2.2" @@ -4448,7 +4706,7 @@ dependencies = [ "proc-macro2", "quote", "regex", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -4577,7 +4835,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", "wasm-bindgen-shared", ] @@ -4611,7 +4869,7 @@ checksum = "26c6ab57572f7a24a4985830b120de1594465e5d500f24afe89e16b4e833ef68" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -4632,6 +4890,16 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "web-time" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "webpki-roots" version = "0.26.6" @@ -4884,6 +5152,18 @@ dependencies = [ "memchr", ] +[[package]] +name = "write16" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1890f4022759daae28ed4fe62859b1236caebfc61ede2f63ed4e695f3f6d936" + +[[package]] +name = "writeable" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51" + [[package]] name = "xattr" version = "1.3.1" @@ -4915,7 +5195,7 @@ checksum = "28cc31741b18cb6f1d5ff12f5b7523e3d6eb0852bbbad19d73905511d9849b95" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", "synstructure", ] @@ -4937,7 +5217,7 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", ] [[package]] @@ -4957,7 +5237,7 @@ checksum = "0ea7b4a3637ea8669cedf0f1fd5c286a17f3de97b8dd5a70a6c167a1730e63a5" dependencies = [ "proc-macro2", "quote", - "syn 2.0.85", + "syn 2.0.87", "synstructure", ] @@ -4967,6 +5247,28 @@ version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" +[[package]] +name = "zerovec" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa2b893d79df23bfb12d5461018d408ea19dfafe76c2c7ef6d4eba614f8ff079" +dependencies = [ + "yoke", + "zerofrom", + "zerovec-derive", +] + +[[package]] +name = "zerovec-derive" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + [[package]] name = "zip" version = "1.1.4" @@ -4980,7 +5282,7 @@ dependencies = [ "flate2", "indexmap", "num_enum", - "thiserror", + "thiserror 1.0.69", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index a116c0da5..70198542e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,9 +7,14 @@ members = [ "mistralrs-bench", "mistralrs-vision", "mistralrs-quant", + "mcandle/mcandle-core", + "mcandle/mcandle-nn", ] exclude = [ "mistralrs-paged_attn", + "mcandle/mcandle-flash-attn", + "mcandle/mcandle-kernels", + "mcandle/mcandle-metal-kernels", ] resolver = "2" @@ -25,8 +30,11 @@ license = "MIT" [workspace.dependencies] anyhow = "1.0.80" -candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.8.0", rev = "cb8082b" } -candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.8.0", rev = "cb8082b" } +mcandle-core = { path = "./mcandle/mcandle-core", package = "mcandle-core" } +mcandle-flash-attn = { path = "./mcandle/mcandle-flash-attn" } +mcandle-kernels = { path = "./mcandle/mcandle-kernels" } +mcandle-metal-kernels = { path = "./mcandle/mcandle-metal-kernels" } +mcandle-nn = { path = "./mcandle/mcandle-nn" } serde = "1.0.197" serde_json = "1.0.114" indexmap = { version = "2.2.5", features = ["serde"] } @@ -49,5 +57,23 @@ rayon = "1.1.0" url = "2.5.2" data-url = "0.3.1" buildstructor = "0.5.4" -float8 = "0.1.1" +float8 = { version = "0.1.0", features = ["num-traits", "rand_distr"] } regex = "1.10.6" +byteorder = "1.4.3" +cudarc = { version = "0.12.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false } +gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] } +libc = { version = "0.2.147" } +memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] } +yoke = { version = "0.7.2", features = ["derive"] } +zip = { version = "1.1.1", default-features = false } +metal = { version = "0.27.0", features = ["mps"]} +num_cpus = "1.15.0" +num-traits = "0.2.15" +rand = "0.8.5" +rand_distr = "0.4.3" +safetensors = "0.4.1" +thiserror = "1" +ug = "0.0.2" +ug-cuda = "0.0.2" +ug-metal = "0.0.2" +criterion = { version = "0.5.1", default-features=false } diff --git a/README.md b/README.md index 8b6a5be3f..af229409a 100644 --- a/README.md +++ b/README.md @@ -597,7 +597,7 @@ If you want to add a new model, please contact us via an issue and we can coordi - For non-quantized models, you can specify the data type to load and run in. This must be one of `f32`, `f16`, `bf16` or `auto` to choose based on the device. ## Credits -This project would not be possible without the excellent work at [`candle`](https://github.com/huggingface/candle). Additionally, thank you to all contributors! Contributing can range from raising an issue or suggesting a feature to adding some new functionality. +This project would not be possible without the excellent work at [`mcandle`](https://github.com/huggingface/mcandle). Additionally, thank you to all contributors! Contributing can range from raising an issue or suggesting a feature to adding some new functionality.

⬆️ Back to Top diff --git a/mcandle/mcandle-core/Cargo.toml b/mcandle/mcandle-core/Cargo.toml new file mode 100644 index 000000000..44c9fd018 --- /dev/null +++ b/mcandle/mcandle-core/Cargo.toml @@ -0,0 +1,58 @@ +[package] +name = "mcandle-core" +version.workspace = true +edition.workspace = true +description.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true +readme = "README.md" + +[dependencies] +accelerate-src = { workspace = true, optional = true } +byteorder = { workspace = true } +mcandle-kernels = { workspace = true, optional = true } +mcandle-metal-kernels = { workspace = true, optional = true } +metal = { workspace = true, optional = true} +cudarc = { workspace = true, optional = true } +gemm = { workspace = true } +half = { workspace = true } +float8 = { workspace = true } +intel-mkl-src = { workspace = true, optional = true } +libc = { workspace = true, optional = true } +memmap2 = { workspace = true } +num-traits = { workspace = true } +num_cpus = { workspace = true } +rand = { workspace = true } +rand_distr = { workspace = true } +rayon = { workspace = true } +safetensors = { workspace = true } +thiserror = { workspace = true } +ug = { workspace = true } +ug-cuda = { workspace = true, optional = true } +ug-metal = { workspace = true, optional = true } +yoke = { workspace = true } +zip = { workspace = true } + +[dev-dependencies] +anyhow = { workspace = true } +clap = { workspace = true } +criterion = { workspace = true } + + +[features] +default = [] +cuda = ["cudarc", "dep:mcandle-kernels", "float8/cuda", "dep:ug-cuda"] +cudnn = ["cuda", "cudarc/cudnn"] +mkl = ["dep:libc", "dep:intel-mkl-src"] +accelerate = ["dep:libc", "dep:accelerate-src"] +metal = ["dep:metal", "dep:mcandle-metal-kernels", "dep:ug-metal"] + +[[bench]] +name = "bench_main" +harness = false + +[[example]] +name = "metal_basics" +required-features = ["metal"] diff --git a/mcandle/mcandle-core/LICENSE b/mcandle/mcandle-core/LICENSE new file mode 100644 index 000000000..261eeb9e9 --- /dev/null +++ b/mcandle/mcandle-core/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/mcandle/mcandle-core/README.md b/mcandle/mcandle-core/README.md new file mode 100644 index 000000000..dd6c2572f --- /dev/null +++ b/mcandle/mcandle-core/README.md @@ -0,0 +1,2 @@ +# mcandle +Minimalist ML framework for Rust diff --git a/mcandle/mcandle-core/benches/bench_main.rs b/mcandle/mcandle-core/benches/bench_main.rs new file mode 100644 index 000000000..2e1816fd7 --- /dev/null +++ b/mcandle/mcandle-core/benches/bench_main.rs @@ -0,0 +1,12 @@ +mod benchmarks; + +use criterion::criterion_main; +criterion_main!( + benchmarks::affine::benches, + benchmarks::matmul::benches, + benchmarks::random::benches, + benchmarks::where_cond::benches, + benchmarks::conv_transpose2d::benches, + benchmarks::qmatmul::benches, + benchmarks::unary::benches +); diff --git a/mcandle/mcandle-core/benches/benchmarks/affine.rs b/mcandle/mcandle-core/benches/benchmarks/affine.rs new file mode 100644 index 000000000..34b50f528 --- /dev/null +++ b/mcandle/mcandle-core/benches/benchmarks/affine.rs @@ -0,0 +1,43 @@ +use crate::benchmarks::{BenchDevice, BenchDeviceHandler}; +use criterion::{black_box, criterion_group, Criterion, Throughput}; +use mcandle_core::{DType, Device, Tensor}; +use std::time::Instant; + +fn run(a: &Tensor) { + a.affine(12.34, 56.78).unwrap(); +} + +fn run_affine_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) { + let b = 1; + let m = 1024; + let k = 1024; + + let tensor = Tensor::zeros((b, m, k), dtype, device).unwrap(); + + let flops = b * m * k * dtype.size_in_bytes(); + + let mut group = c.benchmark_group(device.bench_name(name)); + group.throughput(Throughput::Bytes(flops as u64)); + group.bench_function("iter", move |b| { + b.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + run(black_box(&tensor)); + } + device.sync().unwrap(); + start.elapsed() + }) + }); + group.finish(); +} + +fn criterion_benchmark(c: &mut Criterion) { + let handler = BenchDeviceHandler::new().unwrap(); + for device in handler.devices { + run_affine_benchmark(c, &device, DType::F32, "affine_f32"); + run_affine_benchmark(c, &device, DType::F16, "affine_f16"); + run_affine_benchmark(c, &device, DType::BF16, "affine_bf16"); + } +} + +criterion_group!(benches, criterion_benchmark); diff --git a/mcandle/mcandle-core/benches/benchmarks/conv_transpose2d.rs b/mcandle/mcandle-core/benches/benchmarks/conv_transpose2d.rs new file mode 100644 index 000000000..b38f235a5 --- /dev/null +++ b/mcandle/mcandle-core/benches/benchmarks/conv_transpose2d.rs @@ -0,0 +1,59 @@ +use crate::benchmarks::{BenchDevice, BenchDeviceHandler}; +use criterion::{black_box, criterion_group, Criterion, Throughput}; +use mcandle_core::{DType, Device, Tensor}; +use std::time::Instant; + +fn run( + x: &Tensor, + k: &Tensor, + padding: usize, + output_padding: usize, + stride: usize, + dilation: usize, +) { + x.conv_transpose2d(k, padding, output_padding, stride, dilation) + .unwrap(); +} + +fn run_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) { + let t = Tensor::arange(0.0f32, 10000.0, device) + .unwrap() + .reshape((1, 4, 50, 50)) + .unwrap() + .to_dtype(dtype) + .unwrap(); + + let kernel = Tensor::arange(0.0f32, 100.0, device) + .unwrap() + .reshape((4, 1, 5, 5)) + .unwrap() + .to_dtype(dtype) + .unwrap(); + + let flops = t.dims().iter().product::() * dtype.size_in_bytes(); + + let mut group = c.benchmark_group(device.bench_name(name)); + group.throughput(Throughput::Bytes(flops as u64)); + group.bench_function("iter", move |b| { + b.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + run(black_box(&t), black_box(&kernel), 1, 0, 1, 2); + } + device.sync().unwrap(); + start.elapsed() + }) + }); + group.finish(); +} + +fn criterion_benchmark(c: &mut Criterion) { + let handler = BenchDeviceHandler::new().unwrap(); + for device in handler.devices { + run_benchmark(c, &device, DType::F32, "conv_transpose2d_f32"); + run_benchmark(c, &device, DType::F16, "conv_transpose2d_f16"); + run_benchmark(c, &device, DType::BF16, "conv_transpose2d_bf16"); + } +} + +criterion_group!(benches, criterion_benchmark); diff --git a/mcandle/mcandle-core/benches/benchmarks/matmul.rs b/mcandle/mcandle-core/benches/benchmarks/matmul.rs new file mode 100644 index 000000000..498b82df2 --- /dev/null +++ b/mcandle/mcandle-core/benches/benchmarks/matmul.rs @@ -0,0 +1,44 @@ +use crate::benchmarks::{BenchDevice, BenchDeviceHandler}; +use criterion::{black_box, criterion_group, Criterion, Throughput}; +use mcandle_core::{DType, Device, Tensor}; +use std::time::Instant; + +fn run(a: &Tensor, b: &Tensor) { + a.matmul(&b.t().unwrap()).unwrap(); +} + +fn run_bench(c: &mut Criterion, device: &Device) { + let b = 1; + let m = 1; + let n = 2048; + let k = 2048; + + let dtype = DType::F32; + let lhs = Tensor::zeros((b, m, k), dtype, device).unwrap(); + let rhs = Tensor::zeros((b, n, k), dtype, device).unwrap(); + + let flops = b * m * n * k; + + let mut group = c.benchmark_group(device.bench_name("matmul")); + group.throughput(Throughput::Bytes(flops as u64)); + group.bench_function("iter", move |b| { + b.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + run(black_box(&lhs), black_box(&rhs)); + } + device.sync().unwrap(); + start.elapsed() + }) + }); + group.finish(); +} + +fn criterion_benchmark(c: &mut Criterion) { + let handler = BenchDeviceHandler::new().unwrap(); + for device in handler.devices { + run_bench(c, &device); + } +} + +criterion_group!(benches, criterion_benchmark); diff --git a/mcandle/mcandle-core/benches/benchmarks/mod.rs b/mcandle/mcandle-core/benches/benchmarks/mod.rs new file mode 100644 index 000000000..986488e5f --- /dev/null +++ b/mcandle/mcandle-core/benches/benchmarks/mod.rs @@ -0,0 +1,72 @@ +pub(crate) mod affine; +pub(crate) mod conv_transpose2d; +pub(crate) mod matmul; +pub(crate) mod qmatmul; +pub(crate) mod random; +pub(crate) mod unary; +pub(crate) mod where_cond; + +use mcandle_core::{Device, Result}; + +pub(crate) trait BenchDevice { + fn sync(&self) -> Result<()>; + + fn bench_name>(&self, name: S) -> String; +} + +impl BenchDevice for Device { + fn sync(&self) -> Result<()> { + match self { + Device::Cpu => Ok(()), + Device::Cuda(device) => { + #[cfg(feature = "cuda")] + { + use cuda::WrapErr; + return Ok(device.synchronize().w()?); + } + #[cfg(not(feature = "cuda"))] + panic!("Cuda device without cuda feature enabled: {:?}", device) + } + Device::Metal(device) => { + #[cfg(feature = "metal")] + return Ok(device.wait_until_completed()?); + #[cfg(not(feature = "metal"))] + panic!("Metal device without metal feature enabled: {:?}", device) + } + } + } + + fn bench_name>(&self, name: S) -> String { + match self { + Device::Cpu => { + let cpu_type = if cfg!(feature = "accelerate") { + "accelerate" + } else if cfg!(feature = "mkl") { + "mkl" + } else { + "cpu" + }; + format!("{}_{}", cpu_type, name.into()) + } + Device::Cuda(_) => format!("cuda_{}", name.into()), + Device::Metal(_) => format!("metal_{}", name.into()), + } + } +} + +struct BenchDeviceHandler { + devices: Vec, +} + +impl BenchDeviceHandler { + pub fn new() -> Result { + let mut devices = Vec::new(); + if cfg!(feature = "metal") { + devices.push(Device::new_metal(0)?); + } else if cfg!(feature = "cuda") { + devices.push(Device::new_cuda(0)?); + } + devices.push(Device::Cpu); + Ok(Self { devices }) + } +} diff --git a/mcandle/mcandle-core/benches/benchmarks/qmatmul.rs b/mcandle/mcandle-core/benches/benchmarks/qmatmul.rs new file mode 100644 index 000000000..965266f4c --- /dev/null +++ b/mcandle/mcandle-core/benches/benchmarks/qmatmul.rs @@ -0,0 +1,72 @@ +use crate::benchmarks::{BenchDevice, BenchDeviceHandler}; +use criterion::{black_box, criterion_group, Criterion, Throughput}; +use mcandle_core::{ + quantized::{self, GgmlDType, QMatMul}, + Device, Module, Tensor, +}; +use std::time::Instant; + +fn run(matmul: &QMatMul, x: &Tensor) { + matmul.forward(x).unwrap(); +} + +fn run_bench(c: &mut Criterion, device: &Device, dtype: GgmlDType) { + let b = 1; + let m = 1; + let n = 1024; + let k = 1024; + + let lhs = (0..(m * k)) + .map(|v| v as f32 / (m * k) as f32) + .collect::>(); + let rhs = (0..(k * n)) + .map(|v| v as f32 / (n * k) as f32) + .collect::>(); + + let lhs = Tensor::from_slice(&lhs, (m, k), device).unwrap(); + let rhs = Tensor::from_slice(&rhs, (k, n), device).unwrap(); + + let qtensor = quantized::QTensor::quantize(&rhs.t().unwrap(), dtype).unwrap(); + let matmul = quantized::QMatMul::from_qtensor(qtensor).unwrap(); + + let flops = b * m * n * k; + + let mut group = c.benchmark_group(device.bench_name(format!("qmatmul_{:?}", dtype))); + group.sample_size(200); + group.throughput(Throughput::Bytes(flops as u64)); + group.bench_function("iter", move |b| { + b.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + run(black_box(&matmul), black_box(&lhs)); + } + device.sync().unwrap(); + start.elapsed() + }) + }); + group.finish(); +} + +fn criterion_benchmark(c: &mut Criterion) { + let handler = BenchDeviceHandler::new().unwrap(); + for device in handler.devices { + for dtype in [ + GgmlDType::F32, + GgmlDType::F16, + GgmlDType::Q4_0, + GgmlDType::Q4_1, + GgmlDType::Q5_0, + GgmlDType::Q5_1, + GgmlDType::Q8_0, + GgmlDType::Q2K, + GgmlDType::Q3K, + GgmlDType::Q4K, + GgmlDType::Q5K, + GgmlDType::Q6K, + ] { + run_bench(c, &device, dtype); + } + } +} + +criterion_group!(benches, criterion_benchmark); diff --git a/mcandle/mcandle-core/benches/benchmarks/random.rs b/mcandle/mcandle-core/benches/benchmarks/random.rs new file mode 100644 index 000000000..e0c74b5e0 --- /dev/null +++ b/mcandle/mcandle-core/benches/benchmarks/random.rs @@ -0,0 +1,63 @@ +use crate::benchmarks::{BenchDevice, BenchDeviceHandler}; +use criterion::{black_box, criterion_group, Criterion, Throughput}; +use mcandle_core::{DType, Device, Tensor}; +use std::time::Instant; + +fn rand_uniform(a: &Tensor) { + a.rand_like(-1.0, 123.0).unwrap(); +} + +fn rand_normal(a: &Tensor) { + a.randn_like(100.0, 15.0).unwrap(); +} + +fn run_random_bench(c: &mut Criterion, device: &Device) { + let b = 1; + + let rows = 2048; + let cols = 2048; + + let dtype = DType::F32; + let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap(); + + let flops = b * rows * cols * dtype.size_in_bytes(); + + let mut group = c.benchmark_group(device.bench_name("random_uniform")); + group.throughput(Throughput::Bytes(flops as u64)); + group.bench_function("iter", move |benches| { + benches.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + rand_uniform(black_box(&tensor)); + } + device.sync().unwrap(); + start.elapsed() + }) + }); + group.finish(); + + let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap(); + + let mut group = c.benchmark_group(device.bench_name("random_normal")); + group.throughput(Throughput::Bytes(flops as u64)); + group.bench_function("iter", move |benches| { + benches.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + rand_normal(black_box(&tensor)); + } + device.sync().unwrap(); + start.elapsed() + }) + }); + group.finish(); +} + +fn criterion_benchmark(c: &mut Criterion) { + let handler = BenchDeviceHandler::new().unwrap(); + for device in handler.devices { + run_random_bench(c, &device); + } +} + +criterion_group!(benches, criterion_benchmark); diff --git a/mcandle/mcandle-core/benches/benchmarks/unary.rs b/mcandle/mcandle-core/benches/benchmarks/unary.rs new file mode 100644 index 000000000..f13b17948 --- /dev/null +++ b/mcandle/mcandle-core/benches/benchmarks/unary.rs @@ -0,0 +1,49 @@ +use crate::benchmarks::{BenchDevice, BenchDeviceHandler}; +use criterion::{black_box, criterion_group, Criterion, Throughput}; +use mcandle_core::{DType, Device, Tensor}; +use std::time::Instant; + +fn run(a: &Tensor) { + a.sqrt().unwrap(); +} + +fn run_unary_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) { + let b = 1; + let m = 1024; + let k = 1024; + + let tensor = Tensor::arange(0.0f32, (b * m * k) as f32, device) + .unwrap() + .to_dtype(dtype) + .unwrap() + .reshape((b, m, k)) + .unwrap(); + + let flops = b * m * k * dtype.size_in_bytes(); + + let mut group = c.benchmark_group(device.bench_name(name)); + group.throughput(Throughput::Bytes(flops as u64)); + group.bench_function("iter", move |b| { + b.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + run(black_box(&tensor)); + } + device.sync().unwrap(); + start.elapsed() + }) + }); + group.finish(); +} + +fn criterion_benchmark(c: &mut Criterion) { + let handler = BenchDeviceHandler::new().unwrap(); + for device in handler.devices { + for dtype in [DType::F32, DType::BF16, DType::F16] { + let name = format!("sqrt_{:?}", dtype); + run_unary_benchmark(c, &device, dtype, &name); + } + } +} + +criterion_group!(benches, criterion_benchmark); diff --git a/mcandle/mcandle-core/benches/benchmarks/where_cond.rs b/mcandle/mcandle-core/benches/benchmarks/where_cond.rs new file mode 100644 index 000000000..2ed1f85c9 --- /dev/null +++ b/mcandle/mcandle-core/benches/benchmarks/where_cond.rs @@ -0,0 +1,64 @@ +use crate::benchmarks::{BenchDevice, BenchDeviceHandler}; +use criterion::{black_box, criterion_group, Criterion, Throughput}; +use mcandle_core::{DType, Device, Tensor}; +use std::time::Instant; + +fn run(a: &Tensor, b: &Tensor, c: &Tensor) { + a.where_cond(b, c).unwrap(); +} + +const fn create_cond_arr() -> [u8; N] { + let mut arr = [0u8; N]; + let mut i = 0; + while i < N { + arr[i] = (i % 2) as u8; + i += 1; + } + arr +} + +const B: usize = 1; +const M: usize = 1024; +const K: usize = 1024; +const SIZE: usize = B * M * K; + +const DATA: [u8; SIZE] = create_cond_arr::(); + +fn run_where_cond_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) { + let tensor = Tensor::from_slice(DATA.as_slice(), (B, M, K), device).unwrap(); + let on_true = Tensor::ones((B, M, K), dtype, device).unwrap(); + let on_false = Tensor::zeros((B, M, K), dtype, device).unwrap(); + + let elements = B * M * K; + // E.g. 2 f32 tensors + 1 u8 tensor + let flops = (2 * elements * dtype.size_in_bytes()) + elements; + + let mut group = c.benchmark_group(device.bench_name(name)); + group.throughput(Throughput::Bytes(flops as u64)); + group.bench_function("iter", move |b| { + b.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + run( + black_box(&tensor), + black_box(&on_true), + black_box(&on_false), + ); + } + device.sync().unwrap(); + start.elapsed() + }) + }); + group.finish(); +} + +fn criterion_benchmark(c: &mut Criterion) { + let device = BenchDeviceHandler::new().unwrap(); + for d in device.devices { + run_where_cond_benchmark(c, &d, DType::F32, "where_cond_f32"); + run_where_cond_benchmark(c, &d, DType::BF16, "where_cond_bf16"); + run_where_cond_benchmark(c, &d, DType::F16, "where_cond_f16"); + } +} + +criterion_group!(benches, criterion_benchmark); diff --git a/mcandle/mcandle-core/examples/basics.rs b/mcandle/mcandle-core/examples/basics.rs new file mode 100644 index 000000000..b88c36fa7 --- /dev/null +++ b/mcandle/mcandle-core/examples/basics.rs @@ -0,0 +1,17 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use anyhow::Result; +use mcandle_core::{Device, Tensor}; + +fn main() -> Result<()> { + let a = Tensor::new(&[[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]], &Device::Cpu)?; + let b = Tensor::new(&[[88.0f32, 99.0]], &Device::Cpu)?; + let new_a = a.slice_scatter(&b, 1, 2)?; + assert_eq!(a.to_vec2::()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + assert_eq!(new_a.to_vec2::()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + Ok(()) +} diff --git a/mcandle/mcandle-core/examples/cuda_basics.rs b/mcandle/mcandle-core/examples/cuda_basics.rs new file mode 100644 index 000000000..2457714b1 --- /dev/null +++ b/mcandle/mcandle-core/examples/cuda_basics.rs @@ -0,0 +1,33 @@ +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +use anyhow::Result; +use mcandle_core::{Device, Tensor}; + +fn main() -> Result<()> { + let device = Device::new_cuda(0)?; + let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)? + .to_dtype(mcandle_core::DType::BF16)?; + mcandle_core::cuda::set_gemm_reduced_precision_f32(false); + mcandle_core::cuda::set_gemm_reduced_precision_bf16(false); + let _x1 = x.matmul(&x)?; + drop(_x1); + let start_time = std::time::Instant::now(); + let _x1 = x.matmul(&x)?; + device.synchronize()?; + println!("fp32: {:?}", start_time.elapsed()); + drop(_x1); + mcandle_core::cuda::set_gemm_reduced_precision_f32(true); + mcandle_core::cuda::set_gemm_reduced_precision_bf16(true); + let _x1 = x.matmul(&x)?; + drop(_x1); + let start_time = std::time::Instant::now(); + let _x1 = x.matmul(&x)?; + device.synchronize()?; + println!("tf32: {:?}", start_time.elapsed()); + drop(_x1); + Ok(()) +} diff --git a/mcandle/mcandle-core/examples/cuda_sum_benchmark.rs b/mcandle/mcandle-core/examples/cuda_sum_benchmark.rs new file mode 100644 index 000000000..5104694d9 --- /dev/null +++ b/mcandle/mcandle-core/examples/cuda_sum_benchmark.rs @@ -0,0 +1,54 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use std::str::FromStr; + +use anyhow::Result; +use mcandle_core::{Device, Tensor}; + +fn cos_sin(n: usize, device: &Device) -> Result { + let thetas: Vec<_> = (0..n).map(|i| (i as f32 / n as f32)).collect(); + let xs: Vec<_> = thetas.iter().map(|t| t.cos().abs()).collect(); + let ys: Vec<_> = thetas.iter().map(|t| t.sin().abs()).collect(); + let xs = Tensor::from_vec(xs, (n, 1), device)?; + let ys = Tensor::from_vec(ys, (1, n), device)?; + let ys = Tensor::cat(&[&ys, &ys, &ys, &ys, &ys, &ys], 1)?; + Ok(xs.matmul(&ys)?) +} + +fn main() -> Result<()> { + let device = Device::new_cuda(0)?; + let args = std::env::args().collect::>(); + let n = if args.len() < 2 { + 2000usize + } else { + usize::from_str(&args[1])? + }; + let xys_cpu = cos_sin(n, &Device::Cpu)?; + let xys = cos_sin(n, &device)?; + println!("{xys_cpu:?} {xys:?}"); + let sum_keepdim_cpu = xys_cpu.sum_keepdim(1)?; + println!("{sum_keepdim_cpu}"); + let sum_keepdim = xys.sum_keepdim(1)?; + println!("{sum_keepdim}"); + let start = std::time::Instant::now(); + let n_iters = 100; + let mut v = 0f32; + for _i in 0..n_iters { + let sum_keepdim = xys.sum_keepdim(1)?; + let sum_keepdim = sum_keepdim.sum_keepdim(0)?; + let sum_keepdim: f32 = sum_keepdim.reshape(&[])?.to_scalar()?; + v += sum_keepdim; + } + let elapsed = start.elapsed(); + if v > 0. { + println!( + "ran {n_iters} iterations, time per iter: {:?} ({v})", + elapsed.div_f64(n_iters as f64) + ); + } + Ok(()) +} diff --git a/mcandle/mcandle-core/examples/metal_basics.rs b/mcandle/mcandle-core/examples/metal_basics.rs new file mode 100644 index 000000000..830b12b8d --- /dev/null +++ b/mcandle/mcandle-core/examples/metal_basics.rs @@ -0,0 +1,28 @@ +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +use anyhow::Result; +use mcandle_core::{Device, Tensor}; + +fn main() -> Result<()> { + // This requires the code to be run with MTL_CAPTURE_ENABLED=1 + let device = Device::new_metal(0)?; + let metal_device = match &device { + Device::Metal(m) => m, + _ => anyhow::bail!("unexpected device"), + }; + metal_device.capture("/tmp/mcandle.gputrace")?; + // This first synchronize ensures that a new command buffer gets created after setting up the + // capture scope. + device.synchronize()?; + let x = Tensor::randn(0f32, 1.0, (128, 128), &device)?; + let x1 = x.add(&x)?; + println!("{x1:?}"); + // This second synchronize ensures that the command buffer gets commited before the end of the + // capture scope. + device.synchronize()?; + Ok(()) +} diff --git a/mcandle/mcandle-core/src/accelerate.rs b/mcandle/mcandle-core/src/accelerate.rs new file mode 100644 index 000000000..d371d3b33 --- /dev/null +++ b/mcandle/mcandle-core/src/accelerate.rs @@ -0,0 +1,476 @@ +#![allow(dead_code)] +use libc::{c_char, c_double, c_float, c_int, c_long, c_ulong}; + +mod ffi { + use super::*; + extern "C" { + // It would be nice to be able to switch to the NEWLAPACK version of the function but this + // seems to trigger some link error. Available function names can be seen here: + // /Library/Developer/CommandLineTools/SDKs/MacOSX13.3.sdk/System/Library/Frameworks/Accelerate.framework/Versions/A/Accelerate.tbd + #[link_name = "sgemm_"] + pub fn sgemm_ffi( + transa: *const c_char, + transb: *const c_char, + m: *const c_int, + n: *const c_int, + k: *const c_int, + alpha: *const c_float, + a: *const c_float, + lda: *const c_int, + b: *const c_float, + ldb: *const c_int, + beta: *const c_float, + c: *mut c_float, + ldc: *const c_int, + ); + #[link_name = "dgemm_"] + pub fn dgemm_ffi( + transa: *const c_char, + transb: *const c_char, + m: *const c_int, + n: *const c_int, + k: *const c_int, + alpha: *const c_double, + a: *const c_double, + lda: *const c_int, + b: *const c_double, + ldb: *const c_int, + beta: *const c_double, + c: *mut c_double, + ldc: *const c_int, + ); + + pub fn vvexpf(dst: *mut c_float, src: *const c_float, len: *const c_int); + pub fn vvexp(dst: *mut c_double, src: *const c_double, len: *const c_int); + pub fn vvsqrtf(dst: *mut c_float, src: *const c_float, len: *const c_int); + pub fn vvsqrt(dst: *mut c_double, src: *const c_double, len: *const c_int); + pub fn vvsinf(dst: *mut c_float, src: *const c_float, len: *const c_int); + pub fn vvsin(dst: *mut c_double, src: *const c_double, len: *const c_int); + pub fn vvcosf(dst: *mut c_float, src: *const c_float, len: *const c_int); + pub fn vvcos(dst: *mut c_double, src: *const c_double, len: *const c_int); + pub fn vvlogf(dst: *mut c_float, src: *const c_float, len: *const c_int); + pub fn vvlog(dst: *mut c_double, src: *const c_double, len: *const c_int); + pub fn vvtanhf(dst: *mut c_float, src: *const c_float, len: *const c_int); + pub fn vvtanh(dst: *mut c_double, src: *const c_double, len: *const c_int); + + pub fn vDSP_vaddD( + _: *const c_double, + _: c_long, + _: *const c_double, + _: c_long, + _: *mut c_double, + _: c_long, + _: c_ulong, + ); + pub fn vDSP_vadd( + _: *const c_float, + _: c_long, + _: *const c_float, + _: c_long, + _: *mut c_float, + _: c_long, + _: c_ulong, + ); + pub fn vDSP_vsubD( + _: *const c_double, + _: c_long, + _: *const c_double, + _: c_long, + _: *mut c_double, + _: c_long, + _: c_ulong, + ); + pub fn vDSP_vsub( + _: *const c_float, + _: c_long, + _: *const c_float, + _: c_long, + _: *mut c_float, + _: c_long, + _: c_ulong, + ); + pub fn vDSP_vmulD( + _: *const c_double, + _: c_long, + _: *const c_double, + _: c_long, + _: *mut c_double, + _: c_long, + _: c_ulong, + ); + pub fn vDSP_vmul( + _: *const c_float, + _: c_long, + _: *const c_float, + _: c_long, + _: *mut c_float, + _: c_long, + _: c_ulong, + ); + pub fn vDSP_vdivD( + _: *const c_double, + _: c_long, + _: *const c_double, + _: c_long, + _: *mut c_double, + _: c_long, + _: c_ulong, + ); + pub fn vDSP_vdiv( + _: *const c_float, + _: c_long, + _: *const c_float, + _: c_long, + _: *mut c_float, + _: c_long, + _: c_ulong, + ); + pub fn vDSP_vminD( + _: *const c_double, + _: c_long, + _: *const c_double, + _: c_long, + _: *mut c_double, + _: c_long, + _: c_ulong, + ); + pub fn vDSP_vmin( + _: *const c_float, + _: c_long, + _: *const c_float, + _: c_long, + _: *mut c_float, + _: c_long, + _: c_ulong, + ); + pub fn vDSP_vmaxD( + _: *const c_double, + _: c_long, + _: *const c_double, + _: c_long, + _: *mut c_double, + _: c_long, + _: c_ulong, + ); + pub fn vDSP_vmax( + _: *const c_float, + _: c_long, + _: *const c_float, + _: c_long, + _: *mut c_float, + _: c_long, + _: c_ulong, + ); + } +} + +#[allow(clippy::too_many_arguments)] +#[inline] +pub unsafe fn sgemm( + transa: u8, + transb: u8, + m: i32, + n: i32, + k: i32, + alpha: f32, + a: &[f32], + lda: i32, + b: &[f32], + ldb: i32, + beta: f32, + c: &mut [f32], + ldc: i32, +) { + ffi::sgemm_ffi( + &(transa as c_char), + &(transb as c_char), + &m, + &n, + &k, + &alpha, + a.as_ptr(), + &lda, + b.as_ptr(), + &ldb, + &beta, + c.as_mut_ptr(), + &ldc, + ) +} + +#[allow(clippy::too_many_arguments)] +#[inline] +pub unsafe fn dgemm( + transa: u8, + transb: u8, + m: i32, + n: i32, + k: i32, + alpha: f64, + a: &[f64], + lda: i32, + b: &[f64], + ldb: i32, + beta: f64, + c: &mut [f64], + ldc: i32, +) { + ffi::dgemm_ffi( + &(transa as c_char), + &(transb as c_char), + &m, + &n, + &k, + &alpha, + a.as_ptr(), + &lda, + b.as_ptr(), + &ldb, + &beta, + c.as_mut_ptr(), + &ldc, + ) +} + +#[inline] +pub fn vs_exp(a: &[f32], y: &mut [f32]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vvexpf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) } +} + +#[inline] +pub fn vd_exp(a: &[f64], y: &mut [f64]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vvexp(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) } +} + +#[inline] +pub fn vs_sqrt(a: &[f32], y: &mut [f32]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vvsqrtf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) } +} + +#[inline] +pub fn vd_sqrt(a: &[f64], y: &mut [f64]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vvsqrt(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) } +} + +#[inline] +pub fn vs_sin(a: &[f32], y: &mut [f32]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vvsinf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) } +} + +#[inline] +pub fn vd_sin(a: &[f64], y: &mut [f64]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vvsin(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) } +} +#[inline] +pub fn vs_cos(a: &[f32], y: &mut [f32]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vvcosf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) } +} + +#[inline] +pub fn vd_cos(a: &[f64], y: &mut [f64]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vvcos(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) } +} +#[inline] +pub fn vs_tanh(a: &[f32], y: &mut [f32]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vvtanhf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) } +} + +#[inline] +pub fn vd_tanh(a: &[f64], y: &mut [f64]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vvtanh(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) } +} + +#[inline] +pub fn vs_ln(a: &[f32], y: &mut [f32]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vvlogf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) } +} + +#[inline] +pub fn vd_ln(a: &[f64], y: &mut [f64]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vvlog(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) } +} + +#[inline] +pub fn vs_sqr(a: &[f32], y: &mut [f32]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + y.iter_mut().zip(a.iter()).for_each(|(y, a)| *y = *a * *a) +} + +#[inline] +pub fn vd_sqr(a: &[f64], y: &mut [f64]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + y.iter_mut().zip(a.iter()).for_each(|(y, a)| *y = *a * *a) +} + +#[inline] +pub fn vs_tanh_inplace(y: &mut [f32]) { + unsafe { ffi::vvtanhf(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) } +} + +#[inline] +pub fn vd_tanh_inplace(y: &mut [f64]) { + unsafe { ffi::vvtanh(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) } +} + +#[inline] +pub fn vs_exp_inplace(y: &mut [f32]) { + unsafe { ffi::vvexpf(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) } +} + +#[inline] +pub fn vd_exp_inplace(y: &mut [f64]) { + unsafe { ffi::vvexp(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) } +} + +#[inline] +pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) { + for (&v, y) in vs.iter().zip(ys.iter_mut()) { + *y = (2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v) + } + vs_tanh_inplace(ys); + for (&v, y) in vs.iter().zip(ys.iter_mut()) { + *y = 0.5 * v * (1.0 + *y) + } +} + +#[inline] +pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) { + for (&v, y) in vs.iter().zip(ys.iter_mut()) { + *y = (2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v) + } + vd_tanh_inplace(ys); + for (&v, y) in vs.iter().zip(ys.iter_mut()) { + *y = 0.5 * v * (1.0 + *y) + } +} + +#[inline] +pub fn vs_silu(vs: &[f32], ys: &mut [f32]) { + for (&v, y) in vs.iter().zip(ys.iter_mut()) { + *y = -v + } + vs_exp_inplace(ys); + for (&v, y) in vs.iter().zip(ys.iter_mut()) { + *y = v / (1.0 + *y) + } +} + +#[inline] +pub fn vd_silu(vs: &[f64], ys: &mut [f64]) { + for (&v, y) in vs.iter().zip(ys.iter_mut()) { + *y = -v + } + vd_exp_inplace(ys); + for (&v, y) in vs.iter().zip(ys.iter_mut()) { + *y = v / (1.0 + *y) + } +} + +macro_rules! binary_op { + ($fn_name:ident, $ty:ty, $accelerate_name:ident) => { + #[inline] + pub fn $fn_name(a: &[$ty], b: &[$ty], y: &mut [$ty]) { + let a_len = a.len(); + let b_len = b.len(); + let y_len = y.len(); + if a_len != y_len || b_len != y_len { + panic!( + "{} a,b,y len mismatch {a_len} {b_len} {y_len}", + stringify!($fn_name) + ); + } + unsafe { + // Weird quirk of accelerate, the rhs comes before the lhs. + ffi::$accelerate_name( + b.as_ptr(), + 1, + a.as_ptr(), + 1, + y.as_mut_ptr(), + 1, + a_len as u64, + ) + } + } + }; +} +binary_op!(vs_add, f32, vDSP_vadd); +binary_op!(vd_add, f64, vDSP_vaddD); +binary_op!(vs_sub, f32, vDSP_vsub); +binary_op!(vd_sub, f64, vDSP_vsubD); +binary_op!(vs_mul, f32, vDSP_vmul); +binary_op!(vd_mul, f64, vDSP_vmulD); +binary_op!(vs_div, f32, vDSP_vdiv); +binary_op!(vd_div, f64, vDSP_vdivD); +binary_op!(vs_max, f32, vDSP_vmax); +binary_op!(vd_max, f64, vDSP_vmaxD); +binary_op!(vs_min, f32, vDSP_vmin); +binary_op!(vd_min, f64, vDSP_vminD); diff --git a/mcandle/mcandle-core/src/backend.rs b/mcandle/mcandle-core/src/backend.rs new file mode 100644 index 000000000..655c7894d --- /dev/null +++ b/mcandle/mcandle-core/src/backend.rs @@ -0,0 +1,165 @@ +use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; +use crate::{CpuStorage, DType, Layout, Result, Shape}; + +pub trait BackendStorage: Sized { + type Device: BackendDevice; + + fn try_clone(&self, _: &Layout) -> Result; + + fn dtype(&self) -> DType; + + fn device(&self) -> &Self::Device; + + // Maybe this should return a Cow instead so that no copy is done on the cpu case. + fn to_cpu_storage(&self) -> Result; + + fn affine(&self, _: &Layout, _: f64, _: f64) -> Result; + + fn powf(&self, _: &Layout, _: f64) -> Result; + + fn elu(&self, _: &Layout, _: f64) -> Result; + + fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result; + + fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result; + + fn to_dtype(&self, _: &Layout, _: DType) -> Result; + + fn unary_impl(&self, _: &Layout) -> Result; + + fn binary_impl(&self, _: &Self, _: &Layout, _: &Layout) -> Result; + + fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result; + + fn conv1d( + &self, + _l: &Layout, + _kernel: &Self, + _kernel_l: &Layout, + _params: &crate::conv::ParamsConv1D, + ) -> Result; + + fn conv_transpose1d( + &self, + _l: &Layout, + _kernel: &Self, + _kernel_l: &Layout, + _params: &crate::conv::ParamsConvTranspose1D, + ) -> Result; + + fn conv2d( + &self, + _l: &Layout, + _kernel: &Self, + _kernel_l: &Layout, + _params: &crate::conv::ParamsConv2D, + ) -> Result; + + fn conv_transpose2d( + &self, + _l: &Layout, + _kernel: &Self, + _kernel_l: &Layout, + _params: &crate::conv::ParamsConvTranspose2D, + ) -> Result; + + fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result; + fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result; + fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result; + fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result; + + fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result; + fn scatter_add( + &self, + _: &Layout, + _: &Self, + _: &Layout, + _: &Self, + _: &Layout, + _: usize, + ) -> Result; + fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result; + fn index_add( + &self, + _: &Layout, + _: &Self, + _: &Layout, + _: &Self, + _: &Layout, + _: usize, + ) -> Result; + + #[allow(clippy::too_many_arguments)] + fn matmul_with_alpha_beta( + &self, + _: &Self, + _: &mut Self, + _: Option, + _: (usize, usize, usize, usize), + _: &Layout, + _: &Layout, + _: &Layout, + ) -> Result<()>; + + #[allow(clippy::too_many_arguments)] + fn matmul_with_alpha( + &self, + _: &Self, + _: Option, + _: (usize, usize, usize, usize), + _: &Layout, + _: &Layout, + ) -> Result; + + fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()>; + + #[allow(clippy::too_many_arguments)] + // Similar to cudaMemcpy2D, though values are in elements and not in bytes. + fn copy2d( + &self, + _: &mut Self, + _d1: usize, + _d2: usize, + _src_stride1: usize, + _dst_stride1: usize, + _src_offset: usize, + _dst_offset: usize, + ) -> Result<()>; +} + +pub trait BackendDevice: Sized + std::fmt::Debug + Clone { + type Storage: BackendStorage; + + // TODO: Make the usize generic and part of a generic DeviceLocation. + fn new(_: usize) -> Result; + + fn location(&self) -> crate::DeviceLocation; + + fn same_device(&self, _: &Self) -> bool; + + fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result; + + fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result; + + /// # Safety + /// This function is unsafe as it doesn't initialize the underlying data store. + /// The caller should ensure that the data is properly initialized as early as possible + /// after this call. + unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result; + + fn storage_from_slice(&self, _: &[T]) -> Result; + + fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result; + + fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result; + + fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result; + + fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result; + + fn set_seed(&self, _: u64) -> Result<()>; + fn get_current_seed(&self) -> Result; + + /// Synchronize should block until all the operations on the device are completed. + fn synchronize(&self) -> Result<()>; +} diff --git a/mcandle/mcandle-core/src/backprop.rs b/mcandle/mcandle-core/src/backprop.rs new file mode 100644 index 000000000..d7169d631 --- /dev/null +++ b/mcandle/mcandle-core/src/backprop.rs @@ -0,0 +1,764 @@ +/// Methods for backpropagation of gradients. +use crate::op::{BinaryOp, Op, ReduceOp, UnaryOp}; +use crate::{Error, Result, Tensor, TensorId}; +use std::collections::HashMap; + +// arg has been reduced to node via reduce_dims, expand it back to arg. +// This has to handle keepdims. +fn broadcast_back(arg: &Tensor, node: &Tensor, reduced_dims: &[usize]) -> Result { + if arg.rank() == node.rank() { + // keepdim = true + node.broadcast_as(arg.shape()) + } else { + // keepdim = false + // first expand the reduced dims. + node.reshape(reduced_dims)?.broadcast_as(arg.shape()) + } +} + +thread_local! { + static CANDLE_GRAD_DO_NOT_DETACH: bool = { + match std::env::var("CANDLE_GRAD_DO_NOT_DETACH") { + Ok(s) => { + !s.is_empty() && s != "0" + }, + Err(_) => false, + } + } +} + +impl Tensor { + /// Return all the nodes that lead to this value in a topologically sorted vec, the first + /// elements having dependencies on the latter ones, e.g. the first element if any is the + /// argument. + /// This assumes that the op graph is a DAG. + fn sorted_nodes(&self) -> Vec<&Tensor> { + // The vec of sorted nodes is passed as an owned value rather than a mutable reference + // to get around some lifetime limitations. + fn walk<'a>( + node: &'a Tensor, + nodes: Vec<&'a Tensor>, + already_seen: &mut HashMap, + ) -> (bool, Vec<&'a Tensor>) { + if let Some(&tg) = already_seen.get(&node.id()) { + return (tg, nodes); + } + let mut track_grad = false; + let mut nodes = if node.is_variable() { + // Do not call recursively on the "leaf" nodes. + track_grad = true; + nodes + } else if node.dtype().is_int() { + nodes + } else if let Some(op) = node.op() { + match op { + Op::IndexAdd(t1, t2, t3, _) + | Op::ScatterAdd(t1, t2, t3, _) + | Op::CustomOp3(t1, t2, t3, _) + | Op::WhereCond(t1, t2, t3) => { + let (tg, nodes) = walk(t1, nodes, already_seen); + track_grad |= tg; + let (tg, nodes) = walk(t2, nodes, already_seen); + track_grad |= tg; + let (tg, nodes) = walk(t3, nodes, already_seen); + track_grad |= tg; + nodes + } + Op::Conv1D { + arg: lhs, + kernel: rhs, + .. + } + | Op::ConvTranspose1D { + arg: lhs, + kernel: rhs, + .. + } + | Op::Conv2D { + arg: lhs, + kernel: rhs, + .. + } + | Op::ConvTranspose2D { + arg: lhs, + kernel: rhs, + .. + } + | Op::CustomOp2(lhs, rhs, _) + | Op::Binary(lhs, rhs, _) + | Op::Gather(lhs, rhs, _) + | Op::IndexSelect(lhs, rhs, _) + | Op::Matmul(lhs, rhs) + | Op::SliceScatter0(lhs, rhs, _) => { + let (tg, nodes) = walk(lhs, nodes, already_seen); + track_grad |= tg; + let (tg, nodes) = walk(rhs, nodes, already_seen); + track_grad |= tg; + nodes + } + Op::Cat(args, _) => args.iter().fold(nodes, |nodes, arg| { + let (tg, nodes) = walk(arg, nodes, already_seen); + track_grad |= tg; + nodes + }), + Op::Affine { arg, mul, .. } => { + if *mul == 0. { + nodes + } else { + let (tg, nodes) = walk(arg, nodes, already_seen); + track_grad |= tg; + nodes + } + } + Op::Unary(_node, UnaryOp::Ceil) + | Op::Unary(_node, UnaryOp::Floor) + | Op::Unary(_node, UnaryOp::Round) + | Op::Unary(_node, UnaryOp::Sign) => nodes, + Op::Reshape(node) + | Op::UpsampleNearest1D { arg: node, .. } + | Op::UpsampleNearest2D { arg: node, .. } + | Op::AvgPool2D { arg: node, .. } + | Op::MaxPool2D { arg: node, .. } + | Op::Copy(node) + | Op::Broadcast(node) + | Op::Cmp(node, _) + | Op::Reduce(node, ReduceOp::Min | ReduceOp::Sum | ReduceOp::Max, _) + | Op::ToDevice(node) + | Op::Transpose(node, _, _) + | Op::Permute(node, _) + | Op::Narrow(node, _, _, _) + | Op::Unary(node, _) + | Op::Elu(node, _) + | Op::Powf(node, _) + | Op::CustomOp1(node, _) => { + let (tg, nodes) = walk(node, nodes, already_seen); + track_grad |= tg; + nodes + } + Op::ToDType(node) => { + if node.dtype().is_float() { + let (tg, nodes) = walk(node, nodes, already_seen); + track_grad |= tg; + nodes + } else { + nodes + } + } + Op::Reduce(_, ReduceOp::ArgMin | ReduceOp::ArgMax, _) => nodes, + } + } else { + nodes + }; + already_seen.insert(node.id(), track_grad); + if track_grad { + nodes.push(node); + } + (track_grad, nodes) + } + let (_tg, mut nodes) = walk(self, vec![], &mut HashMap::new()); + nodes.reverse(); + nodes + } + + pub fn backward(&self) -> Result { + let sorted_nodes = self.sorted_nodes(); + let mut grads = GradStore::new(); + grads.insert(self, self.ones_like()?.contiguous()?); + for node in sorted_nodes.iter() { + if node.is_variable() { + continue; + } + let grad = grads + .remove(node) + .expect("mcandle internal error - grad not populated"); + // https://github.com/huggingface/mcandle/issues/1241 + // Ideally, we would make these operations in place where possible to ensure that we + // do not have to allocate too often. Here we just call `.detach` to avoid computing + // the backprop graph of the backprop itself. This would be an issue for second order + // derivatives but these are out of scope at the moment. + let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b); + let grad = if do_not_detach { grad } else { grad.detach() }; + if let Some(op) = node.op() { + match op { + Op::Binary(lhs, rhs, BinaryOp::Add) => { + let lhs_sum_grad = grads.or_insert(lhs)?; + *lhs_sum_grad = lhs_sum_grad.add(&grad)?; + let rhs_sum_grad = grads.or_insert(rhs)?; + *rhs_sum_grad = rhs_sum_grad.add(&grad)?; + } + Op::Binary(lhs, rhs, BinaryOp::Sub) => { + let lhs_sum_grad = grads.or_insert(lhs)?; + *lhs_sum_grad = lhs_sum_grad.add(&grad)?; + let rhs_sum_grad = grads.or_insert(rhs)?; + *rhs_sum_grad = rhs_sum_grad.sub(&grad)?; + } + Op::Binary(lhs, rhs, BinaryOp::Mul) => { + let lhs_grad = grad.mul(rhs)?; + let lhs_sum_grad = grads.or_insert(lhs)?; + *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?; + let rhs_grad = grad.mul(lhs)?; + let rhs_sum_grad = grads.or_insert(rhs)?; + *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; + } + Op::Binary(lhs, rhs, BinaryOp::Div) => { + let lhs_grad = grad.div(rhs)?; + let lhs_sum_grad = grads.or_insert(lhs)?; + *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?; + let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?; + let rhs_sum_grad = grads.or_insert(rhs)?; + *rhs_sum_grad = rhs_sum_grad.sub(&rhs_grad)?; + } + Op::Binary(lhs, rhs, BinaryOp::Minimum) + | Op::Binary(lhs, rhs, BinaryOp::Maximum) => { + let mask_lhs = node.eq(lhs)?.to_dtype(grad.dtype())?; + let mask_rhs = node.eq(rhs)?.to_dtype(grad.dtype())?; + + // If both masks are 1 one the same point, we want to scale the + // gradient by 0.5 rather than 1. + let lhs_grad = mask_lhs.mul(&grad)?.div(&(&mask_rhs + 1.)?)?; + let lhs_sum_grad = grads.or_insert(lhs)?; + *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?; + + let rhs_grad = mask_rhs.mul(&grad)?.div(&(&mask_lhs + 1.)?)?; + let rhs_sum_grad = grads.or_insert(rhs)?; + *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; + } + Op::WhereCond(pred, t, f) => { + let zeros = grad.zeros_like()?; + let t_sum_grad = grads.or_insert(t)?; + let t_grad = pred.where_cond(&grad, &zeros)?; + *t_sum_grad = t_sum_grad.add(&t_grad)?; + let f_sum_grad = grads.or_insert(f)?; + let f_grad = pred.where_cond(&zeros, &grad)?; + *f_sum_grad = f_sum_grad.add(&f_grad)?; + } + Op::Conv1D { + arg, + kernel, + padding, + stride, + dilation, + } => { + // The output height for conv_transpose1d is: + // (l_in - 1) * stride - 2 * padding + dilation * (k_size - 1) + out_padding + 1 + let grad_l_in = grad.dim(2)?; + let k_size = kernel.dim(2)?; + let out_size = + (grad_l_in - 1) * stride + dilation * (k_size - 1) + 1 - 2 * padding; + let out_padding = arg.dim(2)? - out_size; + let grad_arg = grad.conv_transpose1d( + kernel, + *padding, + out_padding, + *stride, + *dilation, + /* groups */ 1, + )?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&grad_arg)?; + + let grad_kernel = arg + .transpose(0, 1)? + .conv1d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)? + .transpose(0, 1)?; + let sum_grad = grads.or_insert(kernel)?; + let (_, _, k0) = kernel.dims3()?; + let (_, _, g_k0) = grad_kernel.dims3()?; + let grad_kernel = if g_k0 != k0 { + grad_kernel.narrow(2, 0, k0)? + } else { + grad_kernel + }; + *sum_grad = sum_grad.add(&grad_kernel)?; + } + Op::Conv2D { + arg, + kernel, + padding, + stride, + dilation, + } => { + // The output height for conv_transpose2d is: + // (i_h - 1) * stride - 2 * padding + dilation * (k_h - 1) + out_padding + 1 + let grad_h = grad.dim(2)?; + let k_h = kernel.dim(2)?; + let out_size = + (grad_h - 1) * stride + dilation * (k_h - 1) + 1 - 2 * padding; + let out_padding = arg.dim(2)? - out_size; + let grad_arg = grad.conv_transpose2d( + kernel, + *padding, + out_padding, + *stride, + *dilation, + )?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&grad_arg)?; + + let grad_kernel = arg + .transpose(0, 1)? + .conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)? + .transpose(0, 1)?; + let sum_grad = grads.or_insert(kernel)?; + let (_, _, k0, k1) = kernel.dims4()?; + let (_, _, g_k0, g_k1) = grad_kernel.dims4()?; + let grad_kernel = if g_k0 != k0 || g_k1 != k1 { + grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)? + } else { + grad_kernel + }; + *sum_grad = sum_grad.add(&grad_kernel)?; + } + Op::ConvTranspose1D { .. } => Err(Error::BackwardNotSupported { + op: "conv-transpose1d", + })?, + Op::ConvTranspose2D { + arg, + kernel, + padding, + stride, + dilation, + output_padding: _output_padding, + } => { + let grad_arg = grad.conv2d(kernel, *padding, *stride, *dilation, 1)?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&grad_arg)?; + + let grad_kernel = grad + .transpose(0, 1)? + .conv2d(&arg.transpose(0, 1)?, *padding, *dilation, *stride, 1)? + .transpose(0, 1)?; + let sum_grad = grads.or_insert(kernel)?; + let (_, _, k0, k1) = kernel.dims4()?; + let (_, _, g_k0, g_k1) = grad_kernel.dims4()?; + let grad_kernel = if g_k0 != k0 || g_k1 != k1 { + grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)? + } else { + grad_kernel + }; + *sum_grad = sum_grad.add(&grad_kernel)?; + } + Op::AvgPool2D { + arg, + kernel_size, + stride, + } => { + if kernel_size != stride { + crate::bail!("backward not supported for avgpool2d if ksize {kernel_size:?} != stride {stride:?}") + } + let (_n, _c, h, w) = arg.dims4()?; + let grad_arg = grad.upsample_nearest2d(h, w)?; + let grad_arg = + (grad_arg * (1f64 / (kernel_size.0 * kernel_size.1) as f64))?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&grad_arg)?; + } + Op::MaxPool2D { + arg, + kernel_size, + stride, + } => { + if kernel_size != stride { + crate::bail!("backward not supported for maxpool2d if ksize {kernel_size:?} != stride {stride:?}") + } + let (_n, _c, h, w) = arg.dims4()?; + // For computing the max-pool gradient, we compute a mask where a 1 means + // that the element is the maximum, then we apply this mask to the + // upsampled gradient (taking into account that multiple max may exist so + // we scale the gradient for this case). + let node_upsampled = node.upsample_nearest2d(h, w)?; + let mask = arg.eq(&node_upsampled)?.to_dtype(arg.dtype())?; + let avg = mask.avg_pool2d_with_stride(*kernel_size, *stride)?; + let grad_arg = ((grad * avg)?.upsample_nearest2d(h, w)? * mask)?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&grad_arg)?; + } + Op::UpsampleNearest1D { arg, target_size } => { + let (_n, c, size) = arg.dims3()?; + if target_size % size != 0 { + crate::bail!("backward not supported for non integer upscaling factors") + } + let scale = target_size / size; + + let kernel = Tensor::ones((c, 1, scale), arg.dtype(), arg.device())?; + let conv_sum = grad.conv1d(&kernel, 0, scale, 1, c)?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = conv_sum; + } + Op::UpsampleNearest2D { + arg, + target_h, + target_w, + } => { + let (_n, c, h, w) = arg.dims4()?; + if target_h % h != 0 || target_w % w != 0 { + crate::bail!("backward not supported for non integer upscaling factors") + } + let scale_h = target_h / h; + let scale_w = target_w / w; + + if scale_h != scale_w { + crate::bail!("backward not supported for non uniform upscaling factors") + }; + let kernel = + Tensor::ones((c, 1, scale_h, scale_w), arg.dtype(), arg.device())?; + let conv_sum = grad.conv2d(&kernel, 0, scale_h, 1, c)?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = conv_sum; + } + Op::SliceScatter0(lhs, rhs, start_rhs) => { + let rhs_sum_grad = grads.or_insert(rhs)?; + let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?; + *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; + + let lhs_sum_grad = grads.or_insert(lhs)?; + let lhs_grad = grad.slice_scatter0(&rhs.zeros_like()?, *start_rhs)?; + *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)? + } + Op::Gather(arg, indexes, dim) => { + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?; + } + Op::ScatterAdd(init, indexes, src, dim) => { + let init_sum_grad = grads.or_insert(init)?; + *init_sum_grad = init_sum_grad.add(&grad)?; + + let src_grad = grad.gather(indexes, *dim)?; + let src_sum_grad = grads.or_insert(src)?; + *src_sum_grad = src_sum_grad.add(&src_grad)?; + } + Op::IndexAdd(init, indexes, src, dim) => { + let init_sum_grad = grads.or_insert(init)?; + *init_sum_grad = init_sum_grad.add(&grad)?; + + let src_grad = grad.index_select(indexes, *dim)?; + let src_sum_grad = grads.or_insert(src)?; + *src_sum_grad = src_sum_grad.add(&src_grad)?; + } + Op::IndexSelect(arg, indexes, dim) => { + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.index_add(indexes, &grad, *dim)?; + } + Op::Matmul(lhs, rhs) => { + // Skipping checks, the op went ok, we can skip + // the matmul size checks for now. + + let lhs_grad = grad.matmul(&rhs.t()?)?; + let lhs_sum_grad = grads.or_insert(lhs)?; + *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?; + + let rhs_grad = lhs.t()?.matmul(&grad)?; + let rhs_sum_grad = grads.or_insert(rhs)?; + *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; + } + Op::Cat(args, dim) => { + let mut start_idx = 0; + for arg in args { + let len = arg.dims()[*dim]; + let arg_grad = grad.narrow(*dim, start_idx, len)?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&arg_grad)?; + start_idx += len; + } + } + Op::Broadcast(arg) => { + let arg_dims = arg.dims(); + let node_dims = node.dims(); + // The number of dims that have been inserted on the left. + let left_dims = node_dims.len() - arg_dims.len(); + let mut sum_dims: Vec = (0..left_dims).collect(); + for (dim, (node_dim, arg_dim)) in node_dims[left_dims..] + .iter() + .zip(arg_dims.iter()) + .enumerate() + { + if node_dim != arg_dim { + sum_dims.push(dim + left_dims) + } + } + + let mut arg_grad = grad.sum_keepdim(sum_dims.as_slice())?; + for _i in 0..left_dims { + arg_grad = arg_grad.squeeze(0)? + } + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&arg_grad.broadcast_as(sum_grad.dims())?)?; + } + Op::Reduce(arg, ReduceOp::Sum, reduced_dims) => { + let grad = broadcast_back(arg, &grad, reduced_dims)?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&grad)?; + } + Op::Reduce(arg, ReduceOp::Max, reduced_dims) => { + let node = broadcast_back(arg, node, reduced_dims)?; + let grad = broadcast_back(arg, &grad, reduced_dims)?; + let grad = node.eq(arg)?.to_dtype(grad.dtype())?.mul(&grad)?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&grad.broadcast_as(sum_grad.dims())?)?; + } + Op::Reduce(arg, ReduceOp::Min, reduced_dims) => { + let node = broadcast_back(arg, node, reduced_dims)?; + let grad = broadcast_back(arg, &grad, reduced_dims)?; + let grad = node.eq(arg)?.to_dtype(grad.dtype())?.mul(&grad)?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&grad.broadcast_as(sum_grad.dims())?)?; + } + Op::ToDType(arg) => { + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&grad.to_dtype(arg.dtype())?)? + } + Op::Copy(arg) => { + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&grad)? + } + Op::Affine { arg, mul, .. } => { + let arg_grad = grad.affine(*mul, 0.)?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&arg_grad)? + } + Op::Unary(arg, UnaryOp::Log) => { + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&(grad / arg)?)? + } + Op::Unary(arg, UnaryOp::Sin) => { + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&(&grad * arg.cos())?)? + } + Op::Unary(arg, UnaryOp::Cos) => { + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.sub(&(&grad * arg.sin())?)? + } + Op::Unary(arg, UnaryOp::Tanh) => { + let sum_grad = grads.or_insert(arg)?; + let minus_dtanh = (node.sqr()? - 1.)?; + *sum_grad = sum_grad.sub(&(&grad * &minus_dtanh)?)? + } + Op::Unary(arg, UnaryOp::Abs) => { + let sum_grad = grads.or_insert(arg)?; + let ones = arg.ones_like()?; + let abs_grad = arg.ge(&arg.zeros_like()?)?.where_cond(&ones, &ones.neg()?); + *sum_grad = sum_grad.add(&(&grad * abs_grad)?)? + } + Op::Unary(arg, UnaryOp::Exp) => { + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&(&grad * *node)?)? + } + Op::Unary(arg, UnaryOp::Neg) => { + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.sub(&grad)? + } + Op::Unary(arg, UnaryOp::Recip) => { + let sum_grad = grads.or_insert(arg)?; + let grad = (grad / arg.sqr()?)?; + *sum_grad = sum_grad.sub(&grad)? + } + &Op::Narrow(ref arg, dim, start_idx, len) => { + let arg_dims = arg.dims(); + let left_pad = if start_idx == 0 { + None + } else { + let mut dims = arg_dims.to_vec(); + dims[dim] = start_idx; + Some(Tensor::zeros(dims, grad.dtype(), grad.device())?) + }; + let right_pad = arg_dims[dim] - start_idx - len; + let right_pad = if right_pad == 0 { + None + } else { + let mut dims = arg_dims.to_vec(); + dims[dim] = right_pad; + Some(Tensor::zeros(dims, grad.dtype(), grad.device())?) + }; + let arg_grad = match (left_pad, right_pad) { + (None, None) => grad, + (Some(l), None) => Tensor::cat(&[&l, &grad], dim)?, + (None, Some(r)) => Tensor::cat(&[&grad, &r], dim)?, + (Some(l), Some(r)) => Tensor::cat(&[&l, &grad, &r], dim)?, + }; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&arg_grad)? + } + Op::Unary(_, UnaryOp::Floor) + | Op::Unary(_, UnaryOp::Round) + | Op::Reduce(_, ReduceOp::ArgMin, _) + | Op::Reduce(_, ReduceOp::ArgMax, _) + | Op::Unary(_, UnaryOp::Sign) + | Op::Cmp(_, _) => {} + Op::Reshape(arg) => { + let arg_grad = grad.reshape(arg.dims())?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&arg_grad)? + } + Op::Unary(_, UnaryOp::Ceil) => Err(Error::BackwardNotSupported { op: "ceil" })?, + Op::Unary(arg, UnaryOp::Gelu) => { + let sum_grad = grads.or_insert(arg)?; + let cube = arg.powf(3.)?; + let tanh = (0.0356774 * &cube + (0.797885 * arg)?)?.tanh()?; + let gelu_grad = (((0.5 * &tanh)? + + (0.0535161 * cube + (0.398942 * arg)?)? * (1. - tanh.powf(2.)?))? + + 0.5)?; + *sum_grad = sum_grad.add(&(&grad * gelu_grad)?)? + } + Op::Unary(arg, UnaryOp::Erf) => { + let sum_grad = grads.or_insert(arg)?; + // d/dx erf(x) = 2/sqrt(pi) * e^(-x^2) + let erf_grad = + (2. / std::f64::consts::PI.sqrt()) * (arg.sqr()?.neg()?).exp()?; + *sum_grad = sum_grad.add(&(&grad * erf_grad)?)? + } + Op::Unary(arg, UnaryOp::GeluErf) => { + let sum_grad = grads.or_insert(arg)?; + // d/dx gelu_erf(x) = 0.5 + 0.398942 e^(-x^2/2) x + 0.5 erf(x/sqrt(2)) + let neg_half_square = (arg.sqr()?.neg()? / 2.)?; + let scaled_exp_arg = (0.398942 * neg_half_square.exp()? * arg)?; + let arg_scaled_sqrt = (arg / 2f64.sqrt())?; + let erf_scaled_sqrt = (0.5 * arg_scaled_sqrt.erf()?)?; + let gelu_erf_grad = (0.5 + scaled_exp_arg + erf_scaled_sqrt)?; + *sum_grad = sum_grad.add(&(&grad * gelu_erf_grad)?)?; + } + Op::Unary(arg, UnaryOp::Relu) => { + let sum_grad = grads.or_insert(arg)?; + let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?; + *sum_grad = sum_grad.add(&(&grad * relu_grad)?)? + } + Op::Unary(arg, UnaryOp::Silu) => { + let sum_grad = grads.or_insert(arg)?; + // d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x))) = sigmoid(x) * (1 - node) + node + let sigmoid_arg = (arg.neg()?.exp()? + 1.)?.recip()?; + let silu_grad = &sigmoid_arg * (1. - *node) + *node; + *sum_grad = sum_grad.add(&(&grad * silu_grad)?)? + } + Op::Elu(arg, alpha) => { + // d/dx elu(x) = 1 for x > 0, alpha * e^x for x <= 0 + let sum_grad = grads.or_insert(arg)?; + let zeros = arg.zeros_like()?; + let positive_mask = arg.gt(&zeros)?.to_dtype(arg.dtype())?; + let negative_mask = arg.le(&zeros)?.to_dtype(arg.dtype())?; + // node == alpha * (e^x - 1) for x <= 0, reuse it + let negative_exp_mask = (negative_mask * (*node + *alpha))?; + let combined_mask = (positive_mask + negative_exp_mask)?; + *sum_grad = sum_grad.add(&(grad * combined_mask)?)? + } + Op::Powf(arg, e) => { + let arg_grad = (&(grad * arg.powf(e - 1.)?)? * *e)?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&arg_grad)? + } + Op::CustomOp1(arg, c) => { + if let Some(arg_grad) = c.bwd(arg, node, &grad)? { + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&arg_grad)? + } + } + Op::CustomOp2(arg1, arg2, c) => { + let (arg_grad1, arg_grad2) = c.bwd(arg1, arg2, node, &grad)?; + if let Some(arg_grad1) = arg_grad1 { + let sum_grad = grads.or_insert(arg1)?; + *sum_grad = sum_grad.add(&arg_grad1)? + } + if let Some(arg_grad2) = arg_grad2 { + let sum_grad = grads.or_insert(arg2)?; + *sum_grad = sum_grad.add(&arg_grad2)? + } + } + Op::CustomOp3(arg1, arg2, arg3, c) => { + let (arg_grad1, arg_grad2, arg_grad3) = + c.bwd(arg1, arg2, arg3, node, &grad)?; + if let Some(arg_grad1) = arg_grad1 { + let sum_grad = grads.or_insert(arg1)?; + *sum_grad = sum_grad.add(&arg_grad1)? + } + if let Some(arg_grad2) = arg_grad2 { + let sum_grad = grads.or_insert(arg2)?; + *sum_grad = sum_grad.add(&arg_grad2)? + } + if let Some(arg_grad3) = arg_grad3 { + let sum_grad = grads.or_insert(arg3)?; + *sum_grad = sum_grad.add(&arg_grad3)? + } + } + Op::Unary(arg, UnaryOp::Sqr) => { + let arg_grad = arg.mul(&grad)?.affine(2., 0.)?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&arg_grad)? + } + Op::Unary(arg, UnaryOp::Sqrt) => { + let arg_grad = grad.div(node)?.affine(0.5, 0.)?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&arg_grad)? + } + Op::ToDevice(arg) => { + let sum_grad = grads.or_insert(arg)?; + let arg_grad = grad.to_device(sum_grad.device())?; + *sum_grad = sum_grad.add(&arg_grad)? + } + Op::Transpose(arg, dim1, dim2) => { + let arg_grad = grad.transpose(*dim1, *dim2)?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&arg_grad)? + } + Op::Permute(arg, dims) => { + let mut inv_dims = vec![0; dims.len()]; + for (i, &dim_idx) in dims.iter().enumerate() { + inv_dims[dim_idx] = i + } + let arg_grad = grad.permute(inv_dims)?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&arg_grad)? + } + }; + } + } + Ok(grads) + } +} + +/// A store for gradients, associating a tensor id to the corresponding gradient tensor, used for back propagation. +#[derive(Debug)] +pub struct GradStore(HashMap); + +impl GradStore { + /// Create a new gradient store + fn new() -> Self { + GradStore(HashMap::new()) + } + + /// Get the gradient tensor corresponding to the given tensor id + pub fn get_id(&self, id: TensorId) -> Option<&Tensor> { + self.0.get(&id) + } + + /// Get the gradient tensor associated with the given tensor + pub fn get(&self, tensor: &Tensor) -> Option<&Tensor> { + self.0.get(&tensor.id()) + } + + /// Remove the gradient tensor associated with the given tensor, returning it if it exists + pub fn remove(&mut self, tensor: &Tensor) -> Option { + self.0.remove(&tensor.id()) + } + + /// Insert a gradient tensor associated with the given tensor, returning the previous gradient tensor if it existed + pub fn insert(&mut self, tensor: &Tensor, grad: Tensor) -> Option { + self.0.insert(tensor.id(), grad) + } + + /// Get the gradient tensor associated with the given tensor, or, if it does not exist, + /// insert a tensor of zeroes, with the same shape and type as the given tensors and return it + fn or_insert(&mut self, tensor: &Tensor) -> Result<&mut Tensor> { + use std::collections::hash_map::Entry; + let grad = match self.0.entry(tensor.id()) { + Entry::Occupied(entry) => entry.into_mut(), + Entry::Vacant(entry) => { + let grad = tensor.zeros_like()?; + entry.insert(grad) + } + }; + Ok(grad) + } + + /// Get the tensor ids of the stored gradient tensors + pub fn get_ids(&self) -> impl Iterator { + self.0.keys() + } +} diff --git a/mcandle/mcandle-core/src/conv.rs b/mcandle/mcandle-core/src/conv.rs new file mode 100644 index 000000000..7b3922dd7 --- /dev/null +++ b/mcandle/mcandle-core/src/conv.rs @@ -0,0 +1,358 @@ +use crate::{op::BackpropOp, op::Op, Error, Result, Tensor}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ParamsConv1D { + pub(crate) b_size: usize, + // Maybe we should have a version without l_in as this bit depends on the input and not only on + // the weights. + pub(crate) l_in: usize, + pub(crate) c_out: usize, + pub(crate) c_in: usize, + pub(crate) k_size: usize, + pub(crate) padding: usize, + pub(crate) stride: usize, + pub(crate) dilation: usize, +} + +impl ParamsConv1D { + pub(crate) fn l_out(&self) -> usize { + (self.l_in + 2 * self.padding - self.dilation * (self.k_size - 1) - 1) / self.stride + 1 + } + + pub(crate) fn out_dims(&self) -> Vec { + let l_out = self.l_out(); + vec![self.b_size, self.c_out, l_out] + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ParamsConvTranspose1D { + pub(crate) b_size: usize, + pub(crate) l_in: usize, + pub(crate) c_out: usize, + pub(crate) c_in: usize, + pub(crate) k_size: usize, + pub(crate) padding: usize, + pub(crate) output_padding: usize, + pub(crate) stride: usize, + pub(crate) dilation: usize, +} + +impl ParamsConvTranspose1D { + pub(crate) fn l_out(&self) -> usize { + (self.l_in - 1) * self.stride - 2 * self.padding + + self.dilation * (self.k_size - 1) + + self.output_padding + + 1 + } + + pub(crate) fn out_dims(&self) -> Vec { + let l_out = self.l_out(); + vec![self.b_size, self.c_out, l_out] + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum CudnnFwdAlgo { + ImplicitGemm, + ImplicitPrecompGemm, + Gemm, + Direct, + Fft, + FftTiling, + Winograd, + WinogradNonFused, + Count, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ParamsConv2D { + pub(crate) b_size: usize, + pub(crate) i_h: usize, + pub(crate) i_w: usize, + pub(crate) k_h: usize, + pub(crate) k_w: usize, + pub(crate) c_out: usize, + pub(crate) c_in: usize, + pub(crate) padding: usize, + pub(crate) stride: usize, + pub(crate) dilation: usize, + pub cudnn_fwd_algo: Option, +} + +impl ParamsConv2D { + pub(crate) fn out_h(&self) -> usize { + (self.i_h + 2 * self.padding - self.dilation * (self.k_h - 1) - 1) / self.stride + 1 + } + + pub(crate) fn out_w(&self) -> usize { + (self.i_w + 2 * self.padding - self.dilation * (self.k_w - 1) - 1) / self.stride + 1 + } + + pub(crate) fn out_dims(&self) -> Vec { + vec![self.b_size, self.c_out, self.out_h(), self.out_w()] + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ParamsConvTranspose2D { + pub(crate) b_size: usize, + pub(crate) i_h: usize, + pub(crate) i_w: usize, + pub(crate) k_h: usize, + pub(crate) k_w: usize, + pub(crate) c_out: usize, + pub(crate) c_in: usize, + pub(crate) padding: usize, + pub(crate) output_padding: usize, + pub(crate) stride: usize, + pub(crate) dilation: usize, +} + +impl ParamsConvTranspose2D { + pub(crate) fn out_h(&self) -> usize { + (self.i_h - 1) * self.stride + self.dilation * (self.k_h - 1) + self.output_padding + 1 + - 2 * self.padding + } + + pub(crate) fn out_w(&self) -> usize { + (self.i_w - 1) * self.stride + self.dilation * (self.k_w - 1) + self.output_padding + 1 + - 2 * self.padding + } + + pub(crate) fn out_dims(&self) -> Vec { + vec![self.b_size, self.c_out, self.out_h(), self.out_w()] + } +} + +impl Tensor { + fn conv1d_single_group(&self, kernel: &Self, params: &ParamsConv1D) -> Result { + let storage = + self.storage() + .conv1d(self.layout(), &kernel.storage(), kernel.layout(), params)?; + let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv1D { + arg, + kernel, + padding: params.padding, + stride: params.stride, + dilation: params.dilation, + }); + let out_dims = params.out_dims(); + Ok(crate::tensor::from_storage(storage, out_dims, op, false)) + } + + /// Applies a 1D convolution over the input tensor. + pub fn conv1d( + &self, + kernel: &Self, + padding: usize, + stride: usize, + dilation: usize, + groups: usize, + ) -> Result { + let (c_out, c_in_k, k_size) = kernel.dims3()?; + let (b_size, c_in, l_in) = self.dims3()?; + if c_in != c_in_k * groups { + Err(Error::Conv1dInvalidArgs { + inp_shape: self.shape().clone(), + k_shape: kernel.shape().clone(), + padding, + stride, + msg: "the number of in-channels on the input doesn't match the kernel size", + } + .bt())? + } + + let params = ParamsConv1D { + b_size, + l_in, + c_out: c_out / groups, + c_in: c_in / groups, + k_size, + padding, + stride, + dilation, + }; + if groups == 1 { + self.conv1d_single_group(kernel, ¶ms) + } else { + let blocks = self.chunk(groups, 1)?; + let kernel = kernel.chunk(groups, 0)?; + let blocks = blocks + .iter() + .zip(&kernel) + .map(|(block, kernel)| block.conv1d_single_group(kernel, ¶ms)) + .collect::>>()?; + Tensor::cat(&blocks, 1) + } + } + + fn conv_transpose1d_single_group( + &self, + kernel: &Self, + params: &ParamsConvTranspose1D, + ) -> Result { + let storage = self.storage().conv_transpose1d( + self.layout(), + &kernel.storage(), + kernel.layout(), + params, + )?; + let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose1D { + arg, + kernel, + padding: params.padding, + output_padding: params.output_padding, + stride: params.stride, + dilation: params.dilation, + }); + let out_dims = params.out_dims(); + Ok(crate::tensor::from_storage(storage, out_dims, op, false)) + } + + /// Applies a 1D transposed convolution over the input tensor. + pub fn conv_transpose1d( + &self, + kernel: &Self, + padding: usize, + output_padding: usize, + stride: usize, + dilation: usize, + groups: usize, + ) -> Result { + let (c_in_k, c_out, k_size) = kernel.dims3()?; + let (b_size, c_in, l_in) = self.dims3()?; + if c_in != c_in_k { + crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})") + } + if c_in % groups != 0 { + crate::bail!("in_channel {c_in} is not divisible by the number of groups") + } + let params = ParamsConvTranspose1D { + b_size, + l_in, + k_size, + c_out, + c_in: c_in / groups, + padding, + output_padding, + stride, + dilation, + }; + if groups == 1 { + self.conv_transpose1d_single_group(kernel, ¶ms) + } else { + let blocks = self.chunk(groups, 1)?; + let kernel = kernel.chunk(groups, 0)?; + let blocks = blocks + .iter() + .zip(&kernel) + .map(|(block, kernel)| block.conv_transpose1d_single_group(kernel, ¶ms)) + .collect::>>()?; + Tensor::cat(&blocks, 1) + } + } + + fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result { + let storage = + self.storage() + .conv2d(self.layout(), &kernel.storage(), kernel.layout(), params)?; + let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv2D { + arg, + kernel, + padding: params.padding, + stride: params.stride, + dilation: params.dilation, + }); + let out_dims = params.out_dims(); + Ok(crate::tensor::from_storage(storage, out_dims, op, false)) + } + + /// Applies a 2D convolution over the input tensor. + pub fn conv2d( + &self, + kernel: &Self, + padding: usize, + stride: usize, + dilation: usize, + groups: usize, + ) -> Result { + let (b_size, c_in, i_h, i_w) = self.dims4()?; + let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?; + if c_in != c_in_k * groups { + crate::bail!( + "in_channel mismatch between input ({c_in}, groups {groups}) and kernel ({c_in_k})" + ) + } + let params = ParamsConv2D { + b_size, + i_h, + i_w, + k_h, + k_w, + c_out: c_out / groups, + c_in: c_in / groups, + padding, + stride, + dilation, + cudnn_fwd_algo: None, + }; + if groups == 1 { + self.conv2d_single_group(kernel, ¶ms) + } else { + let blocks = self.chunk(groups, 1)?; + let kernel = kernel.chunk(groups, 0)?; + let blocks = blocks + .iter() + .zip(&kernel) + .map(|(block, kernel)| block.conv2d_single_group(kernel, ¶ms)) + .collect::>>()?; + Tensor::cat(&blocks, 1) + } + } + + /// Applies a 2D transposed convolution over the input tensor. + pub fn conv_transpose2d( + &self, + kernel: &Self, + padding: usize, + output_padding: usize, + stride: usize, + dilation: usize, + ) -> Result { + let (b_size, c_in, i_h, i_w) = self.dims4()?; + let (c_in_k, c_out, k_h, k_w) = kernel.dims4()?; + if c_in != c_in_k { + crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})") + } + let params = ParamsConvTranspose2D { + b_size, + i_h, + i_w, + k_h, + k_w, + c_out, + c_in, + padding, + output_padding, + stride, + dilation, + }; + let storage = self.storage().conv_transpose2d( + self.layout(), + &kernel.storage(), + kernel.layout(), + ¶ms, + )?; + let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose2D { + arg, + kernel, + padding: params.padding, + output_padding: params.output_padding, + stride: params.stride, + dilation: params.dilation, + }); + let out_dims = params.out_dims(); + Ok(crate::tensor::from_storage(storage, out_dims, op, false)) + } +} diff --git a/mcandle/mcandle-core/src/convert.rs b/mcandle/mcandle-core/src/convert.rs new file mode 100644 index 000000000..173a96d6e --- /dev/null +++ b/mcandle/mcandle-core/src/convert.rs @@ -0,0 +1,161 @@ +//! Implement conversion traits for tensors +use crate::{DType, Device, Error, Tensor, WithDType}; +use float8::F8E4M3; +use half::{bf16, f16, slice::HalfFloatSliceExt}; +use std::convert::TryFrom; + +impl TryFrom<&Tensor> for Vec { + type Error = Error; + fn try_from(tensor: &Tensor) -> Result { + tensor.to_vec1::() + } +} + +impl TryFrom<&Tensor> for Vec> { + type Error = Error; + fn try_from(tensor: &Tensor) -> Result { + tensor.to_vec2::() + } +} + +impl TryFrom<&Tensor> for Vec>> { + type Error = Error; + fn try_from(tensor: &Tensor) -> Result { + tensor.to_vec3::() + } +} + +impl TryFrom for Vec { + type Error = Error; + fn try_from(tensor: Tensor) -> Result { + Vec::::try_from(&tensor) + } +} + +impl TryFrom for Vec> { + type Error = Error; + fn try_from(tensor: Tensor) -> Result { + Vec::>::try_from(&tensor) + } +} + +impl TryFrom for Vec>> { + type Error = Error; + fn try_from(tensor: Tensor) -> Result { + Vec::>>::try_from(&tensor) + } +} + +impl TryFrom<&[T]> for Tensor { + type Error = Error; + fn try_from(v: &[T]) -> Result { + Tensor::from_slice(v, v.len(), &Device::Cpu) + } +} + +impl TryFrom> for Tensor { + type Error = Error; + fn try_from(v: Vec) -> Result { + let len = v.len(); + Tensor::from_vec(v, len, &Device::Cpu) + } +} + +macro_rules! from_tensor { + ($typ:ident) => { + impl TryFrom<&Tensor> for $typ { + type Error = Error; + + fn try_from(tensor: &Tensor) -> Result { + tensor.to_scalar::<$typ>() + } + } + + impl TryFrom for $typ { + type Error = Error; + + fn try_from(tensor: Tensor) -> Result { + $typ::try_from(&tensor) + } + } + + impl TryFrom<$typ> for Tensor { + type Error = Error; + + fn try_from(v: $typ) -> Result { + Tensor::new(v, &Device::Cpu) + } + } + }; +} + +from_tensor!(f64); +from_tensor!(f32); +from_tensor!(f16); +from_tensor!(bf16); +from_tensor!(i64); +from_tensor!(u32); +from_tensor!(u8); + +impl Tensor { + pub fn write_bytes(&self, f: &mut W) -> crate::Result<()> { + use byteorder::{LittleEndian, WriteBytesExt}; + + let vs = self.flatten_all()?; + match self.dtype() { + DType::BF16 => { + let vs = vs.to_vec1::()?; + for &v in vs.reinterpret_cast() { + f.write_u16::(v)? + } + } + DType::F16 => { + let vs = vs.to_vec1::()?; + for &v in vs.reinterpret_cast() { + f.write_u16::(v)? + } + } + DType::F32 => { + // TODO: Avoid using a buffer when data is already on the CPU. + for v in vs.to_vec1::()? { + f.write_f32::(v)? + } + } + DType::F64 => { + for v in vs.to_vec1::()? { + f.write_f64::(v)? + } + } + DType::U32 => { + for v in vs.to_vec1::()? { + f.write_u32::(v)? + } + } + DType::I16 => { + for v in vs.to_vec1::()? { + f.write_i16::(v)? + } + } + DType::I32 => { + for v in vs.to_vec1::()? { + f.write_i32::(v)? + } + } + DType::I64 => { + for v in vs.to_vec1::()? { + f.write_i64::(v)? + } + } + DType::U8 => { + let vs = vs.to_vec1::()?; + f.write_all(&vs)?; + } + DType::F8E4M3 => { + for v in vs.to_vec1::()? { + f.write_u8(v.to_bits())? + } + } + } + Ok(()) + } +} diff --git a/mcandle/mcandle-core/src/cpu/avx.rs b/mcandle/mcandle-core/src/cpu/avx.rs new file mode 100644 index 000000000..113fc14ce --- /dev/null +++ b/mcandle/mcandle-core/src/cpu/avx.rs @@ -0,0 +1,227 @@ +use super::{Cpu, CpuBF16, CpuF16}; +#[cfg(target_arch = "x86")] +use core::arch::x86::*; +#[cfg(target_arch = "x86_64")] +use core::arch::x86_64::*; + +use half::{bf16, f16}; + +pub struct CurrentCpu {} + +const STEP: usize = 32; +const EPR: usize = 8; +const ARR: usize = STEP / EPR; + +impl Cpu for CurrentCpu { + type Unit = __m256; + type Array = [__m256; ARR]; + + const STEP: usize = STEP; + const EPR: usize = EPR; + + fn n() -> usize { + ARR + } + + unsafe fn zero() -> Self::Unit { + _mm256_setzero_ps() + } + + unsafe fn zero_array() -> Self::Array { + [Self::zero(); ARR] + } + + unsafe fn from_f32(v: f32) -> Self::Unit { + _mm256_set1_ps(v) + } + + unsafe fn load(mem_addr: *const f32) -> Self::Unit { + _mm256_loadu_ps(mem_addr) + } + + unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit { + _mm256_add_ps(a, b) + } + + unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit { + _mm256_add_ps(_mm256_mul_ps(b, c), a) + } + + unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit) { + _mm256_storeu_ps(mem_addr, a); + } + + unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) { + for i in 0..ARR / 2 { + x[2 * i] = _mm256_add_ps(x[2 * i], x[2 * i + 1]); + } + for i in 0..ARR / 4 { + x[4 * i] = _mm256_add_ps(x[4 * i], x[4 * i + 2]); + } + #[allow(clippy::reversed_empty_ranges)] + for i in 0..ARR / 8 { + x[8 * i] = _mm256_add_ps(x[8 * i], x[8 * i + 4]); + } + let t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), _mm256_extractf128_ps(x[0], 1)); + let t1 = _mm_hadd_ps(t0, t0); + *y = _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); + } +} + +pub struct CurrentCpuF16 {} +impl CpuF16 for CurrentCpuF16 { + type Unit = __m256; + type Array = [__m256; ARR]; + + const STEP: usize = STEP; + const EPR: usize = EPR; + + fn n() -> usize { + ARR + } + + unsafe fn zero() -> Self::Unit { + _mm256_setzero_ps() + } + + unsafe fn zero_array() -> Self::Array { + [Self::zero(); ARR] + } + + unsafe fn from_f32(v: f32) -> Self::Unit { + _mm256_set1_ps(v) + } + + #[cfg(target_feature = "f16c")] + unsafe fn load(mem_addr: *const f16) -> Self::Unit { + _mm256_cvtph_ps(_mm_loadu_si128(mem_addr as *const __m128i)) + } + + #[cfg(not(target_feature = "f16c"))] + unsafe fn load(mem_addr: *const f16) -> Self::Unit { + let mut tmp = [0.0f32; 8]; + for i in 0..8 { + tmp[i] = (*mem_addr.add(i)).to_f32(); + } + _mm256_loadu_ps(tmp.as_ptr()) + } + + unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit { + _mm256_add_ps(a, b) + } + + unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit { + _mm256_add_ps(_mm256_mul_ps(b, c), a) + } + + #[cfg(target_feature = "f16c")] + unsafe fn vec_store(mem_addr: *mut f16, a: Self::Unit) { + _mm_storeu_si128(mem_addr as *mut __m128i, _mm256_cvtps_ph(a, 0)) + } + + #[cfg(not(target_feature = "f16c"))] + unsafe fn vec_store(mem_addr: *mut f16, a: Self::Unit) { + let mut tmp = [0.0f32; 8]; + _mm256_storeu_ps(tmp.as_mut_ptr(), a); + for i in 0..8 { + *mem_addr.add(i) = f16::from_f32(tmp[i]); + } + } + + unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) { + let mut offset = ARR >> 1; + for i in 0..offset { + x[i] = _mm256_add_ps(x[i], x[offset + i]); + } + offset >>= 1; + for i in 0..offset { + x[i] = _mm256_add_ps(x[i], x[offset + i]); + } + offset >>= 1; + for i in 0..offset { + x[i] = _mm256_add_ps(x[i], x[offset + i]); + } + let t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), _mm256_extractf128_ps(x[0], 1)); + let t1 = _mm_hadd_ps(t0, t0); + *y = _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); + } +} + +pub struct CurrentCpuBF16 {} +impl CpuBF16 for CurrentCpuBF16 { + type Unit = __m256; + type Array = [__m256; ARR]; + + const STEP: usize = STEP; + const EPR: usize = EPR; + + fn n() -> usize { + ARR + } + + unsafe fn zero() -> Self::Unit { + _mm256_setzero_ps() + } + + unsafe fn zero_array() -> Self::Array { + [Self::zero(); ARR] + } + + unsafe fn from_f32(v: f32) -> Self::Unit { + _mm256_set1_ps(v) + } + + #[cfg(target_feature = "f16c")] + unsafe fn load(mem_addr: *const bf16) -> Self::Unit { + _mm256_cvtph_ps(_mm_loadu_si128(mem_addr as *const __m128i)) + } + + #[cfg(not(target_feature = "f16c"))] + unsafe fn load(mem_addr: *const bf16) -> Self::Unit { + let mut tmp = [0.0f32; 8]; + for i in 0..8 { + tmp[i] = (*mem_addr.add(i)).to_f32(); + } + _mm256_loadu_ps(tmp.as_ptr()) + } + + unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit { + _mm256_add_ps(a, b) + } + + unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit { + _mm256_add_ps(_mm256_mul_ps(b, c), a) + } + + #[cfg(target_feature = "f16c")] + unsafe fn vec_store(mem_addr: *mut bf16, a: Self::Unit) { + _mm_storeu_si128(mem_addr as *mut __m128i, _mm256_cvtps_ph(a, 0)) + } + + #[cfg(not(target_feature = "f16c"))] + unsafe fn vec_store(mem_addr: *mut bf16, a: Self::Unit) { + let mut tmp = [0.0f32; 8]; + _mm256_storeu_ps(tmp.as_mut_ptr(), a); + for i in 0..8 { + *mem_addr.add(i) = bf16::from_f32(tmp[i]); + } + } + + unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) { + let mut offset = ARR >> 1; + for i in 0..offset { + x[i] = _mm256_add_ps(x[i], x[offset + i]); + } + offset >>= 1; + for i in 0..offset { + x[i] = _mm256_add_ps(x[i], x[offset + i]); + } + offset >>= 1; + for i in 0..offset { + x[i] = _mm256_add_ps(x[i], x[offset + i]); + } + let t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), _mm256_extractf128_ps(x[0], 1)); + let t1 = _mm_hadd_ps(t0, t0); + *y = _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); + } +} diff --git a/mcandle/mcandle-core/src/cpu/erf.rs b/mcandle/mcandle-core/src/cpu/erf.rs new file mode 100644 index 000000000..ca6be53fd --- /dev/null +++ b/mcandle/mcandle-core/src/cpu/erf.rs @@ -0,0 +1,763 @@ +#![allow(clippy::excessive_precision)] +// Code taken from https://github.com/statrs-dev/statrs +//! Provides the [error](https://en.wikipedia.org/wiki/Error_function) and +//! related functions + +mod evaluate { + //! Provides functions that don't have a numerical solution and must + //! be solved computationally (e.g. evaluation of a polynomial) + + /// evaluates a polynomial at `z` where `coeff` are the coeffecients + /// to a polynomial of order `k` where `k` is the length of `coeff` and the + /// coeffecient + /// to the `k`th power is the `k`th element in coeff. E.g. [3,-1,2] equates to + /// `2z^2 - z + 3` + /// + /// # Remarks + /// + /// Returns 0 for a 0 length coefficient slice + pub fn polynomial(z: f64, coeff: &[f64]) -> f64 { + let n = coeff.len(); + if n == 0 { + return 0.0; + } + + let mut sum = *coeff.last().unwrap(); + for c in coeff[0..n - 1].iter().rev() { + sum = *c + z * sum; + } + sum + } +} +use std::f64; + +/// `erf` calculates the error function at `x`. +pub fn erf(x: f64) -> f64 { + if x.is_nan() { + f64::NAN + } else if x >= 0.0 && x.is_infinite() { + 1.0 + } else if x <= 0.0 && x.is_infinite() { + -1.0 + } else if x == 0. { + 0.0 + } else { + erf_impl(x, false) + } +} + +/// `erf_inv` calculates the inverse error function +/// at `x`. +pub fn erf_inv(x: f64) -> f64 { + if x == 0.0 { + 0.0 + } else if x >= 1.0 { + f64::INFINITY + } else if x <= -1.0 { + f64::NEG_INFINITY + } else if x < 0.0 { + erf_inv_impl(-x, 1.0 + x, -1.0) + } else { + erf_inv_impl(x, 1.0 - x, 1.0) + } +} + +/// `erfc` calculates the complementary error function +/// at `x`. +pub fn erfc(x: f64) -> f64 { + if x.is_nan() { + f64::NAN + } else if x == f64::INFINITY { + 0.0 + } else if x == f64::NEG_INFINITY { + 2.0 + } else { + erf_impl(x, true) + } +} + +/// `erfc_inv` calculates the complementary inverse +/// error function at `x`. +pub fn erfc_inv(x: f64) -> f64 { + if x <= 0.0 { + f64::INFINITY + } else if x >= 2.0 { + f64::NEG_INFINITY + } else if x > 1.0 { + erf_inv_impl(-1.0 + x, 2.0 - x, -1.0) + } else { + erf_inv_impl(1.0 - x, x, 1.0) + } +} + +// ********************************************************** +// ********** Coefficients for erf_impl polynomial ********** +// ********************************************************** + +/// Polynomial coefficients for a numerator of `erf_impl` +/// in the interval [1e-10, 0.5]. +const ERF_IMPL_AN: &[f64] = &[ + 0.00337916709551257388990745, + -0.00073695653048167948530905, + -0.374732337392919607868241, + 0.0817442448733587196071743, + -0.0421089319936548595203468, + 0.0070165709512095756344528, + -0.00495091255982435110337458, + 0.000871646599037922480317225, +]; + +/// Polynomial coefficients for a denominator of `erf_impl` +/// in the interval [1e-10, 0.5] +const ERF_IMPL_AD: &[f64] = &[ + 1.0, + -0.218088218087924645390535, + 0.412542972725442099083918, + -0.0841891147873106755410271, + 0.0655338856400241519690695, + -0.0120019604454941768171266, + 0.00408165558926174048329689, + -0.000615900721557769691924509, +]; + +/// Polynomial coefficients for a numerator in `erf_impl` +/// in the interval [0.5, 0.75]. +const ERF_IMPL_BN: &[f64] = &[ + -0.0361790390718262471360258, + 0.292251883444882683221149, + 0.281447041797604512774415, + 0.125610208862766947294894, + 0.0274135028268930549240776, + 0.00250839672168065762786937, +]; + +/// Polynomial coefficients for a denominator in `erf_impl` +/// in the interval [0.5, 0.75]. +const ERF_IMPL_BD: &[f64] = &[ + 1.0, + 1.8545005897903486499845, + 1.43575803037831418074962, + 0.582827658753036572454135, + 0.124810476932949746447682, + 0.0113724176546353285778481, +]; + +/// Polynomial coefficients for a numerator in `erf_impl` +/// in the interval [0.75, 1.25]. +const ERF_IMPL_CN: &[f64] = &[ + -0.0397876892611136856954425, + 0.153165212467878293257683, + 0.191260295600936245503129, + 0.10276327061989304213645, + 0.029637090615738836726027, + 0.0046093486780275489468812, + 0.000307607820348680180548455, +]; + +/// Polynomial coefficients for a denominator in `erf_impl` +/// in the interval [0.75, 1.25]. +const ERF_IMPL_CD: &[f64] = &[ + 1.0, + 1.95520072987627704987886, + 1.64762317199384860109595, + 0.768238607022126250082483, + 0.209793185936509782784315, + 0.0319569316899913392596356, + 0.00213363160895785378615014, +]; + +/// Polynomial coefficients for a numerator in `erf_impl` +/// in the interval [1.25, 2.25]. +const ERF_IMPL_DN: &[f64] = &[ + -0.0300838560557949717328341, + 0.0538578829844454508530552, + 0.0726211541651914182692959, + 0.0367628469888049348429018, + 0.00964629015572527529605267, + 0.00133453480075291076745275, + 0.778087599782504251917881e-4, +]; + +/// Polynomial coefficients for a denominator in `erf_impl` +/// in the interval [1.25, 2.25]. +const ERF_IMPL_DD: &[f64] = &[ + 1.0, + 1.75967098147167528287343, + 1.32883571437961120556307, + 0.552528596508757581287907, + 0.133793056941332861912279, + 0.0179509645176280768640766, + 0.00104712440019937356634038, + -0.106640381820357337177643e-7, +]; + +/// Polynomial coefficients for a numerator in `erf_impl` +/// in the interval [2.25, 3.5]. +const ERF_IMPL_EN: &[f64] = &[ + -0.0117907570137227847827732, + 0.014262132090538809896674, + 0.0202234435902960820020765, + 0.00930668299990432009042239, + 0.00213357802422065994322516, + 0.00025022987386460102395382, + 0.120534912219588189822126e-4, +]; + +/// Polynomial coefficients for a denominator in `erf_impl` +/// in the interval [2.25, 3.5]. +const ERF_IMPL_ED: &[f64] = &[ + 1.0, + 1.50376225203620482047419, + 0.965397786204462896346934, + 0.339265230476796681555511, + 0.0689740649541569716897427, + 0.00771060262491768307365526, + 0.000371421101531069302990367, +]; + +/// Polynomial coefficients for a numerator in `erf_impl` +/// in the interval [3.5, 5.25]. +const ERF_IMPL_FN: &[f64] = &[ + -0.00546954795538729307482955, + 0.00404190278731707110245394, + 0.0054963369553161170521356, + 0.00212616472603945399437862, + 0.000394984014495083900689956, + 0.365565477064442377259271e-4, + 0.135485897109932323253786e-5, +]; + +/// Polynomial coefficients for a denominator in `erf_impl` +/// in the interval [3.5, 5.25]. +const ERF_IMPL_FD: &[f64] = &[ + 1.0, + 1.21019697773630784832251, + 0.620914668221143886601045, + 0.173038430661142762569515, + 0.0276550813773432047594539, + 0.00240625974424309709745382, + 0.891811817251336577241006e-4, + -0.465528836283382684461025e-11, +]; + +/// Polynomial coefficients for a numerator in `erf_impl` +/// in the interval [5.25, 8]. +const ERF_IMPL_GN: &[f64] = &[ + -0.00270722535905778347999196, + 0.0013187563425029400461378, + 0.00119925933261002333923989, + 0.00027849619811344664248235, + 0.267822988218331849989363e-4, + 0.923043672315028197865066e-6, +]; + +/// Polynomial coefficients for a denominator in `erf_impl` +/// in the interval [5.25, 8]. +const ERF_IMPL_GD: &[f64] = &[ + 1.0, + 0.814632808543141591118279, + 0.268901665856299542168425, + 0.0449877216103041118694989, + 0.00381759663320248459168994, + 0.000131571897888596914350697, + 0.404815359675764138445257e-11, +]; + +/// Polynomial coefficients for a numerator in `erf_impl` +/// in the interval [8, 11.5]. +const ERF_IMPL_HN: &[f64] = &[ + -0.00109946720691742196814323, + 0.000406425442750422675169153, + 0.000274499489416900707787024, + 0.465293770646659383436343e-4, + 0.320955425395767463401993e-5, + 0.778286018145020892261936e-7, +]; + +/// Polynomial coefficients for a denominator in `erf_impl` +/// in the interval [8, 11.5]. +const ERF_IMPL_HD: &[f64] = &[ + 1.0, + 0.588173710611846046373373, + 0.139363331289409746077541, + 0.0166329340417083678763028, + 0.00100023921310234908642639, + 0.24254837521587225125068e-4, +]; + +/// Polynomial coefficients for a numerator in `erf_impl` +/// in the interval [11.5, 17]. +const ERF_IMPL_IN: &[f64] = &[ + -0.00056907993601094962855594, + 0.000169498540373762264416984, + 0.518472354581100890120501e-4, + 0.382819312231928859704678e-5, + 0.824989931281894431781794e-7, +]; + +/// Polynomial coefficients for a denominator in `erf_impl` +/// in the interval [11.5, 17]. +const ERF_IMPL_ID: &[f64] = &[ + 1.0, + 0.339637250051139347430323, + 0.043472647870310663055044, + 0.00248549335224637114641629, + 0.535633305337152900549536e-4, + -0.117490944405459578783846e-12, +]; + +/// Polynomial coefficients for a numerator in `erf_impl` +/// in the interval [17, 24]. +const ERF_IMPL_JN: &[f64] = &[ + -0.000241313599483991337479091, + 0.574224975202501512365975e-4, + 0.115998962927383778460557e-4, + 0.581762134402593739370875e-6, + 0.853971555085673614607418e-8, +]; + +/// Polynomial coefficients for a denominator in `erf_impl` +/// in the interval [17, 24]. +const ERF_IMPL_JD: &[f64] = &[ + 1.0, + 0.233044138299687841018015, + 0.0204186940546440312625597, + 0.000797185647564398289151125, + 0.117019281670172327758019e-4, +]; + +/// Polynomial coefficients for a numerator in `erf_impl` +/// in the interval [24, 38]. +const ERF_IMPL_KN: &[f64] = &[ + -0.000146674699277760365803642, + 0.162666552112280519955647e-4, + 0.269116248509165239294897e-5, + 0.979584479468091935086972e-7, + 0.101994647625723465722285e-8, +]; + +/// Polynomial coefficients for a denominator in `erf_impl` +/// in the interval [24, 38]. +const ERF_IMPL_KD: &[f64] = &[ + 1.0, + 0.165907812944847226546036, + 0.0103361716191505884359634, + 0.000286593026373868366935721, + 0.298401570840900340874568e-5, +]; + +/// Polynomial coefficients for a numerator in `erf_impl` +/// in the interval [38, 60]. +const ERF_IMPL_LN: &[f64] = &[ + -0.583905797629771786720406e-4, + 0.412510325105496173512992e-5, + 0.431790922420250949096906e-6, + 0.993365155590013193345569e-8, + 0.653480510020104699270084e-10, +]; + +/// Polynomial coefficients for a denominator in `erf_impl` +/// in the interval [38, 60]. +const ERF_IMPL_LD: &[f64] = &[ + 1.0, + 0.105077086072039915406159, + 0.00414278428675475620830226, + 0.726338754644523769144108e-4, + 0.477818471047398785369849e-6, +]; + +/// Polynomial coefficients for a numerator in `erf_impl` +/// in the interval [60, 85]. +const ERF_IMPL_MN: &[f64] = &[ + -0.196457797609229579459841e-4, + 0.157243887666800692441195e-5, + 0.543902511192700878690335e-7, + 0.317472492369117710852685e-9, +]; + +/// Polynomial coefficients for a denominator in `erf_impl` +/// in the interval [60, 85]. +const ERF_IMPL_MD: &[f64] = &[ + 1.0, + 0.052803989240957632204885, + 0.000926876069151753290378112, + 0.541011723226630257077328e-5, + 0.535093845803642394908747e-15, +]; + +/// Polynomial coefficients for a numerator in `erf_impl` +/// in the interval [85, 110]. +const ERF_IMPL_NN: &[f64] = &[ + -0.789224703978722689089794e-5, + 0.622088451660986955124162e-6, + 0.145728445676882396797184e-7, + 0.603715505542715364529243e-10, +]; + +/// Polynomial coefficients for a denominator in `erf_impl` +/// in the interval [85, 110]. +const ERF_IMPL_ND: &[f64] = &[ + 1.0, + 0.0375328846356293715248719, + 0.000467919535974625308126054, + 0.193847039275845656900547e-5, +]; + +// ********************************************************** +// ********** Coefficients for erf_inv_impl polynomial ****** +// ********************************************************** + +/// Polynomial coefficients for a numerator of `erf_inv_impl` +/// in the interval [0, 0.5]. +const ERF_INV_IMPL_AN: &[f64] = &[ + -0.000508781949658280665617, + -0.00836874819741736770379, + 0.0334806625409744615033, + -0.0126926147662974029034, + -0.0365637971411762664006, + 0.0219878681111168899165, + 0.00822687874676915743155, + -0.00538772965071242932965, +]; + +/// Polynomial coefficients for a denominator of `erf_inv_impl` +/// in the interval [0, 0.5]. +const ERF_INV_IMPL_AD: &[f64] = &[ + 1.0, + -0.970005043303290640362, + -1.56574558234175846809, + 1.56221558398423026363, + 0.662328840472002992063, + -0.71228902341542847553, + -0.0527396382340099713954, + 0.0795283687341571680018, + -0.00233393759374190016776, + 0.000886216390456424707504, +]; + +/// Polynomial coefficients for a numerator of `erf_inv_impl` +/// in the interval [0.5, 0.75]. +const ERF_INV_IMPL_BN: &[f64] = &[ + -0.202433508355938759655, + 0.105264680699391713268, + 8.37050328343119927838, + 17.6447298408374015486, + -18.8510648058714251895, + -44.6382324441786960818, + 17.445385985570866523, + 21.1294655448340526258, + -3.67192254707729348546, +]; + +/// Polynomial coefficients for a denominator of `erf_inv_impl` +/// in the interval [0.5, 0.75]. +const ERF_INV_IMPL_BD: &[f64] = &[ + 1.0, + 6.24264124854247537712, + 3.9713437953343869095, + -28.6608180499800029974, + -20.1432634680485188801, + 48.5609213108739935468, + 10.8268667355460159008, + -22.6436933413139721736, + 1.72114765761200282724, +]; + +/// Polynomial coefficients for a numerator of `erf_inv_impl` +/// in the interval [0.75, 1] with x less than 3. +const ERF_INV_IMPL_CN: &[f64] = &[ + -0.131102781679951906451, + -0.163794047193317060787, + 0.117030156341995252019, + 0.387079738972604337464, + 0.337785538912035898924, + 0.142869534408157156766, + 0.0290157910005329060432, + 0.00214558995388805277169, + -0.679465575181126350155e-6, + 0.285225331782217055858e-7, + -0.681149956853776992068e-9, +]; + +/// Polynomial coefficients for a denominator of `erf_inv_impl` +/// in the interval [0.75, 1] with x less than 3. +const ERF_INV_IMPL_CD: &[f64] = &[ + 1.0, + 3.46625407242567245975, + 5.38168345707006855425, + 4.77846592945843778382, + 2.59301921623620271374, + 0.848854343457902036425, + 0.152264338295331783612, + 0.01105924229346489121, +]; + +/// Polynomial coefficients for a numerator of `erf_inv_impl` +/// in the interval [0.75, 1] with x between 3 and 6. +const ERF_INV_IMPL_DN: &[f64] = &[ + -0.0350353787183177984712, + -0.00222426529213447927281, + 0.0185573306514231072324, + 0.00950804701325919603619, + 0.00187123492819559223345, + 0.000157544617424960554631, + 0.460469890584317994083e-5, + -0.230404776911882601748e-9, + 0.266339227425782031962e-11, +]; + +/// Polynomial coefficients for a denominator of `erf_inv_impl` +/// in the interval [0.75, 1] with x between 3 and 6. +const ERF_INV_IMPL_DD: &[f64] = &[ + 1.0, + 1.3653349817554063097, + 0.762059164553623404043, + 0.220091105764131249824, + 0.0341589143670947727934, + 0.00263861676657015992959, + 0.764675292302794483503e-4, +]; + +/// Polynomial coefficients for a numerator of `erf_inv_impl` +/// in the interval [0.75, 1] with x between 6 and 18. +const ERF_INV_IMPL_EN: &[f64] = &[ + -0.0167431005076633737133, + -0.00112951438745580278863, + 0.00105628862152492910091, + 0.000209386317487588078668, + 0.149624783758342370182e-4, + 0.449696789927706453732e-6, + 0.462596163522878599135e-8, + -0.281128735628831791805e-13, + 0.99055709973310326855e-16, +]; + +/// Polynomial coefficients for a denominator of `erf_inv_impl` +/// in the interval [0.75, 1] with x between 6 and 18. +const ERF_INV_IMPL_ED: &[f64] = &[ + 1.0, + 0.591429344886417493481, + 0.138151865749083321638, + 0.0160746087093676504695, + 0.000964011807005165528527, + 0.275335474764726041141e-4, + 0.282243172016108031869e-6, +]; + +/// Polynomial coefficients for a numerator of `erf_inv_impl` +/// in the interval [0.75, 1] with x between 18 and 44. +const ERF_INV_IMPL_FN: &[f64] = &[ + -0.0024978212791898131227, + -0.779190719229053954292e-5, + 0.254723037413027451751e-4, + 0.162397777342510920873e-5, + 0.396341011304801168516e-7, + 0.411632831190944208473e-9, + 0.145596286718675035587e-11, + -0.116765012397184275695e-17, +]; + +/// Polynomial coefficients for a denominator of `erf_inv_impl` +/// in the interval [0.75, 1] with x between 18 and 44. +const ERF_INV_IMPL_FD: &[f64] = &[ + 1.0, + 0.207123112214422517181, + 0.0169410838120975906478, + 0.000690538265622684595676, + 0.145007359818232637924e-4, + 0.144437756628144157666e-6, + 0.509761276599778486139e-9, +]; + +/// Polynomial coefficients for a numerator of `erf_inv_impl` +/// in the interval [0.75, 1] with x greater than 44. +const ERF_INV_IMPL_GN: &[f64] = &[ + -0.000539042911019078575891, + -0.28398759004727721098e-6, + 0.899465114892291446442e-6, + 0.229345859265920864296e-7, + 0.225561444863500149219e-9, + 0.947846627503022684216e-12, + 0.135880130108924861008e-14, + -0.348890393399948882918e-21, +]; + +/// Polynomial coefficients for a denominator of `erf_inv_impl` +/// in the interval [0.75, 1] with x greater than 44. +const ERF_INV_IMPL_GD: &[f64] = &[ + 1.0, + 0.0845746234001899436914, + 0.00282092984726264681981, + 0.468292921940894236786e-4, + 0.399968812193862100054e-6, + 0.161809290887904476097e-8, + 0.231558608310259605225e-11, +]; + +/// `erf_impl` computes the error function at `z`. +/// If `inv` is true, `1 - erf` is calculated as opposed to `erf` +fn erf_impl(z: f64, inv: bool) -> f64 { + if z < 0.0 { + if !inv { + return -erf_impl(-z, false); + } + if z < -0.5 { + return 2.0 - erf_impl(-z, true); + } + return 1.0 + erf_impl(-z, false); + } + + let result = if z < 0.5 { + if z < 1e-10 { + z * 1.125 + z * 0.003379167095512573896158903121545171688 + } else { + z * 1.125 + + z * evaluate::polynomial(z, ERF_IMPL_AN) / evaluate::polynomial(z, ERF_IMPL_AD) + } + } else if z < 110.0 { + let (r, b) = if z < 0.75 { + ( + evaluate::polynomial(z - 0.5, ERF_IMPL_BN) + / evaluate::polynomial(z - 0.5, ERF_IMPL_BD), + 0.3440242112, + ) + } else if z < 1.25 { + ( + evaluate::polynomial(z - 0.75, ERF_IMPL_CN) + / evaluate::polynomial(z - 0.75, ERF_IMPL_CD), + 0.419990927, + ) + } else if z < 2.25 { + ( + evaluate::polynomial(z - 1.25, ERF_IMPL_DN) + / evaluate::polynomial(z - 1.25, ERF_IMPL_DD), + 0.4898625016, + ) + } else if z < 3.5 { + ( + evaluate::polynomial(z - 2.25, ERF_IMPL_EN) + / evaluate::polynomial(z - 2.25, ERF_IMPL_ED), + 0.5317370892, + ) + } else if z < 5.25 { + ( + evaluate::polynomial(z - 3.5, ERF_IMPL_FN) + / evaluate::polynomial(z - 3.5, ERF_IMPL_FD), + 0.5489973426, + ) + } else if z < 8.0 { + ( + evaluate::polynomial(z - 5.25, ERF_IMPL_GN) + / evaluate::polynomial(z - 5.25, ERF_IMPL_GD), + 0.5571740866, + ) + } else if z < 11.5 { + ( + evaluate::polynomial(z - 8.0, ERF_IMPL_HN) + / evaluate::polynomial(z - 8.0, ERF_IMPL_HD), + 0.5609807968, + ) + } else if z < 17.0 { + ( + evaluate::polynomial(z - 11.5, ERF_IMPL_IN) + / evaluate::polynomial(z - 11.5, ERF_IMPL_ID), + 0.5626493692, + ) + } else if z < 24.0 { + ( + evaluate::polynomial(z - 17.0, ERF_IMPL_JN) + / evaluate::polynomial(z - 17.0, ERF_IMPL_JD), + 0.5634598136, + ) + } else if z < 38.0 { + ( + evaluate::polynomial(z - 24.0, ERF_IMPL_KN) + / evaluate::polynomial(z - 24.0, ERF_IMPL_KD), + 0.5638477802, + ) + } else if z < 60.0 { + ( + evaluate::polynomial(z - 38.0, ERF_IMPL_LN) + / evaluate::polynomial(z - 38.0, ERF_IMPL_LD), + 0.5640528202, + ) + } else if z < 85.0 { + ( + evaluate::polynomial(z - 60.0, ERF_IMPL_MN) + / evaluate::polynomial(z - 60.0, ERF_IMPL_MD), + 0.5641309023, + ) + } else { + ( + evaluate::polynomial(z - 85.0, ERF_IMPL_NN) + / evaluate::polynomial(z - 85.0, ERF_IMPL_ND), + 0.5641584396, + ) + }; + let g = (-z * z).exp() / z; + g * b + g * r + } else { + 0.0 + }; + + if inv && z >= 0.5 { + result + } else if z >= 0.5 || inv { + 1.0 - result + } else { + result + } +} + +// `erf_inv_impl` computes the inverse error function where +// `p`,`q`, and `s` are the first, second, and third intermediate +// parameters respectively +fn erf_inv_impl(p: f64, q: f64, s: f64) -> f64 { + let result = if p <= 0.5 { + let y = 0.0891314744949340820313; + let g = p * (p + 10.0); + let r = evaluate::polynomial(p, ERF_INV_IMPL_AN) / evaluate::polynomial(p, ERF_INV_IMPL_AD); + g * y + g * r + } else if q >= 0.25 { + let y = 2.249481201171875; + let g = (-2.0 * q.ln()).sqrt(); + let xs = q - 0.25; + let r = + evaluate::polynomial(xs, ERF_INV_IMPL_BN) / evaluate::polynomial(xs, ERF_INV_IMPL_BD); + g / (y + r) + } else { + let x = (-q.ln()).sqrt(); + if x < 3.0 { + let y = 0.807220458984375; + let xs = x - 1.125; + let r = evaluate::polynomial(xs, ERF_INV_IMPL_CN) + / evaluate::polynomial(xs, ERF_INV_IMPL_CD); + y * x + r * x + } else if x < 6.0 { + let y = 0.93995571136474609375; + let xs = x - 3.0; + let r = evaluate::polynomial(xs, ERF_INV_IMPL_DN) + / evaluate::polynomial(xs, ERF_INV_IMPL_DD); + y * x + r * x + } else if x < 18.0 { + let y = 0.98362827301025390625; + let xs = x - 6.0; + let r = evaluate::polynomial(xs, ERF_INV_IMPL_EN) + / evaluate::polynomial(xs, ERF_INV_IMPL_ED); + y * x + r * x + } else if x < 44.0 { + let y = 0.99714565277099609375; + let xs = x - 18.0; + let r = evaluate::polynomial(xs, ERF_INV_IMPL_FN) + / evaluate::polynomial(xs, ERF_INV_IMPL_FD); + y * x + r * x + } else { + let y = 0.99941349029541015625; + let xs = x - 44.0; + let r = evaluate::polynomial(xs, ERF_INV_IMPL_GN) + / evaluate::polynomial(xs, ERF_INV_IMPL_GD); + y * x + r * x + } + }; + s * result +} diff --git a/mcandle/mcandle-core/src/cpu/kernels.rs b/mcandle/mcandle-core/src/cpu/kernels.rs new file mode 100644 index 000000000..f81ad625d --- /dev/null +++ b/mcandle/mcandle-core/src/cpu/kernels.rs @@ -0,0 +1,220 @@ +pub trait VecOps: num_traits::NumAssign + Copy { + fn min(self, rhs: Self) -> Self; + fn max(self, rhs: Self) -> Self; + + /// Dot-product of two vectors. + /// + /// # Safety + /// + /// The length of `lhs` and `rhs` have to be at least `len`. `res` has to point to a valid + /// element. + #[inline(always)] + unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) { + *res = Self::zero(); + for i in 0..len { + *res += *lhs.add(i) * *rhs.add(i) + } + } + + /// Sum of all elements in a vector. + /// + /// # Safety + /// + /// The length of `xs` must be at least `len`. `res` has to point to a valid + /// element. + #[inline(always)] + unsafe fn vec_reduce_sum(xs: *const Self, res: *mut Self, len: usize) { + *res = Self::zero(); + for i in 0..len { + *res += *xs.add(i) + } + } + + /// Maximum element in a non-empty vector. + /// + /// # Safety + /// + /// The length of `xs` must be at least `len` and positive. `res` has to point to a valid + /// element. + #[inline(always)] + unsafe fn vec_reduce_max(xs: *const Self, res: *mut Self, len: usize) { + *res = *xs; + for i in 1..len { + *res = (*res).max(*xs.add(i)) + } + } + + /// Minimum element in a non-empty vector. + /// + /// # Safety + /// + /// The length of `xs` must be at least `len` and positive. `res` has to point to a valid + /// element. + #[inline(always)] + unsafe fn vec_reduce_min(xs: *const Self, res: *mut Self, len: usize) { + *res = *xs; + for i in 1..len { + *res = (*res).min(*xs.add(i)) + } + } +} + +impl VecOps for f32 { + #[inline(always)] + fn min(self, other: Self) -> Self { + Self::min(self, other) + } + + #[inline(always)] + fn max(self, other: Self) -> Self { + Self::max(self, other) + } + + #[inline(always)] + unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) { + super::vec_dot_f32(lhs, rhs, res, len) + } + + #[inline(always)] + unsafe fn vec_reduce_sum(xs: *const Self, res: *mut Self, len: usize) { + super::vec_sum(xs, res, len) + } +} + +impl VecOps for half::f16 { + #[inline(always)] + fn min(self, other: Self) -> Self { + Self::min(self, other) + } + + #[inline(always)] + fn max(self, other: Self) -> Self { + Self::max(self, other) + } + + #[inline(always)] + unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) { + let mut res_f32 = 0f32; + super::vec_dot_f16(lhs, rhs, &mut res_f32, len); + *res = half::f16::from_f32(res_f32); + } +} + +impl VecOps for f64 { + #[inline(always)] + fn min(self, other: Self) -> Self { + Self::min(self, other) + } + + #[inline(always)] + fn max(self, other: Self) -> Self { + Self::max(self, other) + } +} +impl VecOps for half::bf16 { + #[inline(always)] + fn min(self, other: Self) -> Self { + Self::min(self, other) + } + + #[inline(always)] + fn max(self, other: Self) -> Self { + Self::max(self, other) + } + + #[inline(always)] + unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) { + let mut res_f32 = 0f32; + super::vec_dot_bf16(lhs, rhs, &mut res_f32, len); + *res = half::bf16::from_f32(res_f32); + } +} +impl VecOps for u8 { + #[inline(always)] + fn min(self, other: Self) -> Self { + ::min(self, other) + } + + #[inline(always)] + fn max(self, other: Self) -> Self { + ::max(self, other) + } +} +impl VecOps for u32 { + #[inline(always)] + fn min(self, other: Self) -> Self { + ::min(self, other) + } + + #[inline(always)] + fn max(self, other: Self) -> Self { + ::max(self, other) + } +} +impl VecOps for i16 { + #[inline(always)] + fn min(self, other: Self) -> Self { + ::min(self, other) + } + + #[inline(always)] + fn max(self, other: Self) -> Self { + ::max(self, other) + } +} +impl VecOps for i32 { + #[inline(always)] + fn min(self, other: Self) -> Self { + ::min(self, other) + } + + #[inline(always)] + fn max(self, other: Self) -> Self { + ::max(self, other) + } +} +impl VecOps for i64 { + #[inline(always)] + fn min(self, other: Self) -> Self { + ::min(self, other) + } + + #[inline(always)] + fn max(self, other: Self) -> Self { + ::max(self, other) + } +} + +#[inline(always)] +pub fn par_for_each(n_threads: usize, func: impl Fn(usize) + Send + Sync) { + if n_threads == 1 { + func(0) + } else { + rayon::scope(|s| { + for thread_idx in 0..n_threads { + let func = &func; + s.spawn(move |_| func(thread_idx)); + } + }) + } +} + +#[inline(always)] +pub fn par_range(lo: usize, up: usize, n_threads: usize, func: impl Fn(usize) + Send + Sync) { + if n_threads == 1 { + for i in lo..up { + func(i) + } + } else { + rayon::scope(|s| { + for thread_idx in 0..n_threads { + let func = &func; + s.spawn(move |_| { + for i in (thread_idx..up).step_by(n_threads) { + func(i) + } + }); + } + }) + } +} diff --git a/mcandle/mcandle-core/src/cpu/mod.rs b/mcandle/mcandle-core/src/cpu/mod.rs new file mode 100644 index 000000000..0b77e6ecb --- /dev/null +++ b/mcandle/mcandle-core/src/cpu/mod.rs @@ -0,0 +1,240 @@ +pub mod erf; +pub mod kernels; + +#[allow(unused)] +trait Cpu { + type Unit; + type Array; + const STEP: usize; + const EPR: usize; + + fn n() -> usize; + unsafe fn zero() -> Self::Unit; + unsafe fn zero_array() -> Self::Array; + unsafe fn load(mem_addr: *const f32) -> Self::Unit; + unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit; + unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit; + unsafe fn vec_reduce(x: Self::Array, y: *mut f32); + unsafe fn from_f32(v: f32) -> Self::Unit; + unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit); +} + +#[allow(unused)] +trait CpuF16 { + type Unit; + type Array; + const STEP: usize; + const EPR: usize; + + fn n() -> usize; + unsafe fn zero() -> Self::Unit; + unsafe fn zero_array() -> Self::Array; + unsafe fn load(mem_addr: *const f16) -> Self::Unit; + unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit; + unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit; + unsafe fn vec_reduce(x: Self::Array, y: *mut f32); + unsafe fn from_f32(v: f32) -> Self::Unit; + unsafe fn vec_store(mem_addr: *mut f16, a: Self::Unit); +} + +#[allow(unused)] +trait CpuBF16 { + type Unit; + type Array; + const STEP: usize; + const EPR: usize; + + fn n() -> usize; + unsafe fn zero() -> Self::Unit; + unsafe fn zero_array() -> Self::Array; + unsafe fn load(mem_addr: *const bf16) -> Self::Unit; + unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit; + unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit; + unsafe fn vec_reduce(x: Self::Array, y: *mut f32); + unsafe fn from_f32(v: f32) -> Self::Unit; + unsafe fn vec_store(mem_addr: *mut bf16, a: Self::Unit); +} + +use half::{bf16, f16}; + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +#[cfg(target_feature = "avx")] +pub mod avx; +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +#[cfg(target_feature = "avx")] +pub use avx::{CurrentCpu, CurrentCpuBF16, CurrentCpuF16}; + +#[cfg(target_arch = "wasm32")] +#[cfg(target_feature = "simd128")] +pub mod simd128; +#[cfg(target_arch = "wasm32")] +#[cfg(target_feature = "simd128")] +pub use simd128::CurrentCpu; + +#[cfg(any(target_arch = "arm", target_arch = "aarch64"))] +#[cfg(target_feature = "neon")] +pub mod neon; +#[cfg(any(target_arch = "arm", target_arch = "aarch64"))] +#[cfg(target_feature = "neon")] +pub use neon::CurrentCpu; + +#[cfg(any( + target_feature = "neon", + target_feature = "avx", + target_feature = "simd128" +))] +#[inline(always)] +pub(crate) unsafe fn vec_dot_f32(a_row: *const f32, b_row: *const f32, c: *mut f32, k: usize) { + let np = k & !(CurrentCpu::STEP - 1); + + let mut sum = CurrentCpu::zero_array(); + let mut ax = CurrentCpu::zero_array(); + let mut ay = CurrentCpu::zero_array(); + + for i in (0..np).step_by(CurrentCpu::STEP) { + for j in 0..CurrentCpu::n() { + ax[j] = CurrentCpu::load(a_row.add(i + j * CurrentCpu::EPR)); + ay[j] = CurrentCpu::load(b_row.add(i + j * CurrentCpu::EPR)); + + sum[j] = CurrentCpu::vec_fma(sum[j], ax[j], ay[j]); + } + } + + CurrentCpu::vec_reduce(sum, c); + + // leftovers + for i in np..k { + *c += *a_row.add(i) * (*b_row.add(i)); + } +} + +#[cfg(not(any( + target_feature = "neon", + target_feature = "avx", + target_feature = "simd128" +)))] +#[inline(always)] +pub(crate) unsafe fn vec_dot_f32(a_row: *const f32, b_row: *const f32, c: *mut f32, k: usize) { + // leftovers + for i in 0..k { + *c += *a_row.add(i) * (*b_row.add(i)); + } +} + +#[cfg(any( + target_feature = "neon", + target_feature = "avx", + target_feature = "simd128" +))] +#[inline(always)] +pub(crate) unsafe fn vec_sum(row: *const f32, b: *mut f32, k: usize) { + let np = k & !(CurrentCpu::STEP - 1); + + let mut sum = CurrentCpu::zero_array(); + let mut x = CurrentCpu::zero_array(); + + for i in (0..np).step_by(CurrentCpu::STEP) { + for j in 0..CurrentCpu::n() { + x[j] = CurrentCpu::load(row.add(i + j * CurrentCpu::EPR)); + sum[j] = CurrentCpu::vec_add(sum[j], x[j]); + } + } + + CurrentCpu::vec_reduce(sum, b); + + // leftovers + for i in np..k { + *b += *row.add(i) + } +} + +#[cfg(not(any( + target_feature = "neon", + target_feature = "avx", + target_feature = "simd128" +)))] +#[inline(always)] +pub(crate) unsafe fn vec_sum(row: *const f32, b: *mut f32, k: usize) { + *b = 0f32; + for i in 0..k { + *b += *row.add(i) + } +} + +#[cfg(target_feature = "avx")] +#[inline(always)] +pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f32, k: usize) { + let mut sumf = 0.0f32; + let np = k & !(CurrentCpuF16::STEP - 1); + + let mut sum = CurrentCpuF16::zero_array(); + let mut ax = CurrentCpuF16::zero_array(); + let mut ay = CurrentCpuF16::zero_array(); + + for i in (0..np).step_by(CurrentCpuF16::STEP) { + for j in 0..CurrentCpuF16::n() { + ax[j] = CurrentCpuF16::load(a_row.add(i + j * CurrentCpuF16::EPR)); + ay[j] = CurrentCpuF16::load(b_row.add(i + j * CurrentCpuF16::EPR)); + + sum[j] = CurrentCpuF16::vec_fma(sum[j], ax[j], ay[j]); + } + } + + CurrentCpuF16::vec_reduce(sum, &mut sumf); + + // leftovers + for i in np..k { + sumf += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32(); + } + *c = sumf; +} + +#[cfg(target_feature = "avx")] +#[inline(always)] +pub(crate) unsafe fn vec_dot_bf16(a_row: *const bf16, b_row: *const bf16, c: *mut f32, k: usize) { + let mut sumf = 0.0f32; + let np = k & !(CurrentCpuBF16::STEP - 1); + + let mut sum = CurrentCpuBF16::zero_array(); + let mut ax = CurrentCpuBF16::zero_array(); + let mut ay = CurrentCpuBF16::zero_array(); + + for i in (0..np).step_by(CurrentCpuBF16::STEP) { + for j in 0..CurrentCpuBF16::n() { + ax[j] = CurrentCpuBF16::load(a_row.add(i + j * CurrentCpuBF16::EPR)); + ay[j] = CurrentCpuBF16::load(b_row.add(i + j * CurrentCpuBF16::EPR)); + + sum[j] = CurrentCpuBF16::vec_fma(sum[j], ax[j], ay[j]); + } + } + + CurrentCpuBF16::vec_reduce(sum, &mut sumf); + + // leftovers + for i in np..k { + sumf += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32(); + } + *c = sumf; +} + +#[cfg(not(target_feature = "avx"))] +#[inline(always)] +pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f32, k: usize) { + // leftovers + let mut sum = 0.0; + for i in 0..k { + sum += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32(); + } + *c = sum; +} + +#[cfg(not(target_feature = "avx"))] +#[inline(always)] +pub(crate) unsafe fn vec_dot_bf16(a_row: *const bf16, b_row: *const bf16, c: *mut f32, k: usize) { + // leftovers + let mut sum = 0.0; + for i in 0..k { + sum += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32(); + } + *c = sum; +} diff --git a/mcandle/mcandle-core/src/cpu/neon.rs b/mcandle/mcandle-core/src/cpu/neon.rs new file mode 100644 index 000000000..66b8b45e1 --- /dev/null +++ b/mcandle/mcandle-core/src/cpu/neon.rs @@ -0,0 +1,74 @@ +use super::Cpu; +#[cfg(target_arch = "arm")] +use core::arch::arm::*; + +#[cfg(target_arch = "aarch64")] +use core::arch::aarch64::*; + +pub struct CurrentCpu {} + +const STEP: usize = 16; +const EPR: usize = 4; +const ARR: usize = STEP / EPR; + +impl CurrentCpu { + #[cfg(target_arch = "aarch64")] + unsafe fn reduce_one(x: float32x4_t) -> f32 { + vaddvq_f32(x) + } + + #[cfg(target_arch = "arm")] + unsafe fn reduce_one(x: float32x4_t) -> f32 { + vgetq_lane_f32(x, 0) + vgetq_lane_f32(x, 1) + vgetq_lane_f32(x, 2) + vgetq_lane_f32(x, 3) + } +} + +impl Cpu for CurrentCpu { + type Unit = float32x4_t; + type Array = [float32x4_t; ARR]; + + const STEP: usize = STEP; + const EPR: usize = EPR; + + fn n() -> usize { + ARR + } + + unsafe fn zero() -> Self::Unit { + vdupq_n_f32(0.0) + } + + unsafe fn from_f32(x: f32) -> Self::Unit { + vdupq_n_f32(x) + } + + unsafe fn zero_array() -> Self::Array { + [Self::zero(); ARR] + } + + unsafe fn load(mem_addr: *const f32) -> Self::Unit { + vld1q_f32(mem_addr) + } + + unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit { + vaddq_f32(a, b) + } + + unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit { + vfmaq_f32(a, b, c) + } + + unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit) { + vst1q_f32(mem_addr, a); + } + + unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) { + for i in 0..ARR / 2 { + x[2 * i] = vaddq_f32(x[2 * i], x[2 * i + 1]); + } + for i in 0..ARR / 4 { + x[4 * i] = vaddq_f32(x[4 * i], x[4 * i + 2]); + } + *y = Self::reduce_one(x[0]); + } +} diff --git a/mcandle/mcandle-core/src/cpu/simd128.rs b/mcandle/mcandle-core/src/cpu/simd128.rs new file mode 100644 index 000000000..64b94588b --- /dev/null +++ b/mcandle/mcandle-core/src/cpu/simd128.rs @@ -0,0 +1,64 @@ +use super::Cpu; +use core::arch::wasm32::*; + +pub struct CurrentCpu {} + +const STEP: usize = 16; +const EPR: usize = 4; +const ARR: usize = STEP / EPR; + +impl Cpu for CurrentCpu { + type Unit = v128; + type Array = [v128; ARR]; + + const STEP: usize = STEP; + const EPR: usize = EPR; + + fn n() -> usize { + ARR + } + + unsafe fn zero() -> Self::Unit { + f32x4_splat(0.0) + } + + unsafe fn zero_array() -> Self::Array { + [Self::zero(); ARR] + } + + unsafe fn from_f32(v: f32) -> Self::Unit { + f32x4_splat(v) + } + + unsafe fn load(mem_addr: *const f32) -> Self::Unit { + v128_load(mem_addr as *mut v128) + } + + unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit { + f32x4_add(a, b) + } + + unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit { + f32x4_add(f32x4_mul(b, c), a) + } + + unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit) { + v128_store(mem_addr as *mut v128, a); + } + + unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) { + for i in 0..ARR / 2 { + x[2 * i] = f32x4_add(x[2 * i], x[2 * i + 1]); + } + for i in 0..ARR / 4 { + x[4 * i] = f32x4_add(x[4 * i], x[4 * i + 2]); + } + for i in 0..ARR / 8 { + x[8 * i] = f32x4_add(x[8 * i], x[8 * i + 4]); + } + *y = f32x4_extract_lane::<0>(x[0]) + + f32x4_extract_lane::<1>(x[0]) + + f32x4_extract_lane::<2>(x[0]) + + f32x4_extract_lane::<3>(x[0]); + } +} diff --git a/mcandle/mcandle-core/src/cpu_backend/mod.rs b/mcandle/mcandle-core/src/cpu_backend/mod.rs new file mode 100644 index 000000000..6ef74c072 --- /dev/null +++ b/mcandle/mcandle-core/src/cpu_backend/mod.rs @@ -0,0 +1,3759 @@ +use std::ops::Deref; + +use crate::backend::{BackendDevice, BackendStorage}; +use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; +use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType}; +use float8::F8E4M3; +use half::{bf16, f16}; +use rayon::prelude::*; + +mod utils; +pub use utils::{ + binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2Alpha, Map2U8, + Map3, +}; + +const USE_IM2COL_CONV1D: bool = true; +const USE_COL2IM_CONV1D_TR: bool = true; +const USE_IM2COL_CONV2D: bool = true; + +// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator + +// intercept the oom errors to avoid panicking and provide a proper error. +#[derive(Debug, Clone)] +pub enum CpuStorage { + U8(Vec), + U32(Vec), + I16(Vec), + I32(Vec), + I64(Vec), + BF16(Vec), + F16(Vec), + F32(Vec), + F64(Vec), + F8E4M3(Vec), +} + +#[derive(Debug, Clone)] +pub enum CpuStorageRef<'a> { + U8(&'a [u8]), + U32(&'a [u32]), + I16(&'a [i16]), + I32(&'a [i32]), + I64(&'a [i64]), + BF16(&'a [bf16]), + F16(&'a [f16]), + F32(&'a [f32]), + F64(&'a [f64]), + F8E4M3(&'a [F8E4M3]), +} + +#[derive(Debug, Clone)] +pub struct CpuDevice; + +struct Cmp(CmpOp); +impl Map2U8 for Cmp { + const OP: &'static str = "cmp"; + #[inline(always)] + fn f( + &self, + lhs: &[T], + lhs_l: &Layout, + rhs: &[T], + rhs_l: &Layout, + ) -> Result> { + let dst = match self.0 { + CmpOp::Eq => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x == y)), + CmpOp::Ne => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x != y)), + CmpOp::Lt => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x < y)), + CmpOp::Le => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x <= y)), + CmpOp::Gt => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x > y)), + CmpOp::Ge => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x >= y)), + }; + Ok(dst) + } +} + +struct WCond<'a, T: IntDType>(&'a [T], &'a Layout); + +impl<'a, I: IntDType> Map2 for WCond<'a, I> { + const OP: &'static str = "where"; + #[inline(always)] + fn f(&self, t: &[T], t_l: &Layout, f: &[T], f_l: &Layout) -> Result> { + let vs = match ( + self.1.contiguous_offsets(), + t_l.contiguous_offsets(), + f_l.contiguous_offsets(), + ) { + (Some((o1, o2)), Some((o_t1, o_t2)), Some((o_f1, o_f2))) => { + let pred = &self.0[o1..o2]; + let t = &t[o_t1..o_t2]; + let f = &f[o_f1..o_f2]; + pred.iter() + .zip(t.iter().zip(f.iter())) + .map(|(p, (&t, &f))| if p.is_true() { t } else { f }) + .collect::>() + } + _ => self + .1 + .strided_index() + .zip(t_l.strided_index().zip(f_l.strided_index())) + .map(|(i_p, (i_t, i_f))| { + if self.0[i_p].is_true() { + t[i_t] + } else { + f[i_f] + } + }) + .collect::>(), + }; + Ok(vs) + } +} + +struct ReduceIndex { + reduce_dim_index: usize, + use_min: bool, + return_index: bool, +} + +impl ReduceIndex { + // The value gets replaced if f(s[current_acc], s[i]) returns true. + #[inline(always)] + fn fold_impl(&self, src: &[T], src_l: &Layout, f: F, g: G) -> Result> + where + T: Clone + Copy, + U: Clone + Copy, + F: Fn(T, T) -> bool, + G: Fn(T, usize) -> U, + { + let reduce_dim_size = src_l.dims()[self.reduce_dim_index]; + let reduce_dim_stride = src_l.stride()[self.reduce_dim_index]; + let dst_len = src_l.shape().elem_count() / reduce_dim_size; + let mut dst: Vec = Vec::with_capacity(dst_len); + let dst_to_set = dst.spare_capacity_mut(); + let dst_to_set = + unsafe { std::mem::transmute::<&mut [std::mem::MaybeUninit], &mut [U]>(dst_to_set) }; + match src_l.contiguous_offsets() { + Some((o1, o2)) => { + let src = &src[o1..o2]; + if reduce_dim_stride == 1 { + for (start_src_i, dst_v) in dst_to_set.iter_mut().enumerate() { + let start_src_i = start_src_i * reduce_dim_size; + let src = &src[start_src_i..start_src_i + reduce_dim_size]; + let mut acc = 0; + let mut val = src[0]; + for (src_i, &s) in src.iter().enumerate() { + if f(val, s) { + acc = src_i; + val = s + } + } + *dst_v = g(val, acc) + } + } else { + for (start_src_i, dst_v) in dst_to_set.iter_mut().enumerate() { + let (p, q) = ( + start_src_i / reduce_dim_stride, + start_src_i % reduce_dim_stride, + ); + // start_src_i = p * reduce_dim_stride + q + let start_src_i = p * reduce_dim_stride * reduce_dim_size + q; + let src = &src[start_src_i..]; + let mut acc = 0; + let mut val = src[0]; + for src_i in 0..reduce_dim_size { + let s = src[src_i * reduce_dim_stride]; + if f(val, s) { + acc = src_i; + val = s + } + } + *dst_v = g(val, acc) + } + } + } + None => { + let l = src_l.narrow(self.reduce_dim_index, 0, 1)?; + for (unstr_index, src_index) in l.strided_index().enumerate() { + let src = &src[src_index..]; + let mut acc = 0; + let mut val = src[0]; + for src_i in 0..reduce_dim_size { + let s = src[src_i * reduce_dim_stride]; + if f(val, s) { + acc = src_i; + val = s + } + } + dst_to_set[unstr_index] = g(val, acc) + } + } + } + unsafe { dst.set_len(dst_len) }; + Ok(dst) + } +} + +impl Map1Any for ReduceIndex { + #[inline(always)] + fn f) -> CpuStorage>( + &self, + src: &[T], + src_l: &Layout, + wrap: W, + ) -> Result { + if src_l.shape().elem_count() == 0 { + Err(Error::EmptyTensor { op: "reduce" }.bt())? + } + let dst = match (self.return_index, self.use_min) { + (false, true) => wrap(self.fold_impl(src, src_l, |x, y| x > y, |v, _i| v)?), + (false, false) => wrap(self.fold_impl(src, src_l, |x, y| x < y, |v, _i| v)?), + (true, true) => { + CpuStorage::U32(self.fold_impl(src, src_l, |x, y| x > y, |_v, i| i as u32)?) + } + (true, false) => { + CpuStorage::U32(self.fold_impl(src, src_l, |x, y| x < y, |_v, i| i as u32)?) + } + }; + Ok(dst) + } +} + +struct ReduceSum<'a> { + dst_shape: &'a Shape, + reduce_dims: &'a [usize], + reduce_dims_and_stride: Vec<(usize, usize)>, +} + +impl<'a> ReduceSum<'a> { + #[inline(always)] + fn fold_impl(&self, src: &[T], src_l: &Layout, start_elt: T) -> Result> + where + T: WithDType, + { + let mut dst = vec![start_elt; self.dst_shape.elem_count()]; + match src_l.contiguous_offsets() { + Some((o1, o2)) => { + let src = &src[o1..o2]; + // Handle the case where we reduce over the last dimensions separately as it is + // fairly common and easy to optimize. This rely on the layout being contiguous! + // reduce_dims is sorted, check if it is ranging from a to n-1. + let reduce_over_last_dims = self + .reduce_dims + .iter() + .rev() + .enumerate() + .all(|(i, &v)| v == src_l.shape().rank() - 1 - i); + if reduce_over_last_dims { + let reduce_sz = self + .reduce_dims_and_stride + .iter() + .map(|(u, _)| u) + .product::(); + for (dst_i, dst_v) in dst.iter_mut().enumerate() { + let src_i = dst_i * reduce_sz; + unsafe { + T::vec_reduce_sum( + src[src_i..src_i + reduce_sz].as_ptr(), + dst_v, + reduce_sz, + ) + }; + } + return Ok(dst); + }; + for (unstr_index, &src) in src.iter().enumerate() { + let mut dst_index = unstr_index; + // Set the reduce_dims indexes to 0. + for &(dim, stride) in self.reduce_dims_and_stride.iter() { + // The compiler is able to optimize the following in a single divmod op. + let (pre, post) = (dst_index / stride, dst_index % stride); + dst_index = (pre / dim) * stride + post; + } + dst[dst_index] += src; + } + } + None => { + for (unstr_index, src_index) in src_l.strided_index().enumerate() { + let mut dst_index = unstr_index; + // Set the reduce_dims indexes to 0. + for &(dim, stride) in self.reduce_dims_and_stride.iter() { + // The compiler is able to optimize the following in a single divmod op. + let (pre, post) = (dst_index / stride, dst_index % stride); + dst_index = (pre / dim) * stride + post; + } + dst[dst_index] += src[src_index]; + } + } + } + Ok(dst) + } +} + +impl<'a> Map1 for ReduceSum<'a> { + #[inline(always)] + fn f(&self, src: &[T], src_l: &Layout) -> Result> { + self.fold_impl(src, src_l, T::zero()) + } +} + +struct Affine(f64, f64); + +impl Map1 for Affine { + fn f(&self, vs: &[T], layout: &Layout) -> Result> { + let mul = T::from_f64(self.0); + let add = T::from_f64(self.1); + Ok(unary_map(vs, layout, |v| v * mul + add)) + } +} + +struct AvgPool2D((usize, usize), (usize, usize)); + +impl Map1 for AvgPool2D { + fn f(&self, src: &[T], layout: &Layout) -> Result> { + // https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html + let (k_h, k_w) = self.0; + let (s_h, s_w) = self.1; + let (b_sz, c, h, w) = layout.shape().dims4()?; + let stride = layout.stride(); + let (stride_h, stride_w) = (stride[2], stride[3]); + let h_out = (h - k_h) / s_h + 1; + let w_out = (w - k_w) / s_w + 1; + let src_index = layout.start_offset(); + let mut dst = vec![T::zero(); b_sz * c * h_out * w_out]; + let scale = 1f64 / (k_h * k_w) as f64; + let scale = T::from_f64(scale); + for b_idx in 0..b_sz { + let dst = &mut dst[b_idx * c * h_out * w_out..]; + let src_index = src_index + b_idx * stride[0]; + for c_idx in 0..c { + let dst = &mut dst[c_idx * h_out * w_out..]; + let src_index = src_index + c_idx * stride[1]; + for h_idx in 0..h_out { + for w_idx in 0..w_out { + let mut sum = T::zero(); + for m in 0..k_h { + for n in 0..k_w { + let m = s_h * h_idx + m; + let n = s_w * w_idx + n; + sum += src[src_index + m * stride_h + n * stride_w] + } + } + dst[h_idx * w_out + w_idx] = sum * scale; + } + } + } + } + Ok(dst) + } +} + +struct MaxPool2D((usize, usize), (usize, usize)); + +impl Map1 for MaxPool2D { + fn f(&self, src: &[T], layout: &Layout) -> Result> { + // https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html + let (k_h, k_w) = self.0; + let (s_h, s_w) = self.1; + let (b_sz, c, h, w) = layout.shape().dims4()?; + let stride = layout.stride(); + let (stride_h, stride_w) = (stride[2], stride[3]); + let h_out = (h - k_h) / s_h + 1; + let w_out = (w - k_w) / s_w + 1; + let src_index = layout.start_offset(); + let mut dst = vec![T::zero(); b_sz * c * h_out * w_out]; + for b_idx in 0..b_sz { + let dst = &mut dst[b_idx * c * h_out * w_out..]; + let src_index = src_index + b_idx * stride[0]; + for c_idx in 0..c { + let dst = &mut dst[c_idx * h_out * w_out..]; + let src_index = src_index + c_idx * stride[1]; + for h_idx in 0..h_out { + for w_idx in 0..w_out { + let mut largest = + src[src_index + s_h * h_idx * stride_h + s_w * w_idx * stride_w]; + for m in 0..k_h { + for n in 0..k_w { + let m = s_h * h_idx + m; + let n = s_w * w_idx + n; + if largest < src[src_index + m * stride_h + n * stride_w] { + largest = src[src_index + m * stride_h + n * stride_w] + } + } + } + dst[h_idx * w_out + w_idx] = largest; + } + } + } + } + Ok(dst) + } +} + +struct UpsampleNearest1D(usize); + +impl Map1 for UpsampleNearest1D { + fn f(&self, src: &[T], layout: &Layout) -> Result> { + // TODO: Specialized implementation for the case 2*sz? + let dst_sz = self.0; + let (b_sz, c, src_sz) = layout.shape().dims3()?; + let stride = layout.stride(); + let stride_sz = stride[2]; + let src_index = layout.start_offset(); + let scale_sz = src_sz as f64 / dst_sz as f64; + let mut dst = vec![T::zero(); b_sz * c * dst_sz]; + let src_idxs = (0..dst_sz) + .map(|idx| usize::min(src_sz - 1, (idx as f64 * scale_sz) as usize)) + .collect::>(); + for b_idx in 0..b_sz { + let dst = &mut dst[b_idx * c * dst_sz..]; + let src_index = src_index + b_idx * stride[0]; + for c_idx in 0..c { + let dst = &mut dst[c_idx * dst_sz..]; + let src_index = src_index + c_idx * stride[1]; + for (idx, src_idx) in src_idxs.iter().enumerate() { + dst[idx] = src[src_index + src_idx * stride_sz] + } + } + } + Ok(dst) + } +} + +struct UpsampleNearest2D(usize, usize); + +impl Map1 for UpsampleNearest2D { + fn f(&self, src: &[T], layout: &Layout) -> Result> { + // TODO: Specialized implementation for the case 2*h, 2*w? + let (dst_h, dst_w) = (self.0, self.1); + let (b_sz, c, src_h, src_w) = layout.shape().dims4()?; + let stride = layout.stride(); + let (stride_h, stride_w) = (stride[2], stride[3]); + let src_index = layout.start_offset(); + let scale_h = src_h as f64 / dst_h as f64; + let scale_w = src_w as f64 / dst_w as f64; + let mut dst = vec![T::zero(); b_sz * c * dst_h * dst_w]; + let src_h_idxs = (0..dst_h) + .map(|h_idx| usize::min(src_h - 1, (h_idx as f64 * scale_h) as usize)) + .collect::>(); + let src_w_idxs = (0..dst_w) + .map(|w_idx| usize::min(src_w - 1, (w_idx as f64 * scale_w) as usize)) + .collect::>(); + for b_idx in 0..b_sz { + let dst = &mut dst[b_idx * c * dst_h * dst_w..]; + let src_index = src_index + b_idx * stride[0]; + for c_idx in 0..c { + let dst = &mut dst[c_idx * dst_h * dst_w..]; + let src_index = src_index + c_idx * stride[1]; + for (h_idx, src_h_idx) in src_h_idxs.iter().enumerate() { + for (w_idx, src_w_idx) in src_w_idxs.iter().enumerate() { + let src_index = src_index + src_h_idx * stride_h + src_w_idx * stride_w; + dst[h_idx * dst_w + w_idx] = src[src_index] + } + } + } + } + Ok(dst) + } +} + +struct Gather<'a, I: IntDType> { + ids: &'a [I], + ids_l: &'a Layout, + dim: usize, +} + +impl<'a, I: IntDType> Map1 for Gather<'a, I> { + fn f(&self, src: &[T], src_l: &Layout) -> Result> { + let ids = match self.ids_l.contiguous_offsets() { + Some((a, b)) => &self.ids[a..b], + None => Err(Error::RequiresContiguous { op: "gather" }.bt())?, + }; + let src = match src_l.contiguous_offsets() { + Some((a, b)) => &src[a..b], + None => Err(Error::RequiresContiguous { op: "gather" }.bt())?, + }; + let dim = self.dim; + let ids_dims = self.ids_l.dims(); + let src_dims = src_l.dims(); + let dst_len: usize = ids_dims.iter().product(); + let dst_left_len: usize = ids_dims[..dim].iter().product(); + let dst_dim_len = ids_dims[dim]; + let dst_right_len: usize = ids_dims[dim + 1..].iter().product(); + + let src_dim_len = src_dims[dim]; + let src_right_len: usize = src_dims[dim + 1..].iter().product(); + + let mut dst = vec![T::zero(); dst_len]; + for left_i in 0..dst_left_len { + let start_src_idx = left_i * src_right_len * src_dim_len; + let start_dst_idx = left_i * dst_right_len * dst_dim_len; + for i in 0..dst_dim_len { + let start_dst_idx = start_dst_idx + i * dst_right_len; + for right_i in 0..dst_right_len { + let dst_idx = start_dst_idx + right_i; + let index = ids[dst_idx].as_usize(); + if index >= src_dim_len { + Err(Error::InvalidIndex { + index, + size: src_dim_len, + op: "gather", + } + .bt())? + } + let src_idx = start_src_idx + index * src_right_len + right_i; + dst[dst_idx] = src[src_idx] + } + } + } + Ok(dst) + } +} + +struct IndexSelect<'a, T: IntDType> { + ids: &'a [T], + ids_l: &'a Layout, + dim: usize, +} + +impl<'a, I: IntDType> Map1 for IndexSelect<'a, I> { + fn f(&self, src: &[T], layout: &Layout) -> Result> { + let src = match layout.contiguous_offsets() { + Some((a, b)) => &src[a..b], + None => Err(Error::RequiresContiguous { op: "index-select" }.bt())?, + }; + let dim = self.dim; + let n_ids = match self.ids_l.dims() { + [n_ids] => *n_ids, + d => Err(Error::UnexpectedNumberOfDims { + expected: 1, + got: d.len(), + shape: self.ids_l.shape().clone(), + } + .bt())?, + }; + let stride_ids = self.ids_l.stride()[0]; + let mut dst_dims = layout.dims().to_vec(); + let src_dim = dst_dims[dim]; + dst_dims[dim] = n_ids; + let dst_len: usize = dst_dims.iter().product(); + let left_len: usize = dst_dims[..dim].iter().product(); + let right_len: usize = dst_dims[dim + 1..].iter().product(); + let mut dst = vec![T::zero(); dst_len]; + for left_i in 0..left_len { + let start_src_idx = left_i * right_len * src_dim; + let start_dst_idx = left_i * right_len * n_ids; + for i in 0..n_ids { + let index = self.ids[self.ids_l.start_offset() + stride_ids * i].as_usize(); + if index >= src_dim { + Err(Error::InvalidIndex { + index, + size: src_dim, + op: "index-select", + } + .bt())? + } + let start_src_idx = start_src_idx + index * right_len; + let start_dst_idx = start_dst_idx + i * right_len; + dst[start_dst_idx..start_dst_idx + right_len] + .copy_from_slice(&src[start_src_idx..start_src_idx + right_len]) + } + } + Ok(dst) + } +} + +struct ScatterAdd<'a, I: IntDType> { + ids: &'a [I], + ids_l: &'a Layout, + dim: usize, +} + +impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> { + const OP: &'static str = "scatter-add"; + fn f(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result> { + let dst_len = l1.shape().elem_count(); + let mut dst = vec![T::zero(); dst_len]; + copy_strided_src_(v1, &mut dst, 0, l1); + let src = match src_l.contiguous_offsets() { + None => Err(Error::RequiresContiguous { op: "scatter-add" }.bt())?, + Some((o1, o2)) => &src[o1..o2], + }; + + let dim = self.dim; + let ids_dims = self.ids_l.dims(); + let dst_dims = l1.dims(); + let dst_dim_len = dst_dims[dim]; + let dst_right_len: usize = dst_dims[dim + 1..].iter().product(); + + let ids_left_len: usize = ids_dims[..dim].iter().product(); + let ids_dim_len = ids_dims[dim]; + let ids_right_len: usize = ids_dims[dim + 1..].iter().product(); + + let ids = match self.ids_l.contiguous_offsets() { + Some((a, b)) => &self.ids[a..b], + None => Err(Error::RequiresContiguous { op: "gather" }.bt())?, + }; + for left_i in 0..ids_left_len { + let start_ids_idx = left_i * ids_right_len * ids_dim_len; + let start_dst_idx = left_i * dst_right_len * dst_dim_len; + for i in 0..ids_dim_len { + let start_ids_idx = start_ids_idx + i * ids_right_len; + for right_i in 0..dst_right_len { + let ids_idx = start_ids_idx + right_i; + let index = ids[ids_idx].as_usize(); + if index >= dst_dim_len { + Err(Error::InvalidIndex { + index, + size: dst_dim_len, + op: "gather", + } + .bt())? + } + let dst_idx = start_dst_idx + index * dst_right_len + right_i; + dst[dst_idx] += src[ids_idx] + } + } + } + + Ok(dst) + } +} + +struct IndexAdd<'a, I: IntDType> { + ids: &'a [I], + dim: usize, +} + +impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> { + const OP: &'static str = "index-add"; + // https://pytorch.org/docs/stable/generated/torch.Tensor.index_add_.html#torch.Tensor.index_add_ + // v1, l1 -> self + fn f(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result> { + let dst_len = l1.shape().elem_count(); + let mut dst = vec![T::zero(); dst_len]; + copy_strided_src_(v1, &mut dst, 0, l1); + let src = match src_l.contiguous_offsets() { + None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?, + Some((o1, o2)) => &src[o1..o2], + }; + let dim = self.dim; + let max_idx = l1.dims()[dim]; + let pre_dim = src_l.dims()[..dim].iter().product::(); + let src_dim_sz = src_l.dims()[dim]; + let post_dim = src_l.dims()[dim + 1..].iter().product::(); + if dim == 0 { + for (src_idx, dst_idx) in self.ids.iter().enumerate() { + let dst_idx = dst_idx.as_usize(); + if dst_idx >= max_idx { + Err(Error::InvalidIndex { + index: dst_idx, + op: "index-add", + size: max_idx, + })? + } + let src_idx = src_idx * post_dim; + let dst_idx = dst_idx * post_dim; + let src = &src[src_idx..src_idx + post_dim]; + let dst = &mut dst[dst_idx..dst_idx + post_dim]; + for (d, &s) in dst.iter_mut().zip(src.iter()) { + *d += s + } + } + } else { + for (src_idx, dst_idx) in self.ids.iter().enumerate() { + let dst_idx = dst_idx.as_usize(); + if dst_idx >= max_idx { + Err(Error::InvalidIndex { + index: dst_idx, + op: "index-add", + size: max_idx, + })? + } + for pre_i in 0..pre_dim { + let pre_src_i = (pre_i * src_dim_sz + src_idx) * post_dim; + let pre_dst_i = (pre_i * max_idx + dst_idx) * post_dim; + let src = &src[pre_src_i..pre_src_i + post_dim]; + let dst = &mut dst[pre_dst_i..pre_dst_i + post_dim]; + for (d, &s) in dst.iter_mut().zip(src.iter()) { + *d += s + } + } + } + } + Ok(dst) + } +} + +#[allow(clippy::too_many_arguments)] +fn copy2d_( + src: &[T], + dst: &mut [T], + d1: usize, + d2: usize, + src_stride1: usize, + dst_stride1: usize, + src_offset: usize, + dst_offset: usize, +) { + for i1 in 0..d1 { + let dst_idx = i1 * dst_stride1 + dst_offset; + let src_idx = i1 * src_stride1 + src_offset; + let dst = &mut dst[dst_idx..dst_idx + d2]; + let src = &src[src_idx..src_idx + d2]; + dst.copy_from_slice(src) + } +} + +fn copy_strided_src_(src: &[T], dst: &mut [T], dst_offset: usize, src_l: &Layout) { + match src_l.strided_blocks() { + crate::StridedBlocks::SingleBlock { start_offset, len } => { + let to_copy = (dst.len() - dst_offset).min(len); + dst[dst_offset..dst_offset + to_copy] + .copy_from_slice(&src[start_offset..start_offset + to_copy]) + } + crate::StridedBlocks::MultipleBlocks { + block_start_index, + block_len: 1, + } => { + for (dst_index, src_index) in block_start_index.enumerate() { + let dst_index = dst_index + dst_offset; + if dst_index >= dst.len() { + break; + } + dst[dst_index] = src[src_index] + } + } + crate::StridedBlocks::MultipleBlocks { + block_start_index, + block_len, + } => { + let mut dst_index = dst_offset; + for src_index in block_start_index { + let next_dst_index = dst_index + block_len; + if dst_index >= dst.len() { + break; + } + let to_copy = usize::min(block_len, dst.len() - dst_index); + dst[dst_index..dst_index + to_copy] + .copy_from_slice(&src[src_index..src_index + to_copy]); + dst_index = next_dst_index + } + } + } +} + +struct Conv1D<'a>(&'a crate::conv::ParamsConv1D); + +impl<'a> Map2 for Conv1D<'a> { + const OP: &'static str = "conv1d"; + fn f(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result> { + let p = self.0; + let inp = &inp[inp_l.start_offset()..]; + let k = &k[k_l.start_offset()..]; + let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?; + let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?; + let l_out = p.l_out(); + let dst_elems = p.c_out * l_out * p.b_size; + // The output shape is [b_size, c_out, l_out] + let dst = vec![T::zero(); dst_elems]; + + // TODO: Avoid making this copy if `inp` already has the appropriate layout. + let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in]; + for b_idx in 0..p.b_size { + for src_l in 0..p.l_in { + for src_c_idx in 0..p.c_in { + let inp_idx = b_idx * inp_s0 + src_c_idx * inp_s1 + src_l * inp_s2; + inp_cont[b_idx * p.l_in * p.c_in + src_l * p.c_in + src_c_idx] = inp[inp_idx] + } + } + } + + for offset in 0..p.k_size { + (0..p.c_out).into_par_iter().for_each(|dst_c_idx| { + let dst_idx = dst_c_idx * l_out; + let k_cont = (0..p.c_in) + .map(|c_in_idx| k[dst_c_idx * k_s0 + c_in_idx * k_s1 + offset * k_s2]) + .collect::>(); + for b_idx in 0..p.b_size { + let dst_idx = dst_idx + b_idx * p.c_out * l_out; + for dst_l in 0..l_out { + let dst_idx = dst_idx + dst_l; + let src_l = p.stride * dst_l + offset * p.dilation; + if src_l < p.padding || src_l >= p.padding + p.l_in { + continue; + } + let src_l = src_l - p.padding; + let inp_cont = &inp_cont[b_idx * p.l_in * p.c_in + src_l * p.c_in..]; + assert!(inp_cont.len() >= p.c_in); + assert!(k_cont.len() >= p.c_in); + let mut d = T::zero(); + unsafe { T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in) } + let dst_p = dst.as_ptr(); + // Safety: dst_idx are uniques per dst_c_idx which is used to parallelise + // the different tasks so no two threads can try to write at the same + // location. + unsafe { + let ptr = dst_p.add(dst_idx) as *mut T; + *ptr += d + } + } + } + }) + } + Ok(dst) + } +} + +struct Im2Col1D { + l_k: usize, + stride: usize, + dilation: usize, + padding: usize, +} + +impl Im2Col1D { + fn l_out(&self, l: usize) -> usize { + (l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1 + } +} + +impl Map1 for Im2Col1D { + fn f(&self, vs: &[T], layout: &Layout) -> Result> { + let &Self { + l_k, + stride, + dilation, + padding, + } = self; + let (b, c, l) = layout.shape().dims3()?; + let l_out = self.l_out(l); + let src = &vs[layout.start_offset()..]; + let mut dst = vec![T::zero(); b * l_out * c * l_k]; + let (src_s0, src_s1, src_s2) = { + let s = layout.stride(); + (s[0], s[1], s[2]) + }; + // TODO: provide specialized kernels for the common use cases. + // - l_k = 1 + // - padding = 0 + // - stride = 1 + // - dilation = 1 + for b_idx in 0..b { + let src_idx = b_idx * src_s0; + let dst_idx = b_idx * l_out * c * l_k; + for l_idx in 0..l_out { + let dst_idx = dst_idx + l_idx * c * l_k; + for c_idx in 0..c { + let dst_idx = dst_idx + c_idx * l_k; + let src_idx = c_idx * src_s1 + src_idx; + for l_k_idx in 0..l_k { + let src_l = l_idx * stride + l_k_idx * dilation; + if padding != 0 && (src_l < padding || src_l >= l + padding) { + continue; + } + let src_l = src_l - padding; + let src_idx = src_idx + src_l * src_s2; + let dst_idx = dst_idx + l_k_idx; + dst[dst_idx] = src[src_idx] + } + } + } + } + Ok(dst) + } +} + +struct Im2Col { + h_k: usize, + w_k: usize, + stride: usize, + dilation: usize, + padding: usize, +} + +impl Im2Col { + fn hw_out(&self, h: usize, w: usize) -> (usize, usize) { + let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1; + let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1; + (h_out, w_out) + } +} + +impl Map1 for Im2Col { + fn f(&self, vs: &[T], layout: &Layout) -> Result> { + let &Self { + h_k, + w_k, + stride, + dilation, + padding, + } = self; + let (b, c, h, w) = layout.shape().dims4()?; + let (h_out, w_out) = self.hw_out(h, w); + let src = &vs[layout.start_offset()..]; + let mut dst = vec![T::zero(); b * h_out * w_out * c * h_k * w_k]; + let (src_s0, src_s1, src_s2, src_s3) = { + let s = layout.stride(); + (s[0], s[1], s[2], s[3]) + }; + // TODO: provide specialized kernels for the common use cases. + // - h_k = w_k = 1 + // - padding = 0 + // - stride = 1 + // - dilation = 1 + for b_idx in 0..b { + let src_idx = b_idx * src_s0; + let dst_idx = b_idx * h_out * w_out * c * h_k * w_k; + for h_idx in 0..h_out { + let dst_idx = dst_idx + h_idx * w_out * c * h_k * w_k; + for w_idx in 0..w_out { + let dst_idx = dst_idx + w_idx * c * h_k * w_k; + for c_idx in 0..c { + let dst_idx = dst_idx + c_idx * h_k * w_k; + let src_idx = c_idx * src_s1 + src_idx; + for h_k_idx in 0..h_k { + let src_h = h_idx * stride + h_k_idx * dilation; + if padding != 0 && (src_h < padding || src_h >= h + padding) { + continue; + } + let src_h = src_h - padding; + let src_idx = src_idx + src_h * src_s2; + let dst_idx = dst_idx + h_k_idx * w_k; + for w_k_idx in 0..w_k { + let src_w = w_idx * stride + w_k_idx * dilation; + if padding != 0 && (src_w < padding || src_w >= w + padding) { + continue; + } + let src_w = src_w - padding; + let src_idx = src_idx + src_w * src_s3; + let dst_idx = dst_idx + w_k_idx; + dst[dst_idx] = src[src_idx] + } + } + } + } + } + } + Ok(dst) + } +} + +struct Col2Im1D { + stride: usize, +} + +impl Map1 for Col2Im1D { + fn f(&self, col: &[T], l: &Layout) -> Result> { + let (b_size, l_in, c_out, k_size) = l.shape().dims4()?; + let stride = self.stride; + let l_out = (l_in - 1) * stride + k_size; + let mut im = vec![T::zero(); b_size * c_out * l_out]; + let (dst_s0, dst_s1) = (c_out * l_out, l_out); + let (src_s0, src_s1, src_s2) = (c_out * k_size * l_in, c_out * k_size, k_size); + for l_in_i in 0..l_in { + for k_i in 0..k_size { + let l_out_i = l_in_i * stride + k_i; + for b_i in 0..b_size { + for c_i in 0..c_out { + let dst_idx = b_i * dst_s0 + c_i * dst_s1 + l_out_i; + let src_idx = b_i * src_s0 + l_in_i * src_s1 + c_i * src_s2 + k_i; + im[dst_idx] += col[src_idx] + } + } + } + } + Ok(im) + } +} + +struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D); + +impl<'a> Map2 for ConvTranspose1D<'a> { + const OP: &'static str = "conv_transpose1d"; + fn f(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result> { + let p = self.0; + let inp = &inp[inp_l.start_offset()..]; + let k = &k[k_l.start_offset()..]; + let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?; + let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?; + let l_out = p.l_out(); + + // Output shape: [b_size, c_out, l_out]. + let dst_elems = p.c_out * l_out * p.b_size; + let dst = vec![T::zero(); dst_elems]; + let dst_s0 = p.c_out * l_out; + let dst_s1 = l_out; + let dst_s2 = 1; + + // TODO: Avoid making this copy if `inp` already has the appropriate layout. + let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in]; + let cont_s0 = p.l_in * p.c_in; + let cont_s1 = p.c_in; + for b_idx in 0..p.b_size { + for l_idx in 0..p.l_in { + for c_idx in 0..p.c_in { + let src_idx = b_idx * inp_s0 + c_idx * inp_s1 + l_idx * inp_s2; + let dst_idx = b_idx * cont_s0 + l_idx * cont_s1 + c_idx; + inp_cont[dst_idx] = inp[src_idx] + } + } + } + + for k_idx in 0..p.k_size { + (0..p.c_out).into_par_iter().for_each(|dst_c_idx| { + let k_cont = (0..p.c_in) + .map(|c_in_idx| k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_idx * k_s2]) + .collect::>(); + for b_idx in 0..p.b_size { + for l_idx in 0..p.l_in { + let out_idx = l_idx * p.stride + k_idx * p.dilation; + if out_idx < p.padding { + continue; + } + let out_idx = out_idx - p.padding; + if out_idx < l_out { + let inp_cont = &inp_cont[b_idx * cont_s0 + l_idx * cont_s1..]; + let dst_idx = b_idx * dst_s0 + out_idx * dst_s2 + dst_c_idx * dst_s1; + let mut d = T::zero(); + unsafe { + T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in) + } + let dst_p = dst.as_ptr(); + // Safety: dst_idx are uniques per dst_c_idx which is used to + // parallelise the different tasks so no two threads can try to + // write at the same location. + unsafe { + let ptr = dst_p.add(dst_idx) as *mut T; + *ptr += d + } + } + } + } + }) + } + Ok(dst) + } +} + +struct Conv2D<'a>(&'a crate::conv::ParamsConv2D); + +impl<'a> Map2 for Conv2D<'a> { + const OP: &'static str = "conv2d"; + fn f(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result> { + let p = self.0; + let inp = &inp[inp_l.start_offset()..]; + let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?; + let k = &k[k_l.start_offset()..]; + let (k_s0, k_s1, k_s2, k_s3) = crate::shape::dims4(k_l.stride())?; + let (out_h, out_w) = (p.out_h(), p.out_w()); + + // Output shape: [b_size, c_out, out_h, out_w]. + let dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w]; + + // TODO: Avoid making this copy if `inp` already has the appropriate layout. + let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.i_h * p.i_w]; + let cont_s0 = p.i_h * p.i_w * p.c_in; + let cont_s1 = p.i_w * p.c_in; + let cont_s2 = p.c_in; + for b_idx in 0..p.b_size { + for h_idx in 0..p.i_h { + for w_idx in 0..p.i_w { + for c_idx in 0..p.c_in { + let src_idx = + b_idx * inp_s0 + c_idx * inp_s1 + h_idx * inp_s2 + w_idx * inp_s3; + let dst_idx = b_idx * cont_s0 + h_idx * cont_s1 + w_idx * cont_s2 + c_idx; + inp_cont[dst_idx] = inp[src_idx] + } + } + } + } + + for offset_h in 0..p.k_h { + for offset_w in 0..p.k_w { + (0..p.c_out).into_par_iter().for_each(|dst_c_idx| { + let dst_idx = dst_c_idx * out_w * out_h; + let k_cont = (0..p.c_in) + .map(|c_in_idx| { + k[dst_c_idx * k_s0 + + c_in_idx * k_s1 + + offset_h * k_s2 + + offset_w * k_s3] + }) + .collect::>(); + for b_idx in 0..p.b_size { + let dst_idx = dst_idx + b_idx * p.c_out * out_h * out_w; + for dst_h in 0..out_h { + let dst_idx = dst_idx + dst_h * out_w; + let src_h = p.stride * dst_h + offset_h * p.dilation; + if src_h < p.padding || src_h >= p.i_h + p.padding { + continue; + } + let src_h = src_h - p.padding; + for dst_w in 0..out_w { + let dst_idx = dst_idx + dst_w; + let src_w = p.stride * dst_w + offset_w * p.dilation; + if src_w < p.padding || src_w >= p.i_w + p.padding { + continue; + } + let src_w = src_w - p.padding; + let inp_cont = &inp_cont + [b_idx * cont_s0 + src_h * cont_s1 + src_w * cont_s2..]; + assert!(inp_cont.len() >= p.c_in); + assert!(k_cont.len() >= p.c_in); + let mut d = T::zero(); + unsafe { + T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in) + } + let dst_p = dst.as_ptr(); + // Safety: dst_idx are uniques per dst_c_idx which is used to parallelise + // the different tasks so no two threads can try to write at the same + // location. + unsafe { + let ptr = dst_p.add(dst_idx) as *mut T; + *ptr += d + } + } + } + } + }); + } + } + + Ok(dst) + } +} + +struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D); + +impl<'a> Map2 for ConvTranspose2D<'a> { + const OP: &'static str = "conv_transpose2d"; + fn f(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result> { + let p = self.0; + let inp = &inp[inp_l.start_offset()..]; + let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?; + let k = &k[k_l.start_offset()..]; + let (k_s0, k_s1, k_s2, k_s3) = crate::shape::dims4(k_l.stride())?; + let (out_h, out_w) = (p.out_h(), p.out_w()); + + // Output shape: [b_size, c_out, out_h, out_w]. + let dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w]; + let dst_s0 = p.c_out * out_h * out_w; + let dst_s1 = out_h * out_w; + let dst_s2 = out_w; + let dst_s3 = 1; + + // TODO: Avoid making this copy if `inp` already has the appropriate layout. + let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.i_h * p.i_w]; + let cont_s0 = p.i_h * p.i_w * p.c_in; + let cont_s1 = p.i_w * p.c_in; + let cont_s2 = p.c_in; + for b_idx in 0..p.b_size { + for h_idx in 0..p.i_h { + for w_idx in 0..p.i_w { + for c_idx in 0..p.c_in { + let src_idx = + b_idx * inp_s0 + c_idx * inp_s1 + h_idx * inp_s2 + w_idx * inp_s3; + let dst_idx = b_idx * cont_s0 + h_idx * cont_s1 + w_idx * cont_s2 + c_idx; + inp_cont[dst_idx] = inp[src_idx] + } + } + } + } + + for k_y in 0..p.k_h { + for k_x in 0..p.k_w { + (0..p.c_out).into_par_iter().for_each(|dst_c_idx| { + let k_cont = (0..p.c_in) + .map(|c_in_idx| { + k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_y * k_s2 + k_x * k_s3] + }) + .collect::>(); + for b_idx in 0..p.b_size { + for inp_y in 0..p.i_h { + for inp_x in 0..p.i_w { + let out_x = inp_x * p.stride + k_x * p.dilation; + let out_y = inp_y * p.stride + k_y * p.dilation; + if out_x < p.padding || out_y < p.padding { + continue; + } + let out_x = out_x - p.padding; + let out_y = out_y - p.padding; + if out_x < out_w && out_y < out_h { + let inp_cont = &inp_cont + [b_idx * cont_s0 + inp_y * cont_s1 + inp_x * cont_s2..]; + let dst_idx = b_idx * dst_s0 + + out_y * dst_s2 + + out_x * dst_s3 + + dst_c_idx * dst_s1; + let mut d = T::zero(); + unsafe { + T::vec_dot( + inp_cont.as_ptr(), + k_cont.as_ptr(), + &mut d, + p.c_in, + ) + } + let dst_p = dst.as_ptr(); + // Safety: dst_idx are uniques per dst_c_idx which is used to + // parallelise the different tasks so no two threads can try to + // write at the same location. + unsafe { + let ptr = dst_p.add(dst_idx) as *mut T; + *ptr += d + } + } + } + } + } + }) + } + } + Ok(dst) + } +} + +struct MatMul((usize, usize, usize, usize)); + +impl MatMul { + fn striding_error(&self, lhs_l: &Layout, rhs_l: &Layout, msg: &'static str) -> Error { + Error::MatMulUnexpectedStriding(Box::new(crate::error::MatMulUnexpectedStriding { + lhs_l: lhs_l.clone(), + rhs_l: rhs_l.clone(), + bmnk: self.0, + msg, + })) + .bt() + } + + fn ab_skip(&self, lhs_l: &Layout, rhs_l: &Layout) -> Result<(usize, usize)> { + let lhs_stride = lhs_l.stride(); + let rhs_stride = rhs_l.stride(); + let rank = lhs_stride.len(); + let (_b, m, n, k) = self.0; + let a_skip: usize = match lhs_stride[..rank - 2] { + [s1, stride] if s1 == stride * lhs_l.dims()[1] => stride, + [_, stride] if lhs_l.dims()[0] == 1 => stride, + [stride, _] if lhs_l.dims()[1] == 1 => stride, + [stride] => stride, + [] => m * k, + _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?, + }; + let b_skip: usize = match rhs_stride[..rank - 2] { + [s1, stride] if s1 == stride * rhs_l.dims()[1] => stride, + [_, stride] if rhs_l.dims()[0] == 1 => stride, + [stride, _] if rhs_l.dims()[1] == 1 => stride, + [stride] => stride, + [] => n * k, + _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?, + }; + Ok((a_skip, b_skip)) + } +} + +impl Map2 for MatMul { + const OP: &'static str = "mat_mul"; + + #[cfg(all(not(feature = "mkl"), not(feature = "accelerate")))] + fn f( + &self, + lhs: &[T], + lhs_l: &Layout, + rhs: &[T], + rhs_l: &Layout, + ) -> Result> { + use gemm::{gemm, Parallelism}; + + match T::DTYPE { + DType::F16 | DType::F32 | DType::F64 => {} + _ => Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?, + } + + let (b, m, n, k) = self.0; + let lhs = &lhs[lhs_l.start_offset()..]; + let rhs = &rhs[rhs_l.start_offset()..]; + + let lhs_stride = lhs_l.stride(); + let rhs_stride = rhs_l.stride(); + let rank = lhs_stride.len(); + let lhs_cs = lhs_stride[rank - 1]; + let lhs_rs = lhs_stride[rank - 2]; + + let rhs_cs = rhs_stride[rank - 1]; + let rhs_rs = rhs_stride[rank - 2]; + + let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?; + let c_skip: usize = m * n; + + let dst_shape: Shape = (m, n).into(); + let dst_strides = dst_shape.stride_contiguous(); + let dst_rs = dst_strides[0]; + let dst_cs = dst_strides[1]; + + let mut dst = vec![T::zero(); b * m * n]; + let num_threads = crate::utils::get_num_threads(); + let parallelism = if num_threads > 1 { + Parallelism::Rayon(num_threads) + } else { + Parallelism::None + }; + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut dst[step * c_skip..]; + unsafe { + gemm( + /* m: usize = */ m, + /* n: usize = */ n, + /* k: usize = */ k, + /* dst: *mut T = */ dst_p.as_mut_ptr(), + /* dst_cs: isize = */ dst_cs as isize, + /* dst_rs: isize = */ dst_rs as isize, + /* read_dst: bool = */ false, + /* lhs: *const T = */ lhs_p.as_ptr(), + /* lhs_cs: isize = */ lhs_cs as isize, + /* lhs_rs: isize = */ lhs_rs as isize, + /* rhs: *const T = */ rhs_p.as_ptr(), + /* rhs_cs: isize = */ rhs_cs as isize, + /* rhs_rs: isize = */ rhs_rs as isize, + /* alpha: T = */ T::zero(), + /* beta: T = */ T::one(), + /* conj_dst: bool = */ false, + /* conj_lhs: bool = */ false, + /* conj_rhs: bool = */ false, + parallelism, + ) + } + } + Ok(dst) + } + + #[cfg(feature = "accelerate")] + fn f( + &self, + lhs: &[T], + lhs_l: &Layout, + rhs: &[T], + rhs_l: &Layout, + ) -> Result> { + let (b, m, n, k) = self.0; + let lhs = &lhs[lhs_l.start_offset()..]; + let rhs = &rhs[rhs_l.start_offset()..]; + + let lhs_stride = lhs_l.stride(); + let rhs_stride = rhs_l.stride(); + + let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?; + let c_skip: usize = m * n; + + let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; + let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; + let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; + let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; + + let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) { + (n as i32, b'N') + } else if rhs_m1 == k && rhs_m2 == 1 { + (k as i32, b'T') + } else { + Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))? + }; + // The b tensor has dims batching, m, k (lhs) + let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) { + (k as i32, b'N') + } else if lhs_m1 == m && lhs_m2 == 1 { + (m as i32, b'T') + } else { + Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))? + }; + + let mut dst = vec![T::zero(); b * m * n]; + match T::DTYPE { + DType::F16 => { + crate::bail!("the accelerate backend does not support f16 matmul") + } + DType::F32 => { + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut dst[step * c_skip..]; + unsafe { + let a = rhs_p.as_ptr() as *const f32; + let b = lhs_p.as_ptr() as *const f32; + let c = dst_p.as_mut_ptr() as *mut f32; + let a = std::slice::from_raw_parts(a, a_skip); + let b = std::slice::from_raw_parts(b, b_skip); + let c = std::slice::from_raw_parts_mut(c, c_skip); + crate::accelerate::sgemm( + transa, transb, /* m= */ n as i32, /* n= */ m as i32, + /* k= */ k as i32, /* alpha= */ 1., /* a= */ a, + /* lda= */ lda, /* b= */ b, /* ldb= */ ldb, + /* beta= */ 0., /* c= */ c, /* ldc= */ n as i32, + ) + } + } + } + DType::F64 => { + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut dst[step * c_skip..]; + unsafe { + let a = rhs_p.as_ptr() as *const f64; + let b = lhs_p.as_ptr() as *const f64; + let c = dst_p.as_mut_ptr() as *mut f64; + let a = std::slice::from_raw_parts(a, a_skip); + let b = std::slice::from_raw_parts(b, b_skip); + let c = std::slice::from_raw_parts_mut(c, c_skip); + crate::accelerate::dgemm( + transa, transb, /* m= */ n as i32, /* n= */ m as i32, + /* k= */ k as i32, /* alpha= */ 1., /* a= */ a, + /* lda= */ lda, /* b= */ b, /* ldb= */ ldb, + /* beta= */ 0., /* c= */ c, /* ldc= */ n as i32, + ) + } + } + } + dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?, + } + Ok(dst) + } + + #[cfg(feature = "mkl")] + fn f( + &self, + lhs: &[T], + lhs_l: &Layout, + rhs: &[T], + rhs_l: &Layout, + ) -> Result> { + let (b, m, n, k) = self.0; + let lhs = &lhs[lhs_l.start_offset()..]; + let rhs = &rhs[rhs_l.start_offset()..]; + + let lhs_stride = lhs_l.stride(); + let rhs_stride = rhs_l.stride(); + + let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?; + let c_skip: usize = m * n; + + let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; + let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; + let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; + let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; + + let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) { + (n as i32, b'N') + } else if rhs_m1 == k && rhs_m2 == 1 { + (k as i32, b'T') + } else { + Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))? + }; + // The b tensor has dims batching, m, k (lhs) + let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) { + (k as i32, b'N') + } else if lhs_m1 == m && lhs_m2 == 1 { + (m as i32, b'T') + } else { + Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))? + }; + + let mut dst = vec![T::zero(); b * m * n]; + match T::DTYPE { + DType::F16 => { + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut dst[step * c_skip..]; + unsafe { + let a = rhs_p.as_ptr() as *const f16; + let b = lhs_p.as_ptr() as *const f16; + let c = dst_p.as_mut_ptr() as *mut f16; + let a = std::slice::from_raw_parts(a, a_skip); + let b = std::slice::from_raw_parts(b, b_skip); + let c = std::slice::from_raw_parts_mut(c, c_skip); + crate::mkl::hgemm( + transa, + transb, + /* m= */ n as i32, + /* n= */ m as i32, + /* k= */ k as i32, + /* alpha= */ f16::ONE, + /* a= */ a, + /* lda= */ lda, + /* b= */ b, + /* ldb= */ ldb, + /* beta= */ f16::ZERO, + /* c= */ c, + /* ldc= */ n as i32, + ) + } + } + } + DType::F32 => { + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut dst[step * c_skip..]; + unsafe { + let a = rhs_p.as_ptr() as *const f32; + let b = lhs_p.as_ptr() as *const f32; + let c = dst_p.as_mut_ptr() as *mut f32; + let a = std::slice::from_raw_parts(a, a_skip); + let b = std::slice::from_raw_parts(b, b_skip); + let c = std::slice::from_raw_parts_mut(c, c_skip); + crate::mkl::sgemm( + transa, transb, /* m= */ n as i32, /* n= */ m as i32, + /* k= */ k as i32, /* alpha= */ 1., /* a= */ a, + /* lda= */ lda, /* b= */ b, /* ldb= */ ldb, + /* beta= */ 0., /* c= */ c, /* ldc= */ n as i32, + ) + } + } + } + DType::F64 => { + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut dst[step * c_skip..]; + unsafe { + let a = rhs_p.as_ptr() as *const f64; + let b = lhs_p.as_ptr() as *const f64; + let c = dst_p.as_mut_ptr() as *mut f64; + let a = std::slice::from_raw_parts(a, a_skip); + let b = std::slice::from_raw_parts(b, b_skip); + let c = std::slice::from_raw_parts_mut(c, c_skip); + crate::mkl::dgemm( + transa, transb, /* m= */ n as i32, /* n= */ m as i32, + /* k= */ k as i32, /* alpha= */ 1., /* a= */ a, + /* lda= */ lda, /* b= */ b, /* ldb= */ ldb, + /* beta= */ 0., /* c= */ c, /* ldc= */ n as i32, + ) + } + } + } + dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?, + } + Ok(dst) + } +} + +struct MatMulWithBias(MatMul); + +impl Deref for MatMulWithBias { + type Target = MatMul; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Map3 for MatMulWithBias { + const OP: &'static str = "mat_mul_ac"; + + #[cfg(all(not(feature = "mkl"), not(feature = "accelerate")))] + fn f( + &self, + lhs: &[T], + lhs_l: &Layout, + rhs: &[T], + rhs_l: &Layout, + c: &mut [T], + c_l: &Layout, + s: Option, + ) -> Result<()> { + use gemm::{gemm, Parallelism}; + + match T::DTYPE { + DType::F16 | DType::F32 | DType::F64 => {} + _ => Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?, + } + + let (b, m, n, k) = self.0 .0; + let lhs = &lhs[lhs_l.start_offset()..]; + let rhs = &rhs[rhs_l.start_offset()..]; + + let lhs_stride = lhs_l.stride(); + let rhs_stride = rhs_l.stride(); + let rank = lhs_stride.len(); + let lhs_cs = lhs_stride[rank - 1]; + let lhs_rs = lhs_stride[rank - 2]; + + let rhs_cs = rhs_stride[rank - 1]; + let rhs_rs = rhs_stride[rank - 2]; + + let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?; + let c_skip: usize = m * n; + + let dst_shape: Shape = (m, n).into(); + let dst_strides = dst_shape.stride_contiguous(); + let dst_rs = dst_strides[0]; + let dst_cs = dst_strides[1]; + + let num_threads = crate::utils::get_num_threads(); + let parallelism = if num_threads > 1 { + Parallelism::Rayon(num_threads) + } else { + Parallelism::None + }; + + match c_l.contiguous_offsets() { + Some((o1, o2)) => { + if o1 != 0 { + crate::bail!("`c` start offset must be 0"); + } + if o2 != b * m * n { + crate::bail!("`c` end offset must be {}", b * m * n) + } + } + None => crate::bail!("`c` has to be contiguous"), + }; + + let alpha = T::from_f64(s.unwrap_or(1.0)); + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut c[step * c_skip..]; + unsafe { + gemm( + /* m: usize = */ m, + /* n: usize = */ n, + /* k: usize = */ k, + /* dst: *mut T = */ dst_p.as_mut_ptr(), + /* dst_cs: isize = */ dst_cs as isize, + /* dst_rs: isize = */ dst_rs as isize, + /* read_dst: bool = */ true, + /* lhs: *const T = */ lhs_p.as_ptr(), + /* lhs_cs: isize = */ lhs_cs as isize, + /* lhs_rs: isize = */ lhs_rs as isize, + /* rhs: *const T = */ rhs_p.as_ptr(), + /* rhs_cs: isize = */ rhs_cs as isize, + /* rhs_rs: isize = */ rhs_rs as isize, + /* alpha: T = */ T::one(), + /* beta: T = */ alpha, + /* conj_dst: bool = */ false, + /* conj_lhs: bool = */ false, + /* conj_rhs: bool = */ false, + parallelism, + ) + } + } + Ok(()) + } + + #[cfg(feature = "accelerate")] + fn f( + &self, + lhs: &[T], + lhs_l: &Layout, + rhs: &[T], + rhs_l: &Layout, + c: &mut [T], + c_l: &Layout, + s: Option, + ) -> Result<()> { + let (b, m, n, k) = self.0 .0; + let lhs = &lhs[lhs_l.start_offset()..]; + let rhs = &rhs[rhs_l.start_offset()..]; + + let lhs_stride = lhs_l.stride(); + let rhs_stride = rhs_l.stride(); + + let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?; + let c_skip: usize = m * n; + + let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; + let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; + let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; + let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; + + let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) { + (n as i32, b'N') + } else if rhs_m1 == k && rhs_m2 == 1 { + (k as i32, b'T') + } else { + Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))? + }; + // The b tensor has dims batching, m, k (lhs) + let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) { + (k as i32, b'N') + } else if lhs_m1 == m && lhs_m2 == 1 { + (m as i32, b'T') + } else { + Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))? + }; + + match c_l.contiguous_offsets() { + Some((o1, o2)) => { + if o1 != 0 { + crate::bail!("`c` start offset must be 0"); + } + if o2 != b * m * n { + crate::bail!("`c` end offset must be {}", b * m * n) + } + } + None => crate::bail!("`c` has to be contiguous"), + }; + + match T::DTYPE { + DType::F16 => { + crate::bail!("the accelerate backend does not support f16 matmul") + } + DType::F32 => { + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut c[step * c_skip..]; + unsafe { + let a = rhs_p.as_ptr() as *const f32; + let b = lhs_p.as_ptr() as *const f32; + let c = dst_p.as_mut_ptr() as *mut f32; + let a = std::slice::from_raw_parts(a, a_skip); + let b = std::slice::from_raw_parts(b, b_skip); + let c = std::slice::from_raw_parts_mut(c, c_skip); + crate::accelerate::sgemm( + transa, + transb, + /* m= */ n as i32, + /* n= */ m as i32, + /* k= */ k as i32, + /* alpha= */ s.unwrap_or(1.) as f32, + /* a= */ a, + /* lda= */ lda, + /* b= */ b, + /* ldb= */ ldb, + /* beta= */ 1., + /* c= */ c, + /* ldc= */ n as i32, + ) + } + } + } + DType::F64 => { + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut c[step * c_skip..]; + unsafe { + let a = rhs_p.as_ptr() as *const f64; + let b = lhs_p.as_ptr() as *const f64; + let c = dst_p.as_mut_ptr() as *mut f64; + let a = std::slice::from_raw_parts(a, a_skip); + let b = std::slice::from_raw_parts(b, b_skip); + let c = std::slice::from_raw_parts_mut(c, c_skip); + crate::accelerate::dgemm( + transa, + transb, + /* m= */ n as i32, + /* n= */ m as i32, + /* k= */ k as i32, + /* alpha= */ s.unwrap_or(1.) as f64, + /* a= */ a, + /* lda= */ lda, + /* b= */ b, + /* ldb= */ ldb, + /* beta= */ 1., + /* c= */ c, + /* ldc= */ n as i32, + ) + } + } + } + dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?, + } + Ok(()) + } + + #[cfg(feature = "mkl")] + fn f( + &self, + lhs: &[T], + lhs_l: &Layout, + rhs: &[T], + rhs_l: &Layout, + c: &mut [T], + c_l: &Layout, + s: Option, + ) -> Result<()> { + let (b, m, n, k) = self.0 .0; + let lhs = &lhs[lhs_l.start_offset()..]; + let rhs = &rhs[rhs_l.start_offset()..]; + + let lhs_stride = lhs_l.stride(); + let rhs_stride = rhs_l.stride(); + + let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?; + let c_skip: usize = m * n; + + let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; + let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; + let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; + let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; + + let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) { + (n as i32, b'N') + } else if rhs_m1 == k && rhs_m2 == 1 { + (k as i32, b'T') + } else { + Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))? + }; + // The b tensor has dims batching, m, k (lhs) + let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) { + (k as i32, b'N') + } else if lhs_m1 == m && lhs_m2 == 1 { + (m as i32, b'T') + } else { + Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))? + }; + + match c_l.contiguous_offsets() { + Some((o1, o2)) => { + if o1 != 0 { + crate::bail!("`c` start offset must be 0"); + } + if o2 != b * m * n { + crate::bail!("`c` end offset must be {}", b * m * n) + } + } + None => crate::bail!("`c` has to be contiguous"), + }; + + match T::DTYPE { + DType::F16 => { + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut c[step * c_skip..]; + unsafe { + let a = rhs_p.as_ptr() as *const f16; + let b = lhs_p.as_ptr() as *const f16; + let c = dst_p.as_mut_ptr() as *mut f16; + let a = std::slice::from_raw_parts(a, a_skip); + let b = std::slice::from_raw_parts(b, b_skip); + let c = std::slice::from_raw_parts_mut(c, c_skip); + crate::mkl::hgemm( + transa, + transb, + /* m= */ n as i32, + /* n= */ m as i32, + /* k= */ k as i32, + /* alpha= */ f16::from_f64(s.unwrap_or(1.)), + /* a= */ a, + /* lda= */ lda, + /* b= */ b, + /* ldb= */ ldb, + /* beta= */ f16::ONE, + /* c= */ c, + /* ldc= */ n as i32, + ) + } + } + } + DType::F32 => { + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut c[step * c_skip..]; + unsafe { + let a = rhs_p.as_ptr() as *const f32; + let b = lhs_p.as_ptr() as *const f32; + let c = dst_p.as_mut_ptr() as *mut f32; + let a = std::slice::from_raw_parts(a, a_skip); + let b = std::slice::from_raw_parts(b, b_skip); + let c = std::slice::from_raw_parts_mut(c, c_skip); + crate::mkl::sgemm( + transa, + transb, + /* m= */ n as i32, + /* n= */ m as i32, + /* k= */ k as i32, + /* alpha= */ s.unwrap_or(1.) as f32, + /* a= */ a, + /* lda= */ lda, + /* b= */ b, + /* ldb= */ ldb, + /* beta= */ 0., + /* c= */ c, + /* ldc= */ n as i32, + ) + } + } + } + DType::F64 => { + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut c[step * c_skip..]; + unsafe { + let a = rhs_p.as_ptr() as *const f64; + let b = lhs_p.as_ptr() as *const f64; + let c = dst_p.as_mut_ptr() as *mut f64; + let a = std::slice::from_raw_parts(a, a_skip); + let b = std::slice::from_raw_parts(b, b_skip); + let c = std::slice::from_raw_parts_mut(c, c_skip); + crate::mkl::dgemm( + transa, + transb, + /* m= */ n as i32, + /* n= */ m as i32, + /* k= */ k as i32, + /* alpha= */ s.unwrap_or(1.), + /* a= */ a, + /* lda= */ lda, + /* b= */ b, + /* ldb= */ ldb, + /* beta= */ 0., + /* c= */ c, + /* ldc= */ n as i32, + ) + } + } + } + dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?, + } + Ok(()) + } +} + +struct MatMulWithAlpha(MatMul); + +impl Deref for MatMulWithAlpha { + type Target = MatMul; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Map2Alpha for MatMulWithAlpha { + const OP: &'static str = "mat_mul_a"; + + #[cfg(all(not(feature = "mkl"), not(feature = "accelerate")))] + fn f( + &self, + lhs: &[T], + lhs_l: &Layout, + rhs: &[T], + rhs_l: &Layout, + s: Option, + ) -> Result> { + use gemm::{gemm, Parallelism}; + + match T::DTYPE { + DType::F16 | DType::F32 | DType::F64 => {} + _ => Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?, + } + + let (b, m, n, k) = self.0 .0; + let lhs = &lhs[lhs_l.start_offset()..]; + let rhs = &rhs[rhs_l.start_offset()..]; + + let lhs_stride = lhs_l.stride(); + let rhs_stride = rhs_l.stride(); + let rank = lhs_stride.len(); + let lhs_cs = lhs_stride[rank - 1]; + let lhs_rs = lhs_stride[rank - 2]; + + let rhs_cs = rhs_stride[rank - 1]; + let rhs_rs = rhs_stride[rank - 2]; + + let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?; + let c_skip: usize = m * n; + + let dst_shape: Shape = (m, n).into(); + let dst_strides = dst_shape.stride_contiguous(); + let dst_rs = dst_strides[0]; + let dst_cs = dst_strides[1]; + + let mut dst = vec![T::zero(); b * m * n]; + let num_threads = crate::utils::get_num_threads(); + let parallelism = if num_threads > 1 { + Parallelism::Rayon(num_threads) + } else { + Parallelism::None + }; + + let alpha = T::from_f64(s.unwrap_or(1.0)); + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut dst[step * c_skip..]; + unsafe { + gemm( + /* m: usize = */ m, + /* n: usize = */ n, + /* k: usize = */ k, + /* dst: *mut T = */ dst_p.as_mut_ptr(), + /* dst_cs: isize = */ dst_cs as isize, + /* dst_rs: isize = */ dst_rs as isize, + /* read_dst: bool = */ true, + /* lhs: *const T = */ lhs_p.as_ptr(), + /* lhs_cs: isize = */ lhs_cs as isize, + /* lhs_rs: isize = */ lhs_rs as isize, + /* rhs: *const T = */ rhs_p.as_ptr(), + /* rhs_cs: isize = */ rhs_cs as isize, + /* rhs_rs: isize = */ rhs_rs as isize, + /* alpha: T = */ T::one(), + /* beta: T = */ alpha, + /* conj_dst: bool = */ false, + /* conj_lhs: bool = */ false, + /* conj_rhs: bool = */ false, + parallelism, + ) + } + } + Ok(dst) + } + + #[cfg(feature = "accelerate")] + fn f( + &self, + lhs: &[T], + lhs_l: &Layout, + rhs: &[T], + rhs_l: &Layout, + s: Option, + ) -> Result> { + let (b, m, n, k) = self.0 .0; + let lhs = &lhs[lhs_l.start_offset()..]; + let rhs = &rhs[rhs_l.start_offset()..]; + + let lhs_stride = lhs_l.stride(); + let rhs_stride = rhs_l.stride(); + + let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?; + let c_skip: usize = m * n; + + let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; + let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; + let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; + let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; + + let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) { + (n as i32, b'N') + } else if rhs_m1 == k && rhs_m2 == 1 { + (k as i32, b'T') + } else { + Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))? + }; + // The b tensor has dims batching, m, k (lhs) + let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) { + (k as i32, b'N') + } else if lhs_m1 == m && lhs_m2 == 1 { + (m as i32, b'T') + } else { + Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))? + }; + + let mut dst = vec![T::zero(); b * m * n]; + match T::DTYPE { + DType::F16 => { + crate::bail!("the accelerate backend does not support f16 matmul") + } + DType::F32 => { + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut dst[step * c_skip..]; + unsafe { + let a = rhs_p.as_ptr() as *const f32; + let b = lhs_p.as_ptr() as *const f32; + let c = dst_p.as_mut_ptr() as *mut f32; + let a = std::slice::from_raw_parts(a, a_skip); + let b = std::slice::from_raw_parts(b, b_skip); + let c = std::slice::from_raw_parts_mut(c, c_skip); + crate::accelerate::sgemm( + transa, + transb, + /* m= */ n as i32, + /* n= */ m as i32, + /* k= */ k as i32, + /* alpha= */ s.unwrap_or(1.) as f32, + /* a= */ a, + /* lda= */ lda, + /* b= */ b, + /* ldb= */ ldb, + /* beta= */ 1., + /* c= */ c, + /* ldc= */ n as i32, + ) + } + } + } + DType::F64 => { + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut dst[step * c_skip..]; + unsafe { + let a = rhs_p.as_ptr() as *const f64; + let b = lhs_p.as_ptr() as *const f64; + let c = dst_p.as_mut_ptr() as *mut f64; + let a = std::slice::from_raw_parts(a, a_skip); + let b = std::slice::from_raw_parts(b, b_skip); + let c = std::slice::from_raw_parts_mut(c, c_skip); + crate::accelerate::dgemm( + transa, + transb, + /* m= */ n as i32, + /* n= */ m as i32, + /* k= */ k as i32, + /* alpha= */ s.unwrap_or(1.), + /* a= */ a, + /* lda= */ lda, + /* b= */ b, + /* ldb= */ ldb, + /* beta= */ 1., + /* c= */ c, + /* ldc= */ n as i32, + ) + } + } + } + dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?, + } + Ok(dst) + } + + #[cfg(feature = "mkl")] + fn f( + &self, + lhs: &[T], + lhs_l: &Layout, + rhs: &[T], + rhs_l: &Layout, + s: Option, + ) -> Result> { + let (b, m, n, k) = self.0 .0; + let lhs = &lhs[lhs_l.start_offset()..]; + let rhs = &rhs[rhs_l.start_offset()..]; + + let lhs_stride = lhs_l.stride(); + let rhs_stride = rhs_l.stride(); + + let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?; + let c_skip: usize = m * n; + + let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; + let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; + let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; + let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; + + let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) { + (n as i32, b'N') + } else if rhs_m1 == k && rhs_m2 == 1 { + (k as i32, b'T') + } else { + Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))? + }; + // The b tensor has dims batching, m, k (lhs) + let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) { + (k as i32, b'N') + } else if lhs_m1 == m && lhs_m2 == 1 { + (m as i32, b'T') + } else { + Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))? + }; + + let mut dst = vec![T::zero(); b * m * n]; + match T::DTYPE { + DType::F16 => { + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut dst[step * c_skip..]; + unsafe { + let a = rhs_p.as_ptr() as *const f16; + let b = lhs_p.as_ptr() as *const f16; + let c = dst_p.as_mut_ptr() as *mut f16; + let a = std::slice::from_raw_parts(a, a_skip); + let b = std::slice::from_raw_parts(b, b_skip); + let c = std::slice::from_raw_parts_mut(c, c_skip); + crate::mkl::hgemm( + transa, + transb, + /* m= */ n as i32, + /* n= */ m as i32, + /* k= */ k as i32, + /* alpha= */ f16::from_f64(s.unwrap_or(1.)), + /* a= */ a, + /* lda= */ lda, + /* b= */ b, + /* ldb= */ ldb, + /* beta= */ f16::ONE, + /* c= */ c, + /* ldc= */ n as i32, + ) + } + } + } + DType::F32 => { + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut dst[step * c_skip..]; + unsafe { + let a = rhs_p.as_ptr() as *const f32; + let b = lhs_p.as_ptr() as *const f32; + let c = dst_p.as_mut_ptr() as *mut f32; + let a = std::slice::from_raw_parts(a, a_skip); + let b = std::slice::from_raw_parts(b, b_skip); + let c = std::slice::from_raw_parts_mut(c, c_skip); + crate::mkl::sgemm( + transa, + transb, + /* m= */ n as i32, + /* n= */ m as i32, + /* k= */ k as i32, + /* alpha= */ s.unwrap_or(1.) as f32, + /* a= */ a, + /* lda= */ lda, + /* b= */ b, + /* ldb= */ ldb, + /* beta= */ 0., + /* c= */ c, + /* ldc= */ n as i32, + ) + } + } + } + DType::F64 => { + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut dst[step * c_skip..]; + unsafe { + let a = rhs_p.as_ptr() as *const f64; + let b = lhs_p.as_ptr() as *const f64; + let c = dst_p.as_mut_ptr() as *mut f64; + let a = std::slice::from_raw_parts(a, a_skip); + let b = std::slice::from_raw_parts(b, b_skip); + let c = std::slice::from_raw_parts_mut(c, c_skip); + crate::mkl::dgemm( + transa, + transb, + /* m= */ n as i32, + /* n= */ m as i32, + /* k= */ k as i32, + /* alpha= */ s.unwrap_or(1.), + /* a= */ a, + /* lda= */ lda, + /* b= */ b, + /* ldb= */ ldb, + /* beta= */ 0., + /* c= */ c, + /* ldc= */ n as i32, + ) + } + } + } + dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?, + } + Ok(dst) + } +} + +fn elu(v: T, alpha: T) -> T { + if v.is_sign_positive() { + v + } else { + (v.exp() - T::one()) * alpha + } +} + +impl CpuStorage { + pub fn as_slice(&self) -> Result<&[D]> { + D::cpu_storage_as_slice(self) + } + + pub fn concat(storages: &[CpuStorage]) -> Result { + let storage0 = &storages[0]; + let s = match storage0 { + Self::U8(_) => { + let storages = storages + .iter() + .map(|s| match s { + Self::U8(s) => Ok(s.as_slice()), + _ => crate::bail!("dtype mismatch"), + }) + .collect::>>()? + .concat(); + Self::U8(storages) + } + Self::U32(_) => { + let storages = storages + .iter() + .map(|s| match s { + Self::U32(s) => Ok(s.as_slice()), + _ => crate::bail!("dtype mismatch"), + }) + .collect::>>()? + .concat(); + Self::U32(storages) + } + Self::I16(_) => { + let storages = storages + .iter() + .map(|s| match s { + Self::I16(s) => Ok(s.as_slice()), + _ => crate::bail!("dtype mismatch"), + }) + .collect::>>()? + .concat(); + Self::I16(storages) + } + Self::I32(_) => { + let storages = storages + .iter() + .map(|s| match s { + Self::I32(s) => Ok(s.as_slice()), + _ => crate::bail!("dtype mismatch"), + }) + .collect::>>()? + .concat(); + Self::I32(storages) + } + Self::I64(_) => { + let storages = storages + .iter() + .map(|s| match s { + Self::I64(s) => Ok(s.as_slice()), + _ => crate::bail!("dtype mismatch"), + }) + .collect::>>()? + .concat(); + Self::I64(storages) + } + Self::BF16(_) => { + let storages = storages + .iter() + .map(|s| match s { + Self::BF16(s) => Ok(s.as_slice()), + _ => crate::bail!("dtype mismatch"), + }) + .collect::>>()? + .concat(); + Self::BF16(storages) + } + Self::F16(_) => { + let storages = storages + .iter() + .map(|s| match s { + Self::F16(s) => Ok(s.as_slice()), + _ => crate::bail!("dtype mismatch"), + }) + .collect::>>()? + .concat(); + Self::F16(storages) + } + Self::F32(_) => { + let storages = storages + .iter() + .map(|s| match s { + Self::F32(s) => Ok(s.as_slice()), + _ => crate::bail!("dtype mismatch"), + }) + .collect::>>()? + .concat(); + Self::F32(storages) + } + Self::F64(_) => { + let storages = storages + .iter() + .map(|s| match s { + Self::F64(s) => Ok(s.as_slice()), + _ => crate::bail!("dtype mismatch"), + }) + .collect::>>()? + .concat(); + Self::F64(storages) + } + Self::F8E4M3(_) => { + let storages = storages + .iter() + .map(|s| match s { + Self::F8E4M3(s) => Ok(s.as_slice()), + _ => crate::bail!("dtype mismatch"), + }) + .collect::>>()? + .concat(); + Self::F8E4M3(storages) + } + }; + Ok(s) + } +} + +impl BackendStorage for CpuStorage { + type Device = CpuDevice; + + fn dtype(&self) -> DType { + match self { + Self::U8(_) => DType::U8, + Self::U32(_) => DType::U32, + Self::I16(_) => DType::I16, + Self::I32(_) => DType::I32, + Self::I64(_) => DType::I64, + Self::BF16(_) => DType::BF16, + Self::F16(_) => DType::F16, + Self::F32(_) => DType::F32, + Self::F64(_) => DType::F64, + Self::F8E4M3(_) => DType::F8E4M3, + } + } + + fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result { + // TODO: find a way around the quadratic number of cases below. + match (self, dtype) { + (Self::U8(storage), DType::BF16) => { + let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32)); + Ok(Self::BF16(data)) + } + (Self::U32(storage), DType::BF16) => { + let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32)); + Ok(Self::BF16(data)) + } + (Self::I16(storage), DType::BF16) => { + let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32)); + Ok(Self::BF16(data)) + } + (Self::I32(storage), DType::BF16) => { + let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32)); + Ok(Self::BF16(data)) + } + (Self::I64(storage), DType::BF16) => { + let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32)); + Ok(Self::BF16(data)) + } + (Self::BF16(storage), DType::BF16) => { + let data = unary_map(storage, layout, |v| v); + Ok(Self::BF16(data)) + } + (Self::F16(storage), DType::BF16) => { + let data = unary_map(storage, layout, |v| bf16::from_f32(v.to_f32())); + Ok(Self::BF16(data)) + } + (Self::F32(storage), DType::BF16) => { + let data = unary_map(storage, layout, bf16::from_f32); + Ok(Self::BF16(data)) + } + (Self::F64(storage), DType::BF16) => { + let data = unary_map(storage, layout, bf16::from_f64); + Ok(Self::BF16(data)) + } + (Self::F8E4M3(storage), DType::BF16) => { + let data = unary_map(storage, layout, |v| bf16::from_f32(v.to_f32())); + Ok(Self::BF16(data)) + } + (Self::U8(storage), DType::F16) => { + let data = unary_map(storage, layout, |v| f16::from_f32(v as f32)); + Ok(Self::F16(data)) + } + (Self::U32(storage), DType::F16) => { + let data = unary_map(storage, layout, |v| f16::from_f32(v as f32)); + Ok(Self::F16(data)) + } + (Self::I16(storage), DType::F16) => { + let data = unary_map(storage, layout, |v| f16::from_f32(v as f32)); + Ok(Self::F16(data)) + } + (Self::I32(storage), DType::F16) => { + let data = unary_map(storage, layout, |v| f16::from_f32(v as f32)); + Ok(Self::F16(data)) + } + (Self::I64(storage), DType::F16) => { + let data = unary_map(storage, layout, |v| f16::from_f32(v as f32)); + Ok(Self::F16(data)) + } + (Self::BF16(storage), DType::F16) => { + let data = unary_map(storage, layout, |v| f16::from_f32(v.to_f32())); + Ok(Self::F16(data)) + } + (Self::F16(storage), DType::F16) => { + let data = unary_map(storage, layout, |v| v); + Ok(Self::F16(data)) + } + (Self::F32(storage), DType::F16) => { + let data = unary_map(storage, layout, f16::from_f32); + Ok(Self::F16(data)) + } + (Self::F64(storage), DType::F16) => { + let data = unary_map(storage, layout, f16::from_f64); + Ok(Self::F16(data)) + } + (Self::F8E4M3(storage), DType::F16) => { + let data = unary_map(storage, layout, |v| f16::from_f32(v.to_f32())); + Ok(Self::F16(data)) + } + (Self::U8(storage), DType::F32) => { + let data = unary_map(storage, layout, |v| v as f32); + Ok(Self::F32(data)) + } + (Self::U32(storage), DType::F32) => { + let data = unary_map(storage, layout, |v| v as f32); + Ok(Self::F32(data)) + } + (Self::I16(storage), DType::F32) => { + let data = unary_map(storage, layout, |v| v as f32); + Ok(Self::F32(data)) + } + (Self::I32(storage), DType::F32) => { + let data = unary_map(storage, layout, |v| v as f32); + Ok(Self::F32(data)) + } + (Self::I64(storage), DType::F32) => { + let data = unary_map(storage, layout, |v| v as f32); + Ok(Self::F32(data)) + } + (Self::BF16(storage), DType::F32) => { + let data = unary_map(storage, layout, |v| v.to_f32()); + Ok(Self::F32(data)) + } + (Self::F16(storage), DType::F32) => { + let data = unary_map(storage, layout, |v| v.to_f32()); + Ok(Self::F32(data)) + } + (Self::F32(storage), DType::F32) => { + let data = unary_map(storage, layout, |v| v); + Ok(Self::F32(data)) + } + (Self::F64(storage), DType::F32) => { + let data = unary_map(storage, layout, |v| v as f32); + Ok(Self::F32(data)) + } + (Self::F8E4M3(storage), DType::F32) => { + let data = unary_map(storage, layout, |v| v.to_f32()); + Ok(Self::F32(data)) + } + (Self::U8(storage), DType::U8) => { + let data = unary_map(storage, layout, |v| v); + Ok(Self::U8(data)) + } + (Self::BF16(storage), DType::U8) => { + let data = unary_map(storage, layout, |v| v.to_f32() as u8); + Ok(Self::U8(data)) + } + (Self::F16(storage), DType::U8) => { + let data = unary_map(storage, layout, |v| v.to_f32() as u8); + Ok(Self::U8(data)) + } + (Self::F32(storage), DType::U8) => { + let data = unary_map(storage, layout, |v| v as u8); + Ok(Self::U8(data)) + } + (Self::F64(storage), DType::U8) => { + let data = unary_map(storage, layout, |v| v as u8); + Ok(Self::U8(data)) + } + (Self::U32(storage), DType::U8) => { + let data = unary_map(storage, layout, |v| v as u8); + Ok(Self::U8(data)) + } + (Self::I16(storage), DType::U8) => { + let data = unary_map(storage, layout, |v| v as u8); + Ok(Self::U8(data)) + } + (Self::I32(storage), DType::U8) => { + let data = unary_map(storage, layout, |v| v as u8); + Ok(Self::U8(data)) + } + (Self::I64(storage), DType::U8) => { + let data = unary_map(storage, layout, |v| v as u8); + Ok(Self::U8(data)) + } + (Self::F8E4M3(storage), DType::U8) => { + let data = unary_map(storage, layout, |v| v.to_f32() as u8); + Ok(Self::U8(data)) + } + (Self::U8(storage), DType::U32) => { + let data = unary_map(storage, layout, |v| v as u32); + Ok(Self::U32(data)) + } + (Self::U32(storage), DType::U32) => { + let data = unary_map(storage, layout, |v| v); + Ok(Self::U32(data)) + } + (Self::I16(storage), DType::U32) => { + let data = unary_map(storage, layout, |v| v as u32); + Ok(Self::U32(data)) + } + (Self::I32(storage), DType::U32) => { + let data = unary_map(storage, layout, |v| v as u32); + Ok(Self::U32(data)) + } + (Self::I64(storage), DType::U32) => { + let data = unary_map(storage, layout, |v| v as u32); + Ok(Self::U32(data)) + } + (Self::BF16(storage), DType::U32) => { + let data = unary_map(storage, layout, |v| v.to_f32() as u32); + Ok(Self::U32(data)) + } + (Self::F16(storage), DType::U32) => { + let data = unary_map(storage, layout, |v| v.to_f32() as u32); + Ok(Self::U32(data)) + } + (Self::F32(storage), DType::U32) => { + let data = unary_map(storage, layout, |v| v as u32); + Ok(Self::U32(data)) + } + (Self::F64(storage), DType::U32) => { + let data = unary_map(storage, layout, |v| v as u32); + Ok(Self::U32(data)) + } + (Self::F8E4M3(storage), DType::U32) => { + let data = unary_map(storage, layout, |v| v.to_f32() as u32); + Ok(Self::U32(data)) + } + (Self::U8(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v as i64); + Ok(Self::I64(data)) + } + (Self::U32(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v as i64); + Ok(Self::I64(data)) + } + (Self::I16(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v); + Ok(Self::I16(data)) + } + (Self::I32(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v as i16); + Ok(Self::I16(data)) + } + (Self::I64(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v as i16); + Ok(Self::I16(data)) + } + (Self::BF16(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v.to_f32() as i16); + Ok(Self::I16(data)) + } + (Self::F16(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v.to_f32() as i16); + Ok(Self::I16(data)) + } + (Self::F32(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v as i16); + Ok(Self::I16(data)) + } + (Self::F64(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v as i16); + Ok(Self::I16(data)) + } + (Self::F8E4M3(storage), DType::I16) => { + let data = unary_map(storage, layout, |v| v.to_f32() as i16); + Ok(Self::I16(data)) + } + (Self::U8(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v as i64); + Ok(Self::I64(data)) + } + (Self::U32(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v as i64); + Ok(Self::I64(data)) + } + (Self::I16(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v as i32); + Ok(Self::I32(data)) + } + (Self::I32(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v); + Ok(Self::I32(data)) + } + (Self::I64(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v as i32); + Ok(Self::I32(data)) + } + (Self::BF16(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v.to_f32() as i32); + Ok(Self::I32(data)) + } + (Self::F16(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v.to_f32() as i32); + Ok(Self::I32(data)) + } + (Self::F32(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v as i32); + Ok(Self::I32(data)) + } + (Self::F64(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v as i32); + Ok(Self::I32(data)) + } + (Self::F8E4M3(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v.to_f32() as i32); + Ok(Self::I32(data)) + } + (Self::U8(storage), DType::I64) => { + let data = unary_map(storage, layout, |v| v as i64); + Ok(Self::I64(data)) + } + (Self::U32(storage), DType::I64) => { + let data = unary_map(storage, layout, |v| v as i64); + Ok(Self::I64(data)) + } + (Self::I16(storage), DType::I64) => { + let data = unary_map(storage, layout, |v| v as i64); + Ok(Self::I64(data)) + } + (Self::I32(storage), DType::I64) => { + let data = unary_map(storage, layout, |v| v as i64); + Ok(Self::I64(data)) + } + (Self::I64(storage), DType::I64) => { + let data = unary_map(storage, layout, |v| v); + Ok(Self::I64(data)) + } + (Self::BF16(storage), DType::I64) => { + let data = unary_map(storage, layout, |v| v.to_f32() as i64); + Ok(Self::I64(data)) + } + (Self::F16(storage), DType::I64) => { + let data = unary_map(storage, layout, |v| v.to_f32() as i64); + Ok(Self::I64(data)) + } + (Self::F32(storage), DType::I64) => { + let data = unary_map(storage, layout, |v| v as i64); + Ok(Self::I64(data)) + } + (Self::F64(storage), DType::I64) => { + let data = unary_map(storage, layout, |v| v as i64); + Ok(Self::I64(data)) + } + (Self::F8E4M3(storage), DType::I64) => { + let data = unary_map(storage, layout, |v| v.to_f32() as i64); + Ok(Self::I64(data)) + } + (Self::U8(storage), DType::F64) => { + let data = unary_map(storage, layout, |v| v as f64); + Ok(Self::F64(data)) + } + (Self::U32(storage), DType::F64) => { + let data = unary_map(storage, layout, |v| v as f64); + Ok(Self::F64(data)) + } + (Self::I16(storage), DType::F64) => { + let data = unary_map(storage, layout, |v| v as f64); + Ok(Self::F64(data)) + } + (Self::I32(storage), DType::F64) => { + let data = unary_map(storage, layout, |v| v as f64); + Ok(Self::F64(data)) + } + (Self::I64(storage), DType::F64) => { + let data = unary_map(storage, layout, |v| v as f64); + Ok(Self::F64(data)) + } + (Self::BF16(storage), DType::F64) => { + let data = unary_map(storage, layout, |v| v.to_f64()); + Ok(Self::F64(data)) + } + (Self::F16(storage), DType::F64) => { + let data = unary_map(storage, layout, |v| v.to_f64()); + Ok(Self::F64(data)) + } + (Self::F32(storage), DType::F64) => { + let data = unary_map(storage, layout, |v| v as f64); + Ok(Self::F64(data)) + } + (Self::F64(storage), DType::F64) => { + let data = unary_map(storage, layout, |v| v); + Ok(Self::F64(data)) + } + (Self::F8E4M3(storage), DType::F64) => { + let data = unary_map(storage, layout, |v| v.to_f64()); + Ok(Self::F64(data)) + } + (Self::U8(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32)); + Ok(Self::F8E4M3(data)) + } + (Self::U32(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32)); + Ok(Self::F8E4M3(data)) + } + (Self::I16(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32)); + Ok(Self::F8E4M3(data)) + } + (Self::I32(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32)); + Ok(Self::F8E4M3(data)) + } + (Self::I64(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32)); + Ok(Self::F8E4M3(data)) + } + (Self::BF16(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, |v| F8E4M3::from(v.to_f32())); + Ok(Self::F8E4M3(data)) + } + (Self::F16(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v.to_f32())); + Ok(Self::F8E4M3(data)) + } + (Self::F32(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, F8E4M3::from_f32); + Ok(Self::F8E4M3(data)) + } + (Self::F64(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, F8E4M3::from_f64); + Ok(Self::F8E4M3(data)) + } + (Self::F8E4M3(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, |v| v); + Ok(Self::F8E4M3(data)) + } + } + } + + fn reduce_op(&self, op: ReduceOp, layout: &Layout, reduce_dims: &[usize]) -> Result { + match op { + ReduceOp::Sum => { + let src_dims = layout.dims(); + let mut dst_dims = src_dims.to_vec(); + for &dim in reduce_dims.iter() { + dst_dims[dim] = 1; + } + let dst_shape = Shape::from(dst_dims); + let mut reduce_dims = reduce_dims.to_vec(); + // Sort the reduce_dims as they have to be processed from left to right when converting the + // indexes. + reduce_dims.sort(); + let reduce_dims_and_stride: Vec<_> = reduce_dims + .iter() + .map(|&d| (src_dims[d], src_dims[d + 1..].iter().product::())) + .collect(); + ReduceSum { + dst_shape: &dst_shape, + reduce_dims: &reduce_dims, + reduce_dims_and_stride, + } + .map(self, layout) + } + ReduceOp::Min | ReduceOp::ArgMin | ReduceOp::Max | ReduceOp::ArgMax => { + let reduce_dim_index = match reduce_dims { + [reduce_dim_index] => *reduce_dim_index, + _ => { + let op = match op { + ReduceOp::Min => "min", + ReduceOp::ArgMin => "argmin", + ReduceOp::Max => "max", + ReduceOp::ArgMax => "argmax", + _ => unreachable!(), + }; + let dims = reduce_dims.to_vec(); + Err(Error::OnlySingleDimension { op, dims })? + } + }; + let (use_min, return_index) = match op { + ReduceOp::Min => (true, false), + ReduceOp::ArgMin => (true, true), + ReduceOp::Max => (false, false), + ReduceOp::ArgMax => (false, true), + _ => unreachable!(), + }; + ReduceIndex { + reduce_dim_index, + use_min, + return_index, + } + .map(self, layout) + } + } + } + + fn cmp(&self, op: CmpOp, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result { + Cmp(op).map(self, lhs_l, rhs, rhs_l) + } + + fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result { + Affine(mul, add).map(self, layout) + } + + fn avg_pool2d( + &self, + layout: &Layout, + kernel_size: (usize, usize), + stride: (usize, usize), + ) -> Result { + AvgPool2D(kernel_size, stride).map(self, layout) + } + + fn max_pool2d( + &self, + layout: &Layout, + kernel_size: (usize, usize), + stride: (usize, usize), + ) -> Result { + MaxPool2D(kernel_size, stride).map(self, layout) + } + + fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result { + UpsampleNearest1D(sz).map(self, layout) + } + + fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result { + UpsampleNearest2D(h, w).map(self, layout) + } + + fn powf(&self, layout: &Layout, e: f64) -> Result { + use num_traits::Float; + // TODO: Have some generic map for functions that apply on num_traits::Float elements. + match self { + Self::BF16(storage) => { + let data = unary_map(storage, layout, |v| v.powf(bf16::from_f64(e))); + Ok(Self::BF16(data)) + } + Self::F16(storage) => { + let data = unary_map(storage, layout, |v| v.powf(f16::from_f64(e))); + Ok(Self::F16(data)) + } + Self::F32(storage) => { + let data = unary_map(storage, layout, |v| v.powf(e as f32)); + Ok(Self::F32(data)) + } + Self::F64(storage) => { + let data = unary_map(storage, layout, |v| v.powf(e)); + Ok(Self::F64(data)) + } + Self::F8E4M3(storage) => { + let data = unary_map(storage, layout, |v| v.powf(F8E4M3::from_f64(e))); + Ok(Self::F8E4M3(data)) + } + Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()), + Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()), + Self::I16(_) => Err(Error::UnsupportedDTypeForOp(DType::I16, "elu").bt()), + Self::I32(_) => Err(Error::UnsupportedDTypeForOp(DType::I32, "elu").bt()), + Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()), + } + } + + fn elu(&self, layout: &Layout, alpha: f64) -> Result { + // TODO: Have some generic map for functions that apply on num_traits::Float elements. + match self { + Self::BF16(storage) => { + let data = unary_map(storage, layout, |v| elu(v, bf16::from_f64(alpha))); + Ok(Self::BF16(data)) + } + Self::F16(storage) => { + let data = unary_map(storage, layout, |v| elu(v, f16::from_f64(alpha))); + Ok(Self::F16(data)) + } + Self::F32(storage) => { + let data = unary_map(storage, layout, |v| elu(v, f32::from_f64(alpha))); + Ok(Self::F32(data)) + } + Self::F64(storage) => { + let data = unary_map(storage, layout, |v| elu(v, alpha)); + Ok(Self::F64(data)) + } + Self::F8E4M3(storage) => { + let data = unary_map(storage, layout, |v| elu(v, F8E4M3::from_f64(alpha))); + Ok(Self::F8E4M3(data)) + } + Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()), + Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()), + Self::I16(_) => Err(Error::UnsupportedDTypeForOp(DType::I16, "elu").bt()), + Self::I32(_) => Err(Error::UnsupportedDTypeForOp(DType::I32, "elu").bt()), + Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()), + } + } + + fn unary_impl(&self, layout: &Layout) -> Result { + match self { + Self::BF16(storage) => { + if B::BF16_VEC { + let data = unary_map_vec(storage, layout, B::bf16, B::bf16_vec); + Ok(Self::BF16(data)) + } else { + let data = unary_map(storage, layout, B::bf16); + Ok(Self::BF16(data)) + } + } + Self::F16(storage) => { + if B::F16_VEC { + let data = unary_map_vec(storage, layout, B::f16, B::f16_vec); + Ok(Self::F16(data)) + } else { + let data = unary_map(storage, layout, B::f16); + Ok(Self::F16(data)) + } + } + Self::F32(storage) => { + if B::F32_VEC { + let data = unary_map_vec(storage, layout, B::f32, B::f32_vec); + Ok(Self::F32(data)) + } else { + let data = unary_map(storage, layout, B::f32); + Ok(Self::F32(data)) + } + } + Self::F64(storage) => { + if B::F64_VEC { + let data = unary_map_vec(storage, layout, B::f64, B::f64_vec); + Ok(Self::F64(data)) + } else { + let data = unary_map(storage, layout, B::f64); + Ok(Self::F64(data)) + } + } + Self::F8E4M3(storage) => { + if B::F8E4M3_VEC { + let data = unary_map_vec(storage, layout, B::f8e4m3, B::f8e4m3_vec); + Ok(Self::F8E4M3(data)) + } else { + let data = unary_map(storage, layout, B::f8e4m3); + Ok(Self::F8E4M3(data)) + } + } + Self::U8(storage) => { + let data = unary_map(storage, layout, B::u8); + Ok(Self::U8(data)) + } + Self::U32(storage) => { + let data = unary_map(storage, layout, B::u32); + Ok(Self::U32(data)) + } + Self::I16(storage) => { + let data = unary_map(storage, layout, B::i16); + Ok(Self::I16(data)) + } + Self::I32(storage) => { + let data = unary_map(storage, layout, B::i32); + Ok(Self::I32(data)) + } + Self::I64(storage) => { + let data = unary_map(storage, layout, B::i64); + Ok(Self::I64(data)) + } + } + } + + fn binary_impl( + &self, + rhs: &Self, + lhs_l: &Layout, + rhs_l: &Layout, + ) -> Result { + match (self, rhs) { + (Self::BF16(lhs), Self::BF16(rhs)) => { + let data = if B::BF16_VEC { + binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::bf16, B::bf16_vec) + } else { + binary_map(lhs_l, rhs_l, lhs, rhs, B::bf16) + }; + Ok(Self::BF16(data)) + } + (Self::F16(lhs), Self::F16(rhs)) => { + let data = if B::F16_VEC { + binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::f16, B::f16_vec) + } else { + binary_map(lhs_l, rhs_l, lhs, rhs, B::f16) + }; + Ok(Self::F16(data)) + } + (Self::F32(lhs), Self::F32(rhs)) => { + let data = if B::F32_VEC { + binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::f32, B::f32_vec) + } else { + binary_map(lhs_l, rhs_l, lhs, rhs, B::f32) + }; + Ok(Self::F32(data)) + } + (Self::F64(lhs), Self::F64(rhs)) => { + let data = if B::F64_VEC { + binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::f64, B::f64_vec) + } else { + binary_map(lhs_l, rhs_l, lhs, rhs, B::f64) + }; + Ok(Self::F64(data)) + } + (Self::U32(lhs), Self::U32(rhs)) => { + let data = if B::U32_VEC { + binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::u32, B::u32_vec) + } else { + binary_map(lhs_l, rhs_l, lhs, rhs, B::u32) + }; + Ok(Self::U32(data)) + } + (Self::I16(lhs), Self::I16(rhs)) => { + let data = if B::I16_VEC { + binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::i16, B::i16_vec) + } else { + binary_map(lhs_l, rhs_l, lhs, rhs, B::i16) + }; + Ok(Self::I16(data)) + } + (Self::I32(lhs), Self::I32(rhs)) => { + let data = if B::I32_VEC { + binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::i32, B::i32_vec) + } else { + binary_map(lhs_l, rhs_l, lhs, rhs, B::i32) + }; + Ok(Self::I32(data)) + } + (Self::I64(lhs), Self::I64(rhs)) => { + let data = if B::I64_VEC { + binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::i64, B::i64_vec) + } else { + binary_map(lhs_l, rhs_l, lhs, rhs, B::i64) + }; + Ok(Self::I64(data)) + } + (Self::U8(lhs), Self::U8(rhs)) => { + let data = if B::U8_VEC { + binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::u8, B::u8_vec) + } else { + binary_map(lhs_l, rhs_l, lhs, rhs, B::u8) + }; + Ok(Self::U8(data)) + } + _ => { + // This should be covered by the dtype check above. + Err(Error::DTypeMismatchBinaryOp { + lhs: self.dtype(), + rhs: rhs.dtype(), + op: B::NAME, + } + .bt()) + } + } + } + + fn copy2d( + &self, + dst: &mut Self, + d1: usize, + d2: usize, + src_s: usize, + dst_s: usize, + src_o: usize, + dst_o: usize, + ) -> Result<()> { + match (self, dst) { + (Self::U8(src), Self::U8(dst)) => copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o), + (Self::U32(src), Self::U32(dst)) => { + copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) + } + (Self::I16(src), Self::I16(dst)) => { + copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) + } + (Self::I32(src), Self::I32(dst)) => { + copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) + } + (Self::I64(src), Self::I64(dst)) => { + copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) + } + (Self::BF16(src), Self::BF16(dst)) => { + copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) + } + (Self::F16(src), Self::F16(dst)) => { + copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) + } + (Self::F32(src), Self::F32(dst)) => { + copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) + } + (Self::F64(src), Self::F64(dst)) => { + copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) + } + (_, dst) => { + return Err(Error::DTypeMismatchBinaryOp { + lhs: self.dtype(), + rhs: dst.dtype(), + op: "copy2d", + } + .bt()); + } + } + Ok(()) + } + + fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { + match (self, dst) { + (Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), + (Self::U32(src), Self::U32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), + (Self::I16(src), Self::I16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), + (Self::I32(src), Self::I32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), + (Self::I64(src), Self::I64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), + (Self::BF16(src), Self::BF16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), + (Self::F16(src), Self::F16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), + (Self::F32(src), Self::F32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), + (Self::F64(src), Self::F64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), + (_, dst) => { + // This should be covered by the dtype check above. + return Err(Error::DTypeMismatchBinaryOp { + lhs: self.dtype(), + rhs: dst.dtype(), + op: "copy_strided", + } + .bt()); + } + } + Ok(()) + } + + fn where_cond( + &self, + layout: &Layout, + t: &Self, + t_l: &Layout, + f: &Self, + f_l: &Layout, + ) -> Result { + match self { + Self::U8(pred) => WCond(pred, layout).map(t, t_l, f, f_l), + Self::U32(pred) => WCond(pred, layout).map(t, t_l, f, f_l), + Self::I16(pred) => WCond(pred, layout).map(t, t_l, f, f_l), + Self::I32(pred) => WCond(pred, layout).map(t, t_l, f, f_l), + Self::I64(pred) => WCond(pred, layout).map(t, t_l, f, f_l), + _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "where-cond")), + } + } + + fn conv1d( + &self, + l: &Layout, + kernel: &Self, + kernel_l: &Layout, + params: &crate::conv::ParamsConv1D, + ) -> Result { + if !USE_IM2COL_CONV1D { + return Conv1D(params).map(self, l, kernel, kernel_l); + } + let op = Im2Col1D { + l_k: params.k_size, + padding: params.padding, + stride: params.stride, + dilation: params.dilation, + }; + let col = op.map(self, l)?; + let b = params.b_size; + let n = params.c_out; + let l_out = params.l_out(); + let k = op.l_k * params.c_in; + let m = l_out; + let col_l = Layout::contiguous((b, m, k)); + let res = if kernel_l.is_contiguous() { + let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) + .transpose(1, 2)? + .broadcast_as((b, k, n))?; + col.matmul_with_alpha(kernel, None, (b, m, n, k), &col_l, &kernel_l)? + } else { + // Make the kernel contiguous if not already the case. + let mut kernel_c = unsafe { + self.device() + .alloc_uninit(kernel_l.shape(), kernel.dtype())? + }; + kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?; + let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) + .transpose(1, 2)? + .broadcast_as((b, k, n))?; + col.matmul_with_alpha(kernel, None, (b, m, n, k), &col_l, &kernel_l)? + }; + let res_l = Layout::contiguous((b, l_out, params.c_out)).transpose(1, 2)?; + let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? }; + res.copy_strided_src(&mut res_t, 0, &res_l)?; + Ok(res_t) + } + + fn conv_transpose1d( + &self, + l: &Layout, + kernel: &Self, + kernel_l: &Layout, + params: &crate::conv::ParamsConvTranspose1D, + ) -> Result { + let can_use_col2im = kernel_l.is_contiguous() + && params.dilation == 1 + && params.padding == 0 + && params.output_padding == 0; + if USE_COL2IM_CONV1D_TR && can_use_col2im { + let (b_size, c_in, l_in) = l.shape().dims3()?; + let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?; + if !kernel_l.is_contiguous() { + crate::bail!( + "convtr1d: the second argument (kernel) has to be contiguous {kernel_l:?}" + ) + } + if c_in != c_in2 { + crate::bail!( + "convtr1d: shape mismatch on c_in {:?} {:?}", + l.shape(), + kernel_l.shape() + ) + } + let col = { + // This merges the last two dimensions of the kernel together. + let kernel_l_mm = Layout::new( + (b_size, c_in, k_size * c_out).into(), + vec![0, k_size * c_out, 1], + kernel_l.start_offset(), + ); + self.matmul_with_alpha( + kernel, + None, + ( + b_size, + /* m */ l_in, + /* n */ c_out * k_size, + /* k */ c_in, + ), + &l.transpose(1, 2)?, + &kernel_l_mm, + )? + }; + let col_l = Layout::contiguous((b_size, l_in, c_out, k_size)); + Col2Im1D { + stride: params.stride, + } + .map(&col, &col_l) + } else { + ConvTranspose1D(params).map(self, l, kernel, kernel_l) + } + } + + fn conv2d( + &self, + l: &Layout, + kernel: &Self, + kernel_l: &Layout, + params: &crate::conv::ParamsConv2D, + ) -> Result { + if !USE_IM2COL_CONV2D { + return Conv2D(params).map(self, l, kernel, kernel_l); + } + let op = Im2Col { + h_k: params.k_h, + w_k: params.k_w, + padding: params.padding, + stride: params.stride, + dilation: params.dilation, + }; + let col = op.map(self, l)?; + let b = params.b_size; + let n = params.c_out; + let (h_out, w_out) = (params.out_h(), params.out_w()); + let k = op.h_k * op.w_k * params.c_in; + let m = h_out * w_out; + let col_l = Layout::contiguous((b, m, k)); + let res = if kernel_l.is_contiguous() { + let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) + .transpose(1, 2)? + .broadcast_as((b, k, n))?; + col.matmul_with_alpha(kernel, None, (b, m, n, k), &col_l, &kernel_l)? + } else { + // Make the kernel contiguous if not already the case. + let mut kernel_c = unsafe { + self.device() + .alloc_uninit(kernel_l.shape(), kernel.dtype())? + }; + kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?; + let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) + .transpose(1, 2)? + .broadcast_as((b, k, n))?; + col.matmul_with_alpha(kernel, None, (b, m, n, k), &col_l, &kernel_l)? + }; + let res_l = Layout::contiguous((b, h_out, w_out, params.c_out)) + .transpose(1, 2)? + .transpose(1, 3)?; + let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? }; + res.copy_strided_src(&mut res_t, 0, &res_l)?; + Ok(res_t) + } + + fn conv_transpose2d( + &self, + l: &Layout, + kernel: &Self, + kernel_l: &Layout, + params: &crate::conv::ParamsConvTranspose2D, + ) -> Result { + ConvTranspose2D(params).map(self, l, kernel, kernel_l) + } + + fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result { + match ids { + Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), + Self::U32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), + Self::I16(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), + Self::I32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), + Self::I64(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), + _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select").bt()), + } + } + + fn gather(&self, l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result { + match ids { + Self::U8(ids) => Gather { ids, ids_l, dim }.map(self, l), + Self::U32(ids) => Gather { ids, ids_l, dim }.map(self, l), + Self::I16(ids) => Gather { ids, ids_l, dim }.map(self, l), + Self::I32(ids) => Gather { ids, ids_l, dim }.map(self, l), + Self::I64(ids) => Gather { ids, ids_l, dim }.map(self, l), + _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather").bt()), + } + } + + fn scatter_add( + &self, + l: &Layout, + ids: &Self, + ids_l: &Layout, + src: &Self, + src_l: &Layout, + dim: usize, + ) -> Result { + match ids { + Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l), + Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l), + Self::I16(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l), + Self::I32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l), + Self::I64(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l), + _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add").bt()), + } + } + + fn index_add( + &self, + l: &Layout, + ids: &Self, + ids_l: &Layout, + src: &Self, + src_l: &Layout, + dim: usize, + ) -> Result { + match ids { + Self::U8(ids) => { + let ids = match ids_l.contiguous_offsets() { + Some((a, b)) => &ids[a..b], + None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?, + }; + IndexAdd { ids, dim }.map(self, l, src, src_l) + } + Self::U32(ids) => { + let ids = match ids_l.contiguous_offsets() { + Some((a, b)) => &ids[a..b], + None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?, + }; + IndexAdd { ids, dim }.map(self, l, src, src_l) + } + Self::I16(ids) => { + let ids = match ids_l.contiguous_offsets() { + Some((a, b)) => &ids[a..b], + None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?, + }; + IndexAdd { ids, dim }.map(self, l, src, src_l) + } + Self::I32(ids) => { + let ids = match ids_l.contiguous_offsets() { + Some((a, b)) => &ids[a..b], + None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?, + }; + IndexAdd { ids, dim }.map(self, l, src, src_l) + } + Self::I64(ids) => { + let ids = match ids_l.contiguous_offsets() { + Some((a, b)) => &ids[a..b], + None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?, + }; + IndexAdd { ids, dim }.map(self, l, src, src_l) + } + _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add").bt()), + } + } + + fn matmul_with_alpha_beta( + &self, + rhs: &Self, + c: &mut Self, + s: Option, + bmnk: (usize, usize, usize, usize), + lhs_l: &Layout, + rhs_l: &Layout, + c_l: &Layout, + ) -> Result<()> { + MatMulWithBias(MatMul(bmnk)).map(self, lhs_l, rhs, rhs_l, c, c_l, s) + } + + fn matmul_with_alpha( + &self, + rhs: &Self, + s: Option, + bmnk: (usize, usize, usize, usize), + lhs_l: &Layout, + rhs_l: &Layout, + ) -> Result { + MatMulWithAlpha(MatMul(bmnk)).map(self, lhs_l, rhs, rhs_l, s) + } + + fn device(&self) -> &Self::Device { + &CpuDevice + } + + fn try_clone(&self, _: &Layout) -> Result { + Ok(self.clone()) + } + + fn to_cpu_storage(&self) -> Result { + Ok(self.clone()) + } +} + +impl BackendDevice for CpuDevice { + type Storage = CpuStorage; + + fn location(&self) -> crate::DeviceLocation { + crate::DeviceLocation::Cpu + } + + fn same_device(&self, _: &Self) -> bool { + true + } + + fn storage_from_slice(&self, s: &[T]) -> Result { + Ok(T::to_cpu_storage(s)) + } + + fn storage_from_cpu_storage(&self, s: &CpuStorage) -> Result { + Ok(s.clone()) + } + + fn storage_from_cpu_storage_owned(&self, s: CpuStorage) -> Result { + Ok(s) + } + + fn new(_: usize) -> Result { + Ok(Self) + } + + fn set_seed(&self, _seed: u64) -> Result<()> { + crate::bail!("cannot seed the CPU rng with set_seed") + } + + fn get_current_seed(&self) -> Result { + crate::bail!("cannot get the CPU rng seed with get_current_seed") + } + + fn rand_uniform(&self, shape: &Shape, dtype: DType, min: f64, max: f64) -> Result { + use rand::prelude::*; + + let elem_count = shape.elem_count(); + let mut rng = rand::thread_rng(); + match dtype { + DType::U8 | DType::U32 | DType::I16 | DType::I32 | DType::I64 => { + Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt()) + } + DType::BF16 => { + let mut data = Vec::with_capacity(elem_count); + let uniform = + rand::distributions::Uniform::new(bf16::from_f64(min), bf16::from_f64(max)); + for _i in 0..elem_count { + data.push(rng.sample::(uniform)) + } + Ok(CpuStorage::BF16(data)) + } + DType::F16 => { + let mut data = Vec::with_capacity(elem_count); + let uniform = + rand::distributions::Uniform::new(f16::from_f64(min), f16::from_f64(max)); + for _i in 0..elem_count { + data.push(rng.sample::(uniform)) + } + Ok(CpuStorage::F16(data)) + } + DType::F8E4M3 => { + let mut data = Vec::with_capacity(elem_count); + let uniform = + rand::distributions::Uniform::new(F8E4M3::from_f64(min), F8E4M3::from_f64(max)); + for _i in 0..elem_count { + data.push(rng.sample::(uniform)) + } + Ok(CpuStorage::F8E4M3(data)) + } + DType::F32 => { + let mut data = Vec::with_capacity(elem_count); + let uniform = rand::distributions::Uniform::new(min as f32, max as f32); + for _i in 0..elem_count { + data.push(rng.sample::(uniform)) + } + Ok(CpuStorage::F32(data)) + } + DType::F64 => { + let mut data = Vec::with_capacity(elem_count); + let uniform = rand::distributions::Uniform::new(min, max); + for _i in 0..elem_count { + data.push(rng.sample::(uniform)) + } + Ok(CpuStorage::F64(data)) + } + } + } + + fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result { + use rand::prelude::*; + + let elem_count = shape.elem_count(); + let mut rng = rand::thread_rng(); + match dtype { + DType::U8 | DType::U32 | DType::I16 | DType::I32 | DType::I64 => { + Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()) + } + DType::BF16 => { + let mut data = Vec::with_capacity(elem_count); + let normal = rand_distr::Normal::new(bf16::from_f64(mean), bf16::from_f64(std)) + .map_err(Error::wrap)?; + for _i in 0..elem_count { + data.push(normal.sample(&mut rng)) + } + Ok(CpuStorage::BF16(data)) + } + DType::F16 => { + let mut data = Vec::with_capacity(elem_count); + let normal = rand_distr::Normal::new(f16::from_f64(mean), f16::from_f64(std)) + .map_err(Error::wrap)?; + for _i in 0..elem_count { + data.push(normal.sample(&mut rng)) + } + Ok(CpuStorage::F16(data)) + } + DType::F8E4M3 => { + let mut data = Vec::with_capacity(elem_count); + let normal = rand_distr::Normal::new(F8E4M3::from_f64(mean), F8E4M3::from_f64(std)) + .map_err(Error::wrap)?; + for _i in 0..elem_count { + data.push(normal.sample(&mut rng)) + } + Ok(CpuStorage::F8E4M3(data)) + } + DType::F32 => { + let mut data = Vec::with_capacity(elem_count); + let normal = + rand_distr::Normal::new(mean as f32, std as f32).map_err(Error::wrap)?; + for _i in 0..elem_count { + data.push(normal.sample(&mut rng)) + } + Ok(CpuStorage::F32(data)) + } + DType::F64 => { + let mut data = Vec::with_capacity(elem_count); + let normal = rand_distr::Normal::new(mean, std).map_err(Error::wrap)?; + for _i in 0..elem_count { + data.push(normal.sample(&mut rng)) + } + Ok(CpuStorage::F64(data)) + } + } + } + + #[allow(clippy::uninit_vec)] + unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result { + let elem_count = shape.elem_count(); + // The code below is highly unsafe but hopefully not directly unsound as we only consider + // types that are Copy, not Drop, and for which all bit patterns are proper values. + // It's still pretty risky, see the following for more details: + // https://github.com/rust-lang/rust-clippy/issues/4483 + let storage = match dtype { + DType::U8 => { + let mut v = Vec::with_capacity(elem_count); + v.set_len(elem_count); + CpuStorage::U8(v) + } + DType::U32 => { + let mut v = Vec::with_capacity(elem_count); + v.set_len(elem_count); + CpuStorage::U32(v) + } + DType::I16 => { + let mut v = Vec::with_capacity(elem_count); + v.set_len(elem_count); + CpuStorage::I16(v) + } + DType::I32 => { + let mut v = Vec::with_capacity(elem_count); + v.set_len(elem_count); + CpuStorage::I32(v) + } + DType::I64 => { + let mut v = Vec::with_capacity(elem_count); + v.set_len(elem_count); + CpuStorage::I64(v) + } + DType::BF16 => { + let mut v = Vec::with_capacity(elem_count); + v.set_len(elem_count); + CpuStorage::BF16(v) + } + DType::F16 => { + let mut v = Vec::with_capacity(elem_count); + v.set_len(elem_count); + CpuStorage::F16(v) + } + DType::F32 => { + let mut v = Vec::with_capacity(elem_count); + v.set_len(elem_count); + CpuStorage::F32(v) + } + DType::F64 => { + let mut v = Vec::with_capacity(elem_count); + v.set_len(elem_count); + CpuStorage::F64(v) + } + DType::F8E4M3 => { + let mut v = Vec::with_capacity(elem_count); + v.set_len(elem_count); + CpuStorage::F8E4M3(v) + } + }; + Ok(storage) + } + + fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result { + let elem_count = shape.elem_count(); + let storage = match dtype { + DType::U8 => CpuStorage::U8(vec![1u8; elem_count]), + DType::U32 => CpuStorage::U32(vec![1u32; elem_count]), + DType::I16 => CpuStorage::I16(vec![1i16; elem_count]), + DType::I32 => CpuStorage::I32(vec![1i32; elem_count]), + DType::I64 => CpuStorage::I64(vec![1i64; elem_count]), + DType::BF16 => CpuStorage::BF16(vec![bf16::ONE; elem_count]), + DType::F16 => CpuStorage::F16(vec![f16::ONE; elem_count]), + DType::F8E4M3 => CpuStorage::F8E4M3(vec![F8E4M3::ONE; elem_count]), + DType::F32 => CpuStorage::F32(vec![1f32; elem_count]), + DType::F64 => CpuStorage::F64(vec![1f64; elem_count]), + }; + Ok(storage) + } + + fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result { + let elem_count = shape.elem_count(); + let storage = match dtype { + DType::U8 => CpuStorage::U8(vec![0u8; elem_count]), + DType::U32 => CpuStorage::U32(vec![0u32; elem_count]), + DType::I16 => CpuStorage::I16(vec![0i16; elem_count]), + DType::I32 => CpuStorage::I32(vec![0i32; elem_count]), + DType::I64 => CpuStorage::I64(vec![0i64; elem_count]), + DType::BF16 => CpuStorage::BF16(vec![bf16::ZERO; elem_count]), + DType::F16 => CpuStorage::F16(vec![f16::ZERO; elem_count]), + DType::F8E4M3 => CpuStorage::F8E4M3(vec![F8E4M3::ZERO; elem_count]), + DType::F32 => CpuStorage::F32(vec![0f32; elem_count]), + DType::F64 => CpuStorage::F64(vec![0f64; elem_count]), + }; + Ok(storage) + } + + fn synchronize(&self) -> Result<()> { + Ok(()) + } +} + +#[macro_export] +macro_rules! map_dtype { + ($name:expr, $storage:ident, $fn:expr, ($($dtypes:ident),+)) => { + match $storage { + $(CpuStorage::$dtypes(__e) => CpuStorage::$dtypes($fn(__e)),)* + s => Err(Error::UnsupportedDTypeForOp(s.dtype(), $name).bt())?, + } + }; +} diff --git a/mcandle/mcandle-core/src/cpu_backend/utils.rs b/mcandle/mcandle-core/src/cpu_backend/utils.rs new file mode 100644 index 000000000..9e7538581 --- /dev/null +++ b/mcandle/mcandle-core/src/cpu_backend/utils.rs @@ -0,0 +1,451 @@ +/// Helper functions to write CPU kernels. +use crate::backend::BackendStorage; +use crate::{Error, Layout, Result, WithDType}; + +type C = super::CpuStorage; +pub trait Map1 { + fn f(&self, vs: &[T], layout: &Layout) -> Result>; + + fn map(&self, vs: &C, layout: &Layout) -> Result { + match vs { + C::U8(vs) => Ok(C::U8(self.f(vs, layout)?)), + C::U32(vs) => Ok(C::U32(self.f(vs, layout)?)), + C::I16(vs) => Ok(C::I16(self.f(vs, layout)?)), + C::I32(vs) => Ok(C::I32(self.f(vs, layout)?)), + C::I64(vs) => Ok(C::I64(self.f(vs, layout)?)), + C::BF16(vs) => Ok(C::BF16(self.f(vs, layout)?)), + C::F16(vs) => Ok(C::F16(self.f(vs, layout)?)), + C::F32(vs) => Ok(C::F32(self.f(vs, layout)?)), + C::F64(vs) => Ok(C::F64(self.f(vs, layout)?)), + C::F8E4M3(vs) => Ok(C::F8E4M3(self.f(vs, layout)?)), + } + } +} + +pub trait Map1Any { + fn f) -> C>(&self, vs: &[T], layout: &Layout, wrap: W) -> Result; + + fn map(&self, vs: &C, layout: &Layout) -> Result { + match vs { + C::U8(vs) => Ok(self.f(vs, layout, C::U8)?), + C::U32(vs) => Ok(self.f(vs, layout, C::U32)?), + C::I16(vs) => Ok(self.f(vs, layout, C::I16)?), + C::I32(vs) => Ok(self.f(vs, layout, C::I32)?), + C::I64(vs) => Ok(self.f(vs, layout, C::I64)?), + C::BF16(vs) => Ok(self.f(vs, layout, C::BF16)?), + C::F16(vs) => Ok(self.f(vs, layout, C::F16)?), + C::F32(vs) => Ok(self.f(vs, layout, C::F32)?), + C::F64(vs) => Ok(self.f(vs, layout, C::F64)?), + C::F8E4M3(vs) => Ok(self.f(vs, layout, C::F8E4M3)?), + } + } +} + +pub trait Map2 { + const OP: &'static str; + fn f(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result>; + + fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout) -> Result { + match (v1, v2) { + (C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), + (C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)), + (C::I32(v1), C::I32(v2)) => Ok(C::I32(self.f(v1, l1, v2, l2)?)), + (C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2)?)), + (C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)), + (C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)), + (C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)), + (C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)), + (C::F8E4M3(v1), C::F8E4M3(v2)) => Ok(C::F8E4M3(self.f(v1, l1, v2, l2)?)), + _ => Err(Error::DTypeMismatchBinaryOp { + lhs: v1.dtype(), + rhs: v2.dtype(), + op: Self::OP, + } + .bt()), + } + } +} + +pub trait Map3 { + const OP: &'static str; + #[allow(clippy::too_many_arguments)] + fn f( + &self, + v1: &[T], + l1: &Layout, + v2: &[T], + l2: &Layout, + v3: &mut [T], + l3: &Layout, + s: Option, + ) -> Result<()>; + + #[allow(clippy::too_many_arguments)] + fn map( + &self, + v1: &C, + l1: &Layout, + v2: &C, + l2: &Layout, + v3: &mut C, + l3: &Layout, + s: Option, + ) -> Result<()> { + let v3d = v3.dtype(); + match (v1, v2, v3) { + (C::U8(v1), C::U8(v2), C::U8(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), + (C::U32(v1), C::U32(v2), C::U32(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), + (C::I32(v1), C::I32(v2), C::I32(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), + (C::I64(v1), C::I64(v2), C::I64(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), + (C::BF16(v1), C::BF16(v2), C::BF16(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), + (C::F16(v1), C::F16(v2), C::F16(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), + (C::F32(v1), C::F32(v2), C::F32(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), + (C::F64(v1), C::F64(v2), C::F64(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), + (C::F8E4M3(v1), C::F8E4M3(v2), C::F8E4M3(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), + _ => Err(Error::DTypeMismatchBinaryOp3 { + lhs: v1.dtype(), + rhs: v2.dtype(), + c: v3d, + op: Self::OP, + } + .bt()), + } + } +} + +pub trait Map2Alpha { + const OP: &'static str; + #[allow(clippy::too_many_arguments)] + fn f( + &self, + v1: &[T], + l1: &Layout, + v2: &[T], + l2: &Layout, + s: Option, + ) -> Result>; + + #[allow(clippy::too_many_arguments)] + fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout, s: Option) -> Result { + match (v1, v2) { + (C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2, s)?)), + (C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2, s)?)), + (C::I32(v1), C::I32(v2)) => Ok(C::I32(self.f(v1, l1, v2, l2, s)?)), + (C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2, s)?)), + (C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2, s)?)), + (C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2, s)?)), + (C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2, s)?)), + (C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2, s)?)), + (C::F8E4M3(v1), C::F8E4M3(v2)) => Ok(C::F8E4M3(self.f(v1, l1, v2, l2, s)?)), + _ => Err(Error::DTypeMismatchBinaryOp { + lhs: v1.dtype(), + rhs: v2.dtype(), + op: Self::OP, + } + .bt()), + } + } +} + +pub trait Map2U8 { + const OP: &'static str; + fn f(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result>; + + fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout) -> Result { + match (v1, v2) { + (C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), + (C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), + (C::I32(v1), C::I32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), + (C::I64(v1), C::I64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), + (C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), + (C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), + (C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), + (C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), + (C::F8E4M3(v1), C::F8E4M3(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), + _ => Err(Error::DTypeMismatchBinaryOp { + lhs: v1.dtype(), + rhs: v2.dtype(), + op: Self::OP, + } + .bt()), + } + } +} + +pub fn binary_map U>( + lhs_l: &Layout, + rhs_l: &Layout, + lhs: &[T], + rhs: &[T], + mut f: F, +) -> Vec { + match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) { + (Some((o_l1, o_l2)), Some((o_r1, o_r2))) => lhs[o_l1..o_l2] + .iter() + .zip(rhs[o_r1..o_r2].iter()) + .map(|(&l, &r)| f(l, r)) + .collect(), + (Some((o_l1, o_l2)), None) => { + // TODO: Maybe we want to avoid going through the layout twice. + match rhs_l.offsets_b() { + Some(ob) => { + let mut i_in_block = 0; + let mut i_right_broadcast = 0; + lhs[o_l1..o_l2] + .iter() + .map(|&l| { + let r = unsafe { rhs.get_unchecked(i_in_block + ob.start) }; + i_right_broadcast += 1; + if i_right_broadcast >= ob.right_broadcast { + i_in_block += 1; + i_right_broadcast = 0; + } + if i_in_block >= ob.len { + i_in_block = 0 + } + f(l, *r) + }) + .collect() + } + None => lhs_l + .strided_index() + .zip(rhs_l.strided_index()) + .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) + .collect(), + } + } + (None, Some((o_r1, o_r2))) => { + // TODO: Maybe we want to avoid going through the layout twice. + match lhs_l.offsets_b() { + Some(ob) => { + let mut i_in_block = 0; + let mut i_right_broadcast = 0; + rhs[o_r1..o_r2] + .iter() + .map(|&r| { + let l = unsafe { lhs.get_unchecked(i_in_block + ob.start) }; + i_right_broadcast += 1; + if i_right_broadcast >= ob.right_broadcast { + i_in_block += 1; + i_right_broadcast = 0; + } + if i_in_block >= ob.len { + i_in_block = 0 + } + f(*l, r) + }) + .collect() + } + None => lhs_l + .strided_index() + .zip(rhs_l.strided_index()) + .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) + .collect(), + } + } + _ => lhs_l + .strided_index() + .zip(rhs_l.strided_index()) + .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) + .collect(), + } +} + +// Similar to binary_map but with vectorized variants. +pub fn binary_map_vec T, FV: FnMut(&[T], &[T], &mut [T])>( + lhs_l: &Layout, + rhs_l: &Layout, + lhs: &[T], + rhs: &[T], + mut f: F, + mut f_vec: FV, +) -> Vec { + let el_count = lhs_l.shape().elem_count(); + match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) { + (Some((o_l1, o_l2)), Some((o_r1, o_r2))) => { + let mut ys: Vec = Vec::with_capacity(el_count); + let ys_to_set = ys.spare_capacity_mut(); + let ys_to_set = unsafe { + std::mem::transmute::<&mut [std::mem::MaybeUninit], &mut [T]>(ys_to_set) + }; + f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set); + // SAFETY: values are all set by f_vec. + unsafe { ys.set_len(el_count) }; + ys + } + (Some((o_l1, o_l2)), None) => match rhs_l.offsets_b() { + Some(ob) if ob.right_broadcast == 1 => { + let rhs = &rhs[ob.start..ob.start + ob.len]; + let mut ys: Vec = Vec::with_capacity(el_count); + let ys_to_set = ys.spare_capacity_mut(); + let ys_to_set = unsafe { + std::mem::transmute::<&mut [std::mem::MaybeUninit], &mut [T]>(ys_to_set) + }; + let mut dst_i = 0; + for src_i in (o_l1..o_l2).step_by(ob.len) { + f_vec( + &lhs[src_i..src_i + ob.len], + rhs, + &mut ys_to_set[dst_i..dst_i + ob.len], + ); + dst_i += ob.len; + } + // SAFETY: values are all set by f_vec. + unsafe { ys.set_len(el_count) }; + ys + } + Some(ob) => { + let rhs = &rhs[ob.start..ob.start + ob.len]; + let mut ys = lhs[o_l1..o_l2].to_vec(); + for idx_l in 0..ob.left_broadcast { + let start = idx_l * ob.len * ob.right_broadcast; + for (i, &r) in rhs.iter().enumerate() { + let start = start + i * ob.right_broadcast; + for v in ys[start..start + ob.right_broadcast].iter_mut() { + *v = f(*v, r) + } + } + } + ys + } + None => lhs_l + .strided_index() + .zip(rhs_l.strided_index()) + .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) + .collect(), + }, + (None, Some((o_r1, o_r2))) => match lhs_l.offsets_b() { + Some(ob) if ob.right_broadcast == 1 => { + let lhs = &lhs[ob.start..ob.start + ob.len]; + let mut ys: Vec = Vec::with_capacity(el_count); + let ys_to_set = ys.spare_capacity_mut(); + let ys_to_set = unsafe { + std::mem::transmute::<&mut [std::mem::MaybeUninit], &mut [T]>(ys_to_set) + }; + let mut dst_i = 0; + for src_i in (o_r1..o_r2).step_by(ob.len) { + f_vec( + lhs, + &rhs[src_i..src_i + ob.len], + &mut ys_to_set[dst_i..dst_i + ob.len], + ); + dst_i += ob.len; + } + // SAFETY: values are all set by f_vec. + unsafe { ys.set_len(el_count) }; + ys + } + Some(ob) => { + let lhs = &lhs[ob.start..ob.start + ob.len]; + let mut ys = rhs[o_r1..o_r2].to_vec(); + for idx_l in 0..ob.left_broadcast { + let start = idx_l * ob.len * ob.right_broadcast; + for (i, &l) in lhs.iter().enumerate() { + let start = start + i * ob.right_broadcast; + for v in ys[start..start + ob.right_broadcast].iter_mut() { + *v = f(l, *v) + } + } + } + ys + } + None => lhs_l + .strided_index() + .zip(rhs_l.strided_index()) + .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) + .collect(), + }, + _ => lhs_l + .strided_index() + .zip(rhs_l.strided_index()) + .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) + .collect(), + } +} + +pub fn unary_map U>( + vs: &[T], + layout: &Layout, + mut f: F, +) -> Vec { + match layout.strided_blocks() { + crate::StridedBlocks::SingleBlock { start_offset, len } => vs + [start_offset..start_offset + len] + .iter() + .map(|&v| f(v)) + .collect(), + crate::StridedBlocks::MultipleBlocks { + block_start_index, + block_len, + } => { + let mut result = Vec::with_capacity(layout.shape().elem_count()); + // Specialize the case where block_len is one to avoid the second loop. + if block_len == 1 { + for index in block_start_index { + let v = unsafe { vs.get_unchecked(index) }; + result.push(f(*v)) + } + } else { + for index in block_start_index { + for offset in 0..block_len { + let v = unsafe { vs.get_unchecked(index + offset) }; + result.push(f(*v)) + } + } + } + result + } + } +} + +pub fn unary_map_vec U, FV: FnMut(&[T], &mut [U])>( + vs: &[T], + layout: &Layout, + mut f: F, + mut f_vec: FV, +) -> Vec { + match layout.strided_blocks() { + crate::StridedBlocks::SingleBlock { start_offset, len } => { + let mut ys: Vec = Vec::with_capacity(len); + let ys_to_set = ys.spare_capacity_mut(); + let ys_to_set = unsafe { + std::mem::transmute::<&mut [std::mem::MaybeUninit], &mut [U]>(ys_to_set) + }; + f_vec(&vs[start_offset..start_offset + len], ys_to_set); + // SAFETY: values are all set by f_vec. + unsafe { ys.set_len(len) }; + ys + } + crate::StridedBlocks::MultipleBlocks { + block_start_index, + block_len, + } => { + let el_count = layout.shape().elem_count(); + // Specialize the case where block_len is one to avoid the second loop. + if block_len == 1 { + let mut result = Vec::with_capacity(el_count); + for index in block_start_index { + let v = unsafe { vs.get_unchecked(index) }; + result.push(f(*v)) + } + result + } else { + let mut ys: Vec = Vec::with_capacity(el_count); + let ys_to_set = ys.spare_capacity_mut(); + let ys_to_set = unsafe { + std::mem::transmute::<&mut [std::mem::MaybeUninit], &mut [U]>(ys_to_set) + }; + let mut dst_index = 0; + for src_index in block_start_index { + let vs = &vs[src_index..src_index + block_len]; + let ys = &mut ys_to_set[dst_index..dst_index + block_len]; + f_vec(vs, ys); + dst_index += block_len; + } + // SAFETY: values are all set by f_vec. + unsafe { ys.set_len(el_count) }; + ys + } + } + } +} diff --git a/mcandle/mcandle-core/src/cuda_backend/cudnn.rs b/mcandle/mcandle-core/src/cuda_backend/cudnn.rs new file mode 100644 index 000000000..f5b4db902 --- /dev/null +++ b/mcandle/mcandle-core/src/cuda_backend/cudnn.rs @@ -0,0 +1,124 @@ +use crate::WithDType; +use cudarc; +use cudarc::cudnn::safe::{ConvForward, Cudnn}; +use cudarc::driver::{CudaSlice, CudaView, DeviceRepr, ValidAsZeroBits}; +use std::cell::RefCell; +use std::collections::HashMap; +use std::sync::Arc; + +// The cudnn handles are stored per thread here rather than on the CudaDevice as they are neither +// send nor sync. +thread_local! { + static CUDNN: RefCell>> = HashMap::new().into(); +} + +impl From for crate::Error { + fn from(err: cudarc::cudnn::CudnnError) -> Self { + crate::Error::wrap(err) + } +} + +impl From for crate::Error { + fn from(err: cudarc::driver::DriverError) -> Self { + crate::Error::wrap(err) + } +} + +pub(crate) fn launch_conv2d< + T: DeviceRepr + WithDType + ValidAsZeroBits + cudarc::cudnn::CudnnDataType, + Y: cudarc::cudnn::CudnnDataType, +>( + src: &CudaView, + src_l: &crate::Layout, + filter: &CudaView, + dst: &mut CudaSlice, + params: &crate::conv::ParamsConv2D, + dev: &crate::cuda_backend::CudaDevice, +) -> crate::Result<()> { + use crate::conv::CudnnFwdAlgo as CandleAlgo; + use cudarc::cudnn::sys::cudnnConvolutionFwdAlgo_t as A; + + let device_id = dev.id(); + let cudnn = CUDNN.with(|cudnn| { + if let Some(cudnn) = cudnn.borrow().get(&device_id) { + return Ok(cudnn.clone()); + } + let c = Cudnn::new(dev.cuda_device()); + if let Ok(c) = &c { + cudnn.borrow_mut().insert(device_id, c.clone()); + } + c + })?; + let conv = cudnn.create_conv2d::( + /* pad */ [params.padding as i32, params.padding as i32], + /* stride */ [params.stride as i32, params.stride as i32], + /* dilation */ [params.dilation as i32, params.dilation as i32], + cudarc::cudnn::sys::cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION, + )?; + let x_shape = [ + params.b_size as i32, + params.c_in as i32, + params.i_h as i32, + params.i_w as i32, + ]; + // Note that `src` already starts at the proper offset. + let x = if src_l.is_contiguous() { + cudnn.create_4d_tensor::( + cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW, + x_shape, + )? + } else { + let s = src_l.stride(); + cudnn.create_4d_tensor_ex::( + x_shape, + [s[0] as i32, s[1] as i32, s[2] as i32, s[3] as i32], + )? + }; + let w = cudnn.create_4d_filter::( + cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW, + [ + params.c_out as i32, + params.c_in as i32, + params.k_h as i32, + params.k_w as i32, + ], + )?; + let (w_out, h_out) = (params.out_w() as i32, params.out_h() as i32); + let y = cudnn.create_4d_tensor::( + cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW, + [params.b_size as i32, params.c_out as i32, h_out, w_out], + )?; + let conv2d = ConvForward { + conv: &conv, + x: &x, + w: &w, + y: &y, + }; + let alg = match params.cudnn_fwd_algo { + None => conv2d.pick_algorithm()?, + Some(CandleAlgo::ImplicitGemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, + Some(CandleAlgo::ImplicitPrecompGemm) => { + A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM + } + Some(CandleAlgo::Gemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_GEMM, + Some(CandleAlgo::Direct) => A::CUDNN_CONVOLUTION_FWD_ALGO_DIRECT, + Some(CandleAlgo::Fft) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT, + Some(CandleAlgo::FftTiling) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, + Some(CandleAlgo::Winograd) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD, + Some(CandleAlgo::WinogradNonFused) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED, + Some(CandleAlgo::Count) => A::CUDNN_CONVOLUTION_FWD_ALGO_COUNT, + }; + let workspace_size = conv2d.get_workspace_size(alg)?; + let mut workspace = dev.cuda_device().alloc_zeros::(workspace_size)?; + unsafe { + conv2d.launch::, _, _, _>( + alg, + Some(&mut workspace), + (T::one(), T::zero()), + src, + filter, + dst, + )?; + } + Ok(()) +} diff --git a/mcandle/mcandle-core/src/cuda_backend/device.rs b/mcandle/mcandle-core/src/cuda_backend/device.rs new file mode 100644 index 000000000..c73e5077b --- /dev/null +++ b/mcandle/mcandle-core/src/cuda_backend/device.rs @@ -0,0 +1,594 @@ +use crate::backend::BackendDevice; +use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape}; +pub use cudarc; +use cudarc::driver::{CudaFunction, LaunchAsync, LaunchConfig}; +use float8::F8E4M3; +use half::{bf16, f16}; +pub use mcandle_kernels as kernels; +use std::sync::{Arc, Mutex, RwLock}; + +use super::{CudaError, CudaStorage, CudaStorageSlice, WrapErr}; + +/// Unique identifier for cuda devices. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub struct DeviceId(usize); + +impl DeviceId { + fn new() -> Self { + // https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805 + use std::sync::atomic; + static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1); + Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed)) + } +} + +struct CudaRng(cudarc::curand::CudaRng); +unsafe impl Send for CudaRng {} + +#[derive(Clone)] +pub struct CudaDevice { + id: DeviceId, + device: Arc, + pub(crate) blas: Arc, + curand: Arc>, + seed_value: Arc>, +} + +impl std::fmt::Debug for CudaDevice { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "CudaDevice({:?})", self.id) + } +} + +impl std::ops::Deref for CudaDevice { + type Target = Arc; + + fn deref(&self) -> &Self::Target { + &self.device + } +} + +impl CudaDevice { + pub fn cublas_handle(&self) -> &cudarc::cublas::CudaBlas { + &*self.blas + } + + pub fn cuda_device(&self) -> Arc { + self.device.clone() + } + + pub fn compile( + &self, + func_name: &'static str, + kernel: ug::lang::ssa::Kernel, + ) -> Result { + let mut buf = vec![]; + ug_cuda::code_gen::gen(&mut buf, func_name, &kernel)?; + let cuda_code = String::from_utf8(buf)?; + let opts = cudarc::nvrtc::CompileOptions { + use_fast_math: Some(true), + ..Default::default() + }; + let ptx = cudarc::nvrtc::safe::compile_ptx_with_opts(cuda_code, opts).w()?; + self.device.load_ptx(ptx, "ug", &[func_name]).w()?; + let func = match self.device.get_func("ug", func_name) { + Some(func) => func, + None => crate::bail!("unknown function ug::{func_name}"), + }; + Ok(func) + } + + pub fn id(&self) -> DeviceId { + self.id + } + + fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result { + let elem_count = shape.elem_count(); + let cfg = LaunchConfig::for_num_elems(elem_count as u32); + let slice = match dtype { + DType::U8 => { + // SAFETY: Set later by running the fill kernel. + let data = unsafe { self.alloc::(elem_count) }.w()?; + let func = self.get_or_load_func("fill_u8", kernels::FILL)?; + let params = (&data, v as u8, elem_count); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::U8(data) + } + DType::U32 => { + // SAFETY: Set later by running the fill kernel. + let data = unsafe { self.alloc::(elem_count) }.w()?; + let func = self.get_or_load_func("fill_u32", kernels::FILL)?; + let params = (&data, v as u32, elem_count); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::U32(data) + } + DType::I16 => { + // SAFETY: Set later by running the fill kernel. + let data = unsafe { self.alloc::(elem_count) }.w()?; + let func = self.get_or_load_func("fill_i16", kernels::FILL)?; + let params = (&data, v as i16, elem_count); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::I16(data) + } + DType::I32 => { + // SAFETY: Set later by running the fill kernel. + let data = unsafe { self.alloc::(elem_count) }.w()?; + let func = self.get_or_load_func("fill_i32", kernels::FILL)?; + let params = (&data, v as i32, elem_count); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::I32(data) + } + DType::I64 => { + // SAFETY: Set later by running the fill kernel. + let data = unsafe { self.alloc::(elem_count) }.w()?; + let func = self.get_or_load_func("fill_i64", kernels::FILL)?; + let params = (&data, v as i64, elem_count); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::I64(data) + } + DType::BF16 => { + // SAFETY: Set later by running the fill kernel. + let data = unsafe { self.alloc::(elem_count) }.w()?; + let func = self.get_or_load_func("fill_bf16", kernels::FILL)?; + let params = (&data, bf16::from_f64(v), elem_count); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::BF16(data) + } + DType::F16 => { + // SAFETY: Set later by running the fill kernel. + let data = unsafe { self.alloc::(elem_count) }.w()?; + let func = self.get_or_load_func("fill_f16", kernels::FILL)?; + let params = (&data, f16::from_f64(v), elem_count); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::F16(data) + } + DType::F32 => { + // SAFETY: Set later by running the fill kernel. + let data = unsafe { self.alloc::(elem_count) }.w()?; + let func = self.get_or_load_func("fill_f32", kernels::FILL)?; + let params = (&data, v as f32, elem_count); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::F32(data) + } + DType::F64 => { + // SAFETY: Set later by running the fill kernel. + let data = unsafe { self.alloc::(elem_count) }.w()?; + let func = self.get_or_load_func("fill_f64", kernels::FILL)?; + let params = (&data, v, elem_count); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::F64(data) + } + DType::F8E4M3 => { + // SAFETY: Set later by running the fill kernel. + let data = unsafe { self.alloc::(elem_count) }.w()?; + let func = self.get_or_load_func("fill_f8_e4m3", kernels::FILL)?; + let params = (&data, v, elem_count); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::F8E4M3(data) + } + }; + Ok(CudaStorage { + slice, + device: self.clone(), + }) + } + + pub fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result { + if !self.has_func(module_name, module_name) { + // Leaking the string here is a bit sad but we need a &'static str and this is only + // done once per kernel name. + let static_module_name = Box::leak(module_name.to_string().into_boxed_str()); + self.load_ptx(ptx.into(), module_name, &[static_module_name]) + .map_err(|cuda| CudaError::Load { + cuda, + module_name: module_name.to_string(), + }) + .w()?; + } + self.get_func(module_name, module_name) + // Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is + // able to only build the error value if needed. + .ok_or(CudaError::MissingKernel { + module_name: module_name.to_string(), + }) + .w() + } +} + +impl CudaDevice { + pub fn new_with_stream(ordinal: usize) -> Result { + let device = cudarc::driver::CudaDevice::new_with_stream(ordinal).w()?; + let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?; + let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?; + Ok(Self { + id: DeviceId::new(), + device, + blas: Arc::new(blas), + curand: Arc::new(Mutex::new(CudaRng(curand))), + seed_value: Arc::new(RwLock::new(299792458)), + }) + } +} + +impl BackendDevice for CudaDevice { + type Storage = CudaStorage; + + fn new(ordinal: usize) -> Result { + let device = cudarc::driver::CudaDevice::new(ordinal).w()?; + let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?; + let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?; + Ok(Self { + id: DeviceId::new(), + device, + blas: Arc::new(blas), + curand: Arc::new(Mutex::new(CudaRng(curand))), + seed_value: Arc::new(RwLock::new(299792458)), + }) + } + + fn set_seed(&self, seed: u64) -> Result<()> { + // We do not call set_seed but instead create a new curand object. This ensures that the + // state will be identical and the same random numbers will be generated. + let mut curand = self.curand.lock().unwrap(); + curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?; + *self.seed_value.write().unwrap() = seed; + Ok(()) + } + + fn get_current_seed(&self) -> Result { + Ok(*self.seed_value.read().unwrap()) + } + + fn location(&self) -> crate::DeviceLocation { + crate::DeviceLocation::Cuda { + gpu_id: self.device.ordinal(), + } + } + + fn same_device(&self, rhs: &Self) -> bool { + self.id == rhs.id + } + + fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result { + let elem_count = shape.elem_count(); + let slice = match dtype { + DType::U8 => { + let data = self.alloc_zeros::(elem_count).w()?; + CudaStorageSlice::U8(data) + } + DType::U32 => { + let data = self.alloc_zeros::(elem_count).w()?; + CudaStorageSlice::U32(data) + } + DType::I16 => { + let data = self.alloc_zeros::(elem_count).w()?; + CudaStorageSlice::I16(data) + } + DType::I32 => { + let data = self.alloc_zeros::(elem_count).w()?; + CudaStorageSlice::I32(data) + } + DType::I64 => { + let data = self.alloc_zeros::(elem_count).w()?; + CudaStorageSlice::I64(data) + } + DType::BF16 => { + let data = self.alloc_zeros::(elem_count).w()?; + CudaStorageSlice::BF16(data) + } + DType::F16 => { + let data = self.alloc_zeros::(elem_count).w()?; + CudaStorageSlice::F16(data) + } + DType::F32 => { + let data = self.alloc_zeros::(elem_count).w()?; + CudaStorageSlice::F32(data) + } + DType::F64 => { + let data = self.alloc_zeros::(elem_count).w()?; + CudaStorageSlice::F64(data) + } + DType::F8E4M3 => { + let data = self.alloc_zeros::(elem_count).w()?; + CudaStorageSlice::F8E4M3(data) + } + }; + Ok(CudaStorage { + slice, + device: self.clone(), + }) + } + + fn rand_uniform(&self, shape: &Shape, dtype: DType, lo: f64, up: f64) -> Result { + let elem_count = shape.elem_count(); + let curand = self.curand.lock().unwrap(); + let slice = match dtype { + // TODO: Add support for F16 and BF16 though this is likely to require some upstream + // cudarc changes. + DType::U8 + | DType::U32 + | DType::I64 + | DType::I32 + | DType::I16 + | DType::F16 + | DType::BF16 + | DType::F8E4M3 => Err(CudaError::UnsupportedDtype { + dtype, + op: "rand_uniform", + }) + .w()?, + DType::F32 => { + let mut data = unsafe { self.alloc::(elem_count) }.w()?; + curand.0.fill_with_uniform(&mut data).w()?; + CudaStorageSlice::F32(data) + } + DType::F64 => { + let mut data = unsafe { self.alloc::(elem_count) }.w()?; + curand.0.fill_with_uniform(&mut data).w()?; + CudaStorageSlice::F64(data) + } + }; + let slice = if lo == 0. && up == 1.0 { + slice + } else { + use super::utils::Map1; + let layout = Layout::contiguous(shape); + super::Affine(up - lo, lo).map(&slice, self, &layout)? + }; + Ok(CudaStorage { + slice, + device: self.clone(), + }) + } + + fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result { + // TODO: Add support for F16 and BF16 though this is likely to require some upstream + // cudarc changes. + let elem_count = shape.elem_count(); + let curand = self.curand.lock().unwrap(); + // curand can only generate an odd number of values. + // https://github.com/huggingface/mcandle/issues/734 + let elem_count_round = if elem_count % 2 == 1 { + elem_count + 1 + } else { + elem_count + }; + let slice = match dtype { + DType::U8 + | DType::U32 + | DType::I16 + | DType::I32 + | DType::I64 + | DType::F16 + | DType::BF16 + | DType::F8E4M3 => Err(CudaError::UnsupportedDtype { + dtype, + op: "rand_normal", + }) + .w()?, + DType::F32 => { + let mut data = unsafe { self.alloc::(elem_count_round) }.w()?; + curand + .0 + .fill_with_normal(&mut data, mean as f32, std as f32) + .w()?; + CudaStorageSlice::F32(data) + } + DType::F64 => { + let mut data = unsafe { self.alloc::(elem_count_round) }.w()?; + curand.0.fill_with_normal(&mut data, mean, std).w()?; + CudaStorageSlice::F64(data) + } + }; + Ok(CudaStorage { + slice, + device: self.clone(), + }) + } + + fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result { + self.const_impl(1., shape, dtype) + } + + unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result { + let elem_count = shape.elem_count(); + let slice = match dtype { + DType::U8 => { + let data = self.alloc::(elem_count).w()?; + CudaStorageSlice::U8(data) + } + DType::U32 => { + let data = self.alloc::(elem_count).w()?; + CudaStorageSlice::U32(data) + } + DType::I16 => { + let data = self.alloc::(elem_count).w()?; + CudaStorageSlice::I16(data) + } + DType::I32 => { + let data = self.alloc::(elem_count).w()?; + CudaStorageSlice::I32(data) + } + DType::I64 => { + let data = self.alloc::(elem_count).w()?; + CudaStorageSlice::I64(data) + } + DType::BF16 => { + let data = self.alloc::(elem_count).w()?; + CudaStorageSlice::BF16(data) + } + DType::F16 => { + let data = self.alloc::(elem_count).w()?; + CudaStorageSlice::F16(data) + } + DType::F32 => { + let data = self.alloc::(elem_count).w()?; + CudaStorageSlice::F32(data) + } + DType::F64 => { + let data = self.alloc::(elem_count).w()?; + CudaStorageSlice::F64(data) + } + DType::F8E4M3 => { + let data = self.alloc::(elem_count).w()?; + CudaStorageSlice::F8E4M3(data) + } + }; + Ok(CudaStorage { + slice, + device: self.clone(), + }) + } + + fn storage_from_slice(&self, s: &[T]) -> Result { + let slice = match T::cpu_storage_ref(s) { + CpuStorageRef::U8(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::U8(data) + } + CpuStorageRef::U32(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::U32(data) + } + CpuStorageRef::I16(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::I16(data) + } + CpuStorageRef::I32(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::I32(data) + } + CpuStorageRef::I64(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::I64(data) + } + CpuStorageRef::BF16(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::BF16(data) + } + CpuStorageRef::F16(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::F16(data) + } + CpuStorageRef::F32(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::F32(data) + } + CpuStorageRef::F64(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::F64(data) + } + CpuStorageRef::F8E4M3(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::F8E4M3(data) + } + }; + Ok(CudaStorage { + slice, + device: self.clone(), + }) + } + + fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result { + let slice = match storage { + CpuStorage::U8(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::U8(data) + } + CpuStorage::U32(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::U32(data) + } + CpuStorage::I16(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::I16(data) + } + CpuStorage::I32(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::I32(data) + } + CpuStorage::I64(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::I64(data) + } + CpuStorage::BF16(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::BF16(data) + } + CpuStorage::F16(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::F16(data) + } + CpuStorage::F32(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::F32(data) + } + CpuStorage::F64(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::F64(data) + } + CpuStorage::F8E4M3(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::F8E4M3(data) + } + }; + Ok(CudaStorage { + slice, + device: self.clone(), + }) + } + + fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result { + let slice = match storage { + CpuStorage::U8(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::U8(data) + } + CpuStorage::U32(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::U32(data) + } + CpuStorage::I16(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::I16(data) + } + CpuStorage::I32(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::I32(data) + } + CpuStorage::I64(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::I64(data) + } + CpuStorage::BF16(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::BF16(data) + } + CpuStorage::F16(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::F16(data) + } + CpuStorage::F32(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::F32(data) + } + CpuStorage::F64(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::F64(data) + } + CpuStorage::F8E4M3(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::F8E4M3(data) + } + }; + Ok(CudaStorage { + slice, + device: self.clone(), + }) + } + + fn synchronize(&self) -> Result<()> { + self.device.synchronize().map_err(crate::Error::wrap)?; + Ok(()) + } +} diff --git a/mcandle/mcandle-core/src/cuda_backend/error.rs b/mcandle/mcandle-core/src/cuda_backend/error.rs new file mode 100644 index 000000000..bd6f8ac6d --- /dev/null +++ b/mcandle/mcandle-core/src/cuda_backend/error.rs @@ -0,0 +1,62 @@ +use crate::{DType, Layout}; + +/// cudarc related errors +#[derive(thiserror::Error, Debug)] +pub enum CudaError { + #[error(transparent)] + Cuda(#[from] cudarc::driver::DriverError), + + #[error(transparent)] + Compiler(#[from] cudarc::nvrtc::CompileError), + + #[error(transparent)] + Cublas(#[from] cudarc::cublas::result::CublasError), + + #[error(transparent)] + Curand(#[from] cudarc::curand::result::CurandError), + + #[error("missing kernel '{module_name}'")] + MissingKernel { module_name: String }, + + #[error("unsupported dtype {dtype:?} for {op}")] + UnsupportedDtype { dtype: DType, op: &'static str }, + + #[error("internal error '{0}'")] + InternalError(&'static str), + + #[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")] + MatMulNonContiguous { + lhs_stride: Layout, + rhs_stride: Layout, + mnk: (usize, usize, usize), + }, + + #[error("{msg}, expected: {expected:?}, got: {got:?}")] + UnexpectedDType { + msg: &'static str, + expected: DType, + got: DType, + }, + + #[error("{cuda} when loading {module_name}")] + Load { + cuda: cudarc::driver::DriverError, + module_name: String, + }, +} + +impl From for crate::Error { + fn from(val: CudaError) -> Self { + crate::Error::Cuda(Box::new(val)).bt() + } +} + +pub trait WrapErr { + fn w(self) -> std::result::Result; +} + +impl> WrapErr for std::result::Result { + fn w(self) -> std::result::Result { + self.map_err(|e| crate::Error::Cuda(Box::new(e.into())).bt()) + } +} diff --git a/mcandle/mcandle-core/src/cuda_backend/mod.rs b/mcandle/mcandle-core/src/cuda_backend/mod.rs new file mode 100644 index 000000000..6beaa8108 --- /dev/null +++ b/mcandle/mcandle-core/src/cuda_backend/mod.rs @@ -0,0 +1,2467 @@ +use crate::backend::{BackendDevice, BackendStorage}; +use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; +use crate::{CpuStorage, DType, Layout, Result, Shape, WithDType}; +pub use cudarc; +use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig}; +use cudarc::driver::{ + CudaSlice, DevicePtr, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig, ValidAsZeroBits, +}; +use float8::F8E4M3; +use half::{bf16, f16}; +pub use mcandle_kernels as kernels; + +#[cfg(feature = "cudnn")] +pub mod cudnn; +mod device; +mod error; +mod utils; +pub use device::{CudaDevice, DeviceId}; +pub use error::{CudaError, WrapErr}; +pub use utils::{Map1, Map1Any, Map2, Map2Any, Map2InPlace, Map3, S}; + +pub enum SlicePtrOrNull { + Ptr(CudaSlice), + Null, +} + +unsafe impl DeviceRepr for &SlicePtrOrNull { + fn as_kernel_param(&self) -> *mut std::ffi::c_void { + match self { + SlicePtrOrNull::Ptr(slice) => slice.as_kernel_param(), + SlicePtrOrNull::Null => 0usize.as_kernel_param(), + } + } +} + +impl SlicePtrOrNull { + pub fn params_from_layout(dev: &CudaDevice, l: &Layout) -> Result { + let ds = if l.is_contiguous() { + SlicePtrOrNull::Null + } else { + SlicePtrOrNull::Ptr(dev.htod_copy([l.dims(), l.stride()].concat()).w()?) + }; + Ok(ds) + } +} + +#[derive(Debug)] +pub enum CudaStorageSlice { + U8(CudaSlice), + U32(CudaSlice), + I16(CudaSlice), + I32(CudaSlice), + I64(CudaSlice), + BF16(CudaSlice), + F16(CudaSlice), + F32(CudaSlice), + F64(CudaSlice), + F8E4M3(CudaSlice), +} + +struct Clone; +impl Map1 for Clone { + fn f( + &self, + s: &CudaSlice, + _: &CudaDevice, + _: &Layout, + ) -> Result> { + s.try_clone().w() + } +} + +pub fn kernel_name(root: &str) -> String { + let dtype = T::DTYPE.as_str(); + format!("{root}_{dtype}") +} + +struct Affine(f64, f64); +impl Map1 for Affine { + fn f( + &self, + src: &CudaSlice, + dev: &CudaDevice, + layout: &Layout, + ) -> Result> { + let shape = layout.shape(); + let dims = shape.dims(); + let el = shape.elem_count(); + let cfg = LaunchConfig::for_num_elems(el as u32); + let ds = SlicePtrOrNull::params_from_layout(dev, layout)?; + let src = &src.slice(layout.start_offset()..); + let func = dev.get_or_load_func(&kernel_name::("affine"), kernels::AFFINE)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(el) }.w()?; + let params = ( + el, + dims.len(), + &ds, + src, + &out, + T::from_f64(self.0), + T::from_f64(self.1), + ); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(out) + } +} + +struct Elu(f64); +impl Map1 for Elu { + fn f( + &self, + src: &CudaSlice, + dev: &CudaDevice, + layout: &Layout, + ) -> Result> { + let shape = layout.shape(); + let dims = shape.dims(); + let el = shape.elem_count(); + let cfg = LaunchConfig::for_num_elems(el as u32); + let ds = SlicePtrOrNull::params_from_layout(dev, layout)?; + let src = &src.slice(layout.start_offset()..); + let func = dev.get_or_load_func(&kernel_name::("uelu"), kernels::UNARY)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(el) }.w()?; + let params = (el, dims.len(), &ds, T::from_f64(self.0), src, &out); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(out) + } +} + +struct Im2Col1D { + l_k: usize, + stride: usize, + dilation: usize, + padding: usize, +} + +impl Im2Col1D { + fn l_out(&self, l: usize) -> usize { + (l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1 + } +} + +impl Map1 for Im2Col1D { + fn f( + &self, + src: &CudaSlice, + dev: &CudaDevice, + layout: &Layout, + ) -> Result> { + let shape = layout.shape(); + let dims = shape.dims(); + let l_out = self.l_out(dims[2]); + let dst_el = dims[0] * l_out * dims[1] * self.l_k; + let cfg = LaunchConfig::for_num_elems(dst_el as u32); + let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?; + let src = &src.slice(layout.start_offset()..); + let func = dev.get_or_load_func(&kernel_name::("im2col1d"), kernels::CONV)?; + // SAFETY: Set later by running the kernel. + let dst = unsafe { dev.alloc::(dst_el) }.w()?; + let params = ( + dst_el, + l_out, + self.l_k, + self.stride, + self.padding, + self.dilation, + &ds, + src, + &dst, + ); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(dst) + } +} + +#[allow(unused)] +struct Im2Col { + h_k: usize, + w_k: usize, + stride: usize, + dilation: usize, + padding: usize, +} + +impl Im2Col { + #[allow(unused)] + fn hw_out(&self, h: usize, w: usize) -> (usize, usize) { + let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1; + let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1; + (h_out, w_out) + } +} + +impl Map1 for Im2Col { + fn f( + &self, + src: &CudaSlice, + dev: &CudaDevice, + layout: &Layout, + ) -> Result> { + let shape = layout.shape(); + let dims = shape.dims(); + let (h_out, w_out) = self.hw_out(dims[2], dims[3]); + let dst_el = dims[0] * h_out * w_out * dims[1] * self.h_k * self.w_k; + let cfg = LaunchConfig::for_num_elems(dst_el as u32); + let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?; + let src = &src.slice(layout.start_offset()..); + let func = dev.get_or_load_func(&kernel_name::("im2col"), kernels::CONV)?; + // SAFETY: Set later by running the kernel. + let dst = unsafe { dev.alloc::(dst_el) }.w()?; + let params = ( + dst_el, + h_out, + w_out, + self.h_k, + self.w_k, + self.stride, + self.padding, + self.dilation, + &ds, + src, + &dst, + ); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(dst) + } +} + +struct Powf(f64); +impl Map1 for Powf { + fn f( + &self, + src: &CudaSlice, + dev: &CudaDevice, + layout: &Layout, + ) -> Result> { + let shape = layout.shape(); + let dims = shape.dims(); + let el = shape.elem_count(); + let cfg = LaunchConfig::for_num_elems(el as u32); + let ds = SlicePtrOrNull::params_from_layout(dev, layout)?; + let src = &src.slice(layout.start_offset()..); + let func = dev.get_or_load_func(&kernel_name::("upowf"), kernels::UNARY)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(el) }.w()?; + let params = (el, dims.len(), &ds, T::from_f64(self.0), src, &out); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(out) + } +} + +struct FastReduce<'a>(&'a [usize], ReduceOp); +impl<'a> Map1Any for FastReduce<'a> { + fn f) -> S>( + &self, + src: &CudaSlice, + dev: &CudaDevice, + layout: &Layout, + wrap: W, + ) -> Result { + let src_stride = layout.stride(); + let src_dims = layout.shape().dims(); + let src_el: usize = src_dims.iter().product(); + // Source dims and strides with the sum dims at the end. + let mut dims = vec![]; + let mut stride = vec![]; + let mut dst_el: usize = 1; + for (dim_idx, &d) in src_dims.iter().enumerate() { + if !self.0.contains(&dim_idx) { + dst_el *= d; + dims.push(d); + stride.push(src_stride[dim_idx]); + } + } + for &dim_idx in self.0.iter() { + dims.push(src_dims[dim_idx]); + stride.push(src_stride[dim_idx]); + } + let el_to_sum_per_block = src_el / dst_el; + // The reduction loop requires the shared array to be properly initialized and for + // this we want the number of threads to be a power of two. + let block_dim = usize::min(1024, el_to_sum_per_block).next_power_of_two(); + let cfg = LaunchConfig { + // TODO: Maybe use grid_y if the output is too large? + // TODO: Specialized implementation when reducing on no or all dimensions or when + // reducing only aggregate a small number of elements together. + grid_dim: (dst_el as u32, 1, 1), + block_dim: (block_dim as u32, 1, 1), + shared_mem_bytes: 0, + }; + let ds = dev + .htod_copy([dims.as_slice(), stride.as_slice()].concat()) + .w()?; + let src = &src.slice(layout.start_offset()..); + let (name, check_empty, return_index) = match self.1 { + ReduceOp::Sum => ("fast_sum", false, false), + ReduceOp::Min => ("fast_min", true, false), + ReduceOp::Max => ("fast_max", true, false), + ReduceOp::ArgMin => ("fast_argmin", true, true), + ReduceOp::ArgMax => ("fast_argmax", true, true), + }; + if check_empty && layout.shape().elem_count() == 0 { + Err(crate::Error::EmptyTensor { op: "reduce" }.bt())? + } + let func = dev.get_or_load_func(&kernel_name::(name), kernels::REDUCE)?; + if return_index { + // SAFETY: filled in by the follow up kernel. + let out = unsafe { dev.alloc::(dst_el) }.w()?; + let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(S::U32(out)) + } else { + // SAFETY: filled in by the follow up kernel. + let out = unsafe { dev.alloc::(dst_el) }.w()?; + let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(wrap(out)) + } + } +} + +impl Map1 for U { + fn f( + &self, + src: &CudaSlice, + dev: &CudaDevice, + layout: &Layout, + ) -> Result> { + let shape = layout.shape(); + let dims = shape.dims(); + let el_count = shape.elem_count(); + let cfg = LaunchConfig::for_num_elems(el_count as u32); + let ds = SlicePtrOrNull::params_from_layout(dev, layout)?; + let src = &src.slice(layout.start_offset()..); + let func = dev.get_or_load_func(&kernel_name::(U::KERNEL), kernels::UNARY)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(el_count) }.w()?; + let params = (el_count, dims.len(), &ds, src, &out); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(out) + } +} + +struct IndexSelect<'a>(&'a CudaStorage, &'a Layout, usize); +impl<'a> Map1 for IndexSelect<'a> { + fn f( + &self, + src: &CudaSlice, + dev: &CudaDevice, + src_l: &Layout, + ) -> Result> { + let ids_l = &self.1; + let (name, ids) = match &self.0.slice { + CudaStorageSlice::U32(slice) => { + ("is_u32", *slice.slice(ids_l.start_offset()..).device_ptr()) + } + CudaStorageSlice::U8(slice) => { + ("is_u8", *slice.slice(ids_l.start_offset()..).device_ptr()) + } + CudaStorageSlice::I16(slice) => { + ("is_i16", *slice.slice(ids_l.start_offset()..).device_ptr()) + } + CudaStorageSlice::I32(slice) => { + ("is_i32", *slice.slice(ids_l.start_offset()..).device_ptr()) + } + CudaStorageSlice::I64(slice) => { + ("is_i64", *slice.slice(ids_l.start_offset()..).device_ptr()) + } + _ => Err(CudaError::UnexpectedDType { + msg: "index_select ids should be u8/u32/i16/i32/i64", + expected: DType::U32, + got: self.0.dtype(), + }) + .w()?, + }; + let ids_shape = ids_l.shape(); + let ids_dims = ids_shape.dims(); + let ds = dev.htod_copy([ids_dims, ids_l.stride()].concat()).w()?; + let src = match src_l.contiguous_offsets() { + Some((o1, o2)) => src.slice(o1..o2), + None => Err(crate::Error::RequiresContiguous { op: "index-select" }.bt())?, + }; + let left_size: usize = src_l.dims()[..self.2].iter().product(); + let right_size: usize = src_l.dims()[self.2 + 1..].iter().product(); + let src_dim_size = src_l.dims()[self.2]; + let ids_dim_size = ids_shape.elem_count(); + let dst_el = ids_shape.elem_count() * left_size * right_size; + let cfg = LaunchConfig::for_num_elems(dst_el as u32); + let func = dev.get_or_load_func(&kernel_name::(name), kernels::INDEXING)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(dst_el) }.w()?; + let params = ( + dst_el, + ids_dims.len(), + &ds, + ids, + &src, + &out, + left_size, + src_dim_size, + ids_dim_size, + right_size, + ); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(out) + } +} + +struct Gather<'a>(&'a CudaStorage, &'a Layout, usize); +impl<'a> Map1 for Gather<'a> { + fn f( + &self, + src: &CudaSlice, + dev: &CudaDevice, + src_l: &Layout, + ) -> Result> { + let ids = &self.0; + let ids_l = &self.1; + let dim = self.2; + let (ids_o1, ids_o2) = match ids_l.contiguous_offsets() { + Some(o12) => o12, + None => Err(crate::Error::RequiresContiguous { op: "gather" }.bt())?, + }; + let (name, ids) = match &ids.slice { + CudaStorageSlice::U32(slice) => { + ("gather_u32", *slice.slice(ids_o1..ids_o2).device_ptr()) + } + CudaStorageSlice::U8(slice) => ("gather_u8", *slice.slice(ids_o1..ids_o2).device_ptr()), + CudaStorageSlice::I16(slice) => { + ("gather_i16", *slice.slice(ids_o1..ids_o2).device_ptr()) + } + CudaStorageSlice::I32(slice) => { + ("gather_i32", *slice.slice(ids_o1..ids_o2).device_ptr()) + } + CudaStorageSlice::I64(slice) => { + ("gather_i64", *slice.slice(ids_o1..ids_o2).device_ptr()) + } + _ => Err(CudaError::UnexpectedDType { + msg: "gather ids should be u8/u32/i16/i32/i64", + expected: DType::U32, + got: ids.dtype(), + })?, + }; + let el = ids_l.shape().elem_count(); + let cfg = LaunchConfig::for_num_elems(el as u32); + let src = match src_l.contiguous_offsets() { + Some((o1, o2)) => src.slice(o1..o2), + None => Err(crate::Error::RequiresContiguous { op: "gather" }.bt())?, + }; + let left_sz: usize = src_l.dims()[..dim].iter().product(); + let right_sz: usize = src_l.dims()[dim + 1..].iter().product(); + let src_dim_sz = src_l.dims()[dim]; + let ids_dim_sz = ids_l.dims()[dim]; + let func = dev.get_or_load_func(&kernel_name::(name), kernels::INDEXING)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(el) }.w()?; + let params = ( + el, ids, &src, &out, left_sz, src_dim_sz, ids_dim_sz, right_sz, + ); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(out) + } +} + +struct IndexAdd<'a>(&'a CudaStorage, &'a Layout, usize); +impl<'a> Map2InPlace for IndexAdd<'a> { + fn f( + &self, + dst: &mut CudaSlice, + dst_shape: &Shape, + src: &CudaSlice, + src_l: &Layout, + dev: &CudaDevice, + ) -> Result<()> { + let ids = &self.0; + let ids_l = &self.1; + let dim = self.2; + let (ids_o1, ids_o2) = match ids_l.contiguous_offsets() { + Some(o12) => o12, + None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?, + }; + let (name, ids) = match &ids.slice { + CudaStorageSlice::U32(slice) => ("ia_u32", *slice.slice(ids_o1..ids_o2).device_ptr()), + CudaStorageSlice::I16(slice) => ("ia_i16", *slice.slice(ids_o1..ids_o2).device_ptr()), + CudaStorageSlice::I32(slice) => ("ia_i32", *slice.slice(ids_o1..ids_o2).device_ptr()), + CudaStorageSlice::I64(slice) => ("ia_i64", *slice.slice(ids_o1..ids_o2).device_ptr()), + CudaStorageSlice::U8(slice) => ("ia_u8", *slice.slice(ids_o1..ids_o2).device_ptr()), + _ => Err(CudaError::UnexpectedDType { + msg: "index-add ids should be u8/u32/i16/i32/i64", + expected: DType::U32, + got: ids.dtype(), + })?, + }; + let src = match src_l.contiguous_offsets() { + Some((o1, o2)) => src.slice(o1..o2), + None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?, + }; + let left_sz: usize = src_l.dims()[..dim].iter().product(); + let right_sz: usize = src_l.dims()[dim + 1..].iter().product(); + let src_dim_sz = src_l.dims()[dim]; + let dst_dim_sz = dst_shape.dims()[dim]; + let ids_dim_sz = ids_l.dims()[0]; + let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32); + let func = dev.get_or_load_func(&kernel_name::(name), kernels::INDEXING)?; + // SAFETY: Set later by running the kernel. + let params = ( + ids, ids_dim_sz, &src, dst, left_sz, src_dim_sz, dst_dim_sz, right_sz, + ); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(()) + } +} + +struct ScatterAdd<'a>(&'a CudaStorage, &'a Layout, usize); +impl<'a> Map2InPlace for ScatterAdd<'a> { + fn f( + &self, + dst: &mut CudaSlice, + dst_shape: &Shape, + src: &CudaSlice, + src_l: &Layout, + dev: &CudaDevice, + ) -> Result<()> { + let ids = &self.0; + let ids_l = &self.1; + let dim = self.2; + let (ids_o1, ids_o2) = match ids_l.contiguous_offsets() { + Some(o12) => o12, + None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?, + }; + let (name, ids) = match &ids.slice { + CudaStorageSlice::U32(slice) => ("sa_u32", *slice.slice(ids_o1..ids_o2).device_ptr()), + CudaStorageSlice::I16(slice) => ("sa_i16", *slice.slice(ids_o1..ids_o2).device_ptr()), + CudaStorageSlice::I32(slice) => ("sa_i32", *slice.slice(ids_o1..ids_o2).device_ptr()), + CudaStorageSlice::I64(slice) => ("sa_i64", *slice.slice(ids_o1..ids_o2).device_ptr()), + CudaStorageSlice::U8(slice) => ("sa_u8", *slice.slice(ids_o1..ids_o2).device_ptr()), + _ => Err(CudaError::UnexpectedDType { + msg: "scatter-add ids should be u8/u32/i16/i32/i64", + expected: DType::U32, + got: ids.dtype(), + })?, + }; + let src = match src_l.contiguous_offsets() { + Some((o1, o2)) => src.slice(o1..o2), + None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?, + }; + let left_sz: usize = src_l.dims()[..dim].iter().product(); + let right_sz: usize = src_l.dims()[dim + 1..].iter().product(); + let src_dim_sz = src_l.dims()[dim]; + let dst_dim_sz = dst_shape.dims()[dim]; + let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32); + let func = dev.get_or_load_func(&kernel_name::(name), kernels::INDEXING)?; + // SAFETY: Set later by running the kernel. + let params = (ids, &src, dst, left_sz, src_dim_sz, dst_dim_sz, right_sz); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(()) + } +} + +struct Conv1D<'a>(&'a crate::conv::ParamsConv1D); +impl<'a> Map2 for Conv1D<'a> { + fn f( + &self, + inp: &CudaSlice, + inp_l: &Layout, + k: &CudaSlice, + k_l: &Layout, + dev: &CudaDevice, + ) -> Result> { + // Kernel shape: (c_out, c_in_k, k_size) + // Input shape: (b_size, c_in, l_in) or (c_in, l_in) + let p = &self.0; + let inp = &inp.slice(inp_l.start_offset()..); + let k = &k.slice(k_l.start_offset()..); + let shape = inp_l.shape(); + let dims = shape.dims(); + let el = shape.elem_count(); + let l_out = p.l_out(); + let dst_el = p.c_out * l_out * p.b_size; + let cfg = LaunchConfig::for_num_elems(dst_el as u32); + let func = dev.get_or_load_func(&kernel_name::("conv1d"), kernels::CONV)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(dst_el) }.w()?; + let ds = if dims.len() == 3 { + [dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat() + } else if dims.len() == 2 { + [&[1], dims, &[1], inp_l.stride(), k_l.dims(), k_l.stride()].concat() + } else { + crate::bail!("unexpected input shape for conv1d {dims:?}") + }; + let ds = dev.htod_copy(ds).w()?; + let params = ( + el, l_out, p.stride, p.padding, p.dilation, &ds, inp, k, &out, + ); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(out) + } +} + +struct Conv2D<'a>(&'a crate::conv::ParamsConv2D); +impl<'a> Map2 for Conv2D<'a> { + fn f( + &self, + inp: &CudaSlice, + inp_l: &Layout, + k: &CudaSlice, + k_l: &Layout, + dev: &CudaDevice, + ) -> Result> { + // Kernel shape: (c_out, c_in_k, h_k, w_k) + // Input shape: (b_size, c_in, h_in, w_in) + let p = &self.0; + let (out_w, out_h) = (p.out_w(), p.out_h()); + let dst_el = p.c_out * out_w * out_h * p.b_size; + let inp = &inp.slice(inp_l.start_offset()..); + let k = &k.slice(k_l.start_offset()..); + let shape = inp_l.shape(); + let dims = shape.dims(); + let el = shape.elem_count(); + + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(dst_el) }.w()?; + let cfg = LaunchConfig::for_num_elems(dst_el as u32); + let func = dev.get_or_load_func(&kernel_name::("conv2d"), kernels::CONV)?; + let ds = if dims.len() == 4 { + [dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat() + } else { + crate::bail!("unexpected input shape for conv2d {dims:?}") + }; + let ds = dev.htod_copy(ds).w()?; + let params = ( + el, out_w, out_h, p.stride, p.padding, p.dilation, &ds, inp, k, &out, + ); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(out) + } +} + +struct Col2Im1D { + stride: usize, +} + +impl Map1 for Col2Im1D { + fn f( + &self, + col: &CudaSlice, + dev: &CudaDevice, + l: &Layout, + ) -> Result> { + let (b_size, l_in, c_out, k_size) = l.shape().dims4()?; + let stride = self.stride; + let l_out = (l_in - 1) * stride + k_size; + let dst_el = b_size * c_out * l_out; + let mut im = unsafe { dev.alloc::(dst_el) }.w()?; + + let cfg = LaunchConfig::for_num_elems(dst_el as u32); + let params = (dst_el, l_out, l_in, c_out, k_size, stride, col, &mut im); + let func = dev.get_or_load_func(&kernel_name::("col2im1d"), kernels::CONV)?; + unsafe { func.launch(cfg, params) }.w()?; + Ok(im) + } +} + +struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D); +impl<'a> Map2 for ConvTranspose1D<'a> { + fn f( + &self, + inp: &CudaSlice, + inp_l: &Layout, + k: &CudaSlice, + k_l: &Layout, + dev: &CudaDevice, + ) -> Result> { + // Kernel shape: (c_in_k, c_out, l_k) + // Input shape: (b_size, c_in, l_in) + let p = &self.0; + let l_out = p.l_out(); + let dst_el = p.c_out * l_out * p.b_size; + let inp = &inp.slice(inp_l.start_offset()..); + let k = &k.slice(k_l.start_offset()..); + let shape = inp_l.shape(); + let dims = shape.dims(); + let el = shape.elem_count(); + + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(dst_el) }.w()?; + let cfg = LaunchConfig::for_num_elems(dst_el as u32); + let func = dev.get_or_load_func(&kernel_name::("conv_transpose1d"), kernels::CONV)?; + let ds = if dims.len() == 3 { + [dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat() + } else { + crate::bail!("unexpected input shape for conv_transpose1d {dims:?}") + }; + let ds = dev.htod_copy(ds).w()?; + let params = ( + el, + l_out, + p.stride, + p.padding, + p.output_padding, + p.dilation, + &ds, + inp, + k, + &out, + ); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(out) + } +} + +struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D); +impl<'a> Map2 for ConvTranspose2D<'a> { + fn f( + &self, + inp: &CudaSlice, + inp_l: &Layout, + k: &CudaSlice, + k_l: &Layout, + dev: &CudaDevice, + ) -> Result> { + // Kernel shape: (c_in_k, c_out, h_k, w_k) + // Input shape: (b_size, c_in, h_in, w_in) + let p = &self.0; + let (out_w, out_h) = (p.out_w(), p.out_h()); + let dst_el = p.c_out * out_w * out_h * p.b_size; + let inp = &inp.slice(inp_l.start_offset()..); + let k = &k.slice(k_l.start_offset()..); + let shape = inp_l.shape(); + let dims = shape.dims(); + let el = shape.elem_count(); + + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(dst_el) }.w()?; + let cfg = LaunchConfig::for_num_elems(dst_el as u32); + let func = dev.get_or_load_func(&kernel_name::("conv_transpose2d"), kernels::CONV)?; + let ds = if dims.len() == 4 { + [dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat() + } else { + crate::bail!("unexpected input shape for conv_transpose2d {dims:?}") + }; + let ds = dev.htod_copy(ds).w()?; + let params = ( + el, + out_w, + out_h, + p.stride, + p.padding, + p.output_padding, + p.dilation, + &ds, + inp, + k, + &out, + ); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(out) + } +} + +enum PoolOp { + Max, + Avg, +} + +struct Pool2D { + w_k: usize, + h_k: usize, + w_stride: usize, + h_stride: usize, + op: PoolOp, +} + +impl Map1 for Pool2D { + fn f( + &self, + inp: &CudaSlice, + dev: &CudaDevice, + inp_l: &Layout, + ) -> Result> { + // Input shape: (b_size, c, h, w) + let inp = &inp.slice(inp_l.start_offset()..); + let shape = inp_l.shape(); + let dims = shape.dims(); + let ds = if dims.len() == 4 { + [dims, inp_l.stride()].concat() + } else { + crate::bail!("unexpected input shape for pool {dims:?}") + }; + let el = shape.elem_count(); + let out_w = (dims[2] - self.w_k) / self.w_stride + 1; + let out_h = (dims[3] - self.h_k) / self.h_stride + 1; + let dst_el = out_w * out_h * dims[0] * dims[1]; + let cfg = LaunchConfig::for_num_elems(dst_el as u32); + let kname = match self.op { + PoolOp::Max => "max_pool2d", + PoolOp::Avg => "avg_pool2d", + }; + let func = dev.get_or_load_func(&kernel_name::(kname), kernels::CONV)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(dst_el) }.w()?; + let ds = dev.htod_copy(ds).w()?; + let params = ( + el, + self.w_k, + self.h_k, + self.w_stride, + self.h_stride, + &ds, + inp, + &out, + ); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(out) + } +} + +struct UpsampleNearest2D(usize, usize); +impl Map1 for UpsampleNearest2D { + fn f( + &self, + inp: &CudaSlice, + dev: &CudaDevice, + inp_l: &Layout, + ) -> Result> { + // Input shape: (b_size, c, h, w) + let inp = &inp.slice(inp_l.start_offset()..); + let shape = inp_l.shape(); + let dims = shape.dims(); + let ds = if dims.len() == 4 { + [dims, inp_l.stride()].concat() + } else { + crate::bail!("unexpected input shape for upsample {dims:?}") + }; + let (out_w, out_h) = (self.0, self.1); + let dst_el = out_w * out_h * dims[0] * dims[1]; + let cfg = LaunchConfig::for_num_elems(dst_el as u32); + let func = dev.get_or_load_func(&kernel_name::("upsample_nearest2d"), kernels::CONV)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(dst_el) }.w()?; + let ds = dev.htod_copy(ds).w()?; + let scale_w = dims[2] as f64 / out_w as f64; + let scale_h = dims[3] as f64 / out_h as f64; + let params = (out_w, out_h, scale_w, scale_h, &ds, inp, &out); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(out) + } +} + +struct WhereCond<'a>(&'a CudaStorage, &'a Layout); +impl<'a> Map2 for WhereCond<'a> { + fn f( + &self, + t: &CudaSlice, + layout_t: &Layout, + f: &CudaSlice, + layout_f: &Layout, + dev: &CudaDevice, + ) -> Result> { + let ids_l = &self.1; + let (ids, name) = match &self.0.slice { + CudaStorageSlice::U8(slice) => { + let ptr = *slice.slice(ids_l.start_offset()..).device_ptr(); + (ptr, "where_u8") + } + CudaStorageSlice::U32(slice) => { + let ptr = *slice.slice(ids_l.start_offset()..).device_ptr(); + (ptr, "where_u32") + } + CudaStorageSlice::I16(slice) => { + let ptr = *slice.slice(ids_l.start_offset()..).device_ptr(); + (ptr, "where_i16") + } + CudaStorageSlice::I32(slice) => { + let ptr = *slice.slice(ids_l.start_offset()..).device_ptr(); + (ptr, "where_i32") + } + CudaStorageSlice::I64(slice) => { + let ptr = *slice.slice(ids_l.start_offset()..).device_ptr(); + (ptr, "where_i64") + } + _ => Err(CudaError::UnexpectedDType { + msg: "where conditions should be u8/u32/i16/i32/i64", + expected: DType::U32, + got: self.0.dtype(), + }) + .w()?, + }; + let shape = ids_l.shape(); + let dims = shape.dims(); + let el = shape.elem_count(); + let cfg = LaunchConfig::for_num_elems(el as u32); + let ds = dev + .htod_copy([dims, ids_l.stride(), layout_t.stride(), layout_f.stride()].concat()) + .w()?; + let t = &t.slice(layout_t.start_offset()..); + let f = &f.slice(layout_f.start_offset()..); + let func = dev.get_or_load_func(&kernel_name::(name), kernels::TERNARY)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(el) }.w()?; + let params = (el, dims.len(), &ds, ids, t, f, &out); + // SAFETY: ffi + unsafe { func.launch(cfg, params) }.w()?; + Ok(out) + } +} + +impl Map2 for U { + fn f( + &self, + lhs: &CudaSlice, + lhs_l: &Layout, + rhs: &CudaSlice, + rhs_l: &Layout, + dev: &CudaDevice, + ) -> Result> { + let shape = lhs_l.shape(); + let dims = shape.dims(); + let elem_count = shape.elem_count(); + let cfg = LaunchConfig::for_num_elems(elem_count as u32); + let dims_and_strides = if lhs_l.is_contiguous() && rhs_l.is_contiguous() { + SlicePtrOrNull::Null + } else { + SlicePtrOrNull::Ptr( + dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat()) + .w()?, + ) + }; + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); + let func = dev.get_or_load_func(&kernel_name::(U::KERNEL), kernels::BINARY)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(elem_count) }.w()?; + let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out); + // SAFETY: ffi + unsafe { func.launch(cfg, params) }.w()?; + Ok(out) + } +} + +struct Cmp(CmpOp); +impl Map2Any for Cmp { + fn f( + &self, + lhs: &CudaSlice, + lhs_l: &Layout, + rhs: &CudaSlice, + rhs_l: &Layout, + dev: &CudaDevice, + ) -> Result { + let shape = lhs_l.shape(); + let dims = shape.dims(); + let elem_count = shape.elem_count(); + let cfg = LaunchConfig::for_num_elems(elem_count as u32); + let dims_and_strides = if lhs_l.is_contiguous() && rhs_l.is_contiguous() { + SlicePtrOrNull::Null + } else { + SlicePtrOrNull::Ptr( + dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat()) + .w()?, + ) + }; + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); + let name = match self.0 { + CmpOp::Eq => "eq", + CmpOp::Ne => "ne", + CmpOp::Lt => "lt", + CmpOp::Le => "le", + CmpOp::Gt => "gt", + CmpOp::Ge => "ge", + }; + let func = dev.get_or_load_func(&kernel_name::(name), kernels::BINARY)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(elem_count) }.w()?; + let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out); + // SAFETY: ffi + unsafe { func.launch(cfg, params) }.w()?; + Ok(S::U8(out)) + } +} + +fn slice_src_and_dst<'a, T>( + src: &'a CudaSlice, + src_l: &Layout, + dst: &'a mut CudaSlice, + dst_offset: usize, +) -> ( + cudarc::driver::CudaView<'a, T>, + cudarc::driver::CudaViewMut<'a, T>, +) { + let src_offset = src_l.start_offset(); + let to_copy = dst + .len() + .saturating_sub(dst_offset) + .min(src.len().saturating_sub(src_offset)); + let src = src.slice(src_offset..src_offset + to_copy); + let dst = dst.slice_mut(dst_offset..dst_offset + to_copy); + (src, dst) +} + +#[derive(Debug)] +pub struct CudaStorage { + pub slice: CudaStorageSlice, + pub device: CudaDevice, +} + +pub trait CudaDType: Sized { + fn as_cuda_slice(s: &CudaStorage) -> Result<&CudaSlice>; + fn wrap_cuda_slice(s: CudaSlice, dev: CudaDevice) -> CudaStorage; +} + +macro_rules! cuda_dtype { + ($ty:ty, $dtype:ident) => { + impl CudaDType for $ty { + fn as_cuda_slice(s: &CudaStorage) -> Result<&CudaSlice> { + match &s.slice { + CudaStorageSlice::$dtype(data) => Ok(&data), + _ => Err(crate::Error::UnexpectedDType { + expected: DType::$dtype, + got: s.dtype(), + msg: "unexpected dtype", + } + .bt()), + } + } + + fn wrap_cuda_slice(slice: CudaSlice, device: CudaDevice) -> CudaStorage { + let slice = CudaStorageSlice::$dtype(slice); + CudaStorage { slice, device } + } + } + }; +} +cuda_dtype!(u8, U8); +cuda_dtype!(u32, U32); +cuda_dtype!(i16, I16); +cuda_dtype!(i32, I32); +cuda_dtype!(i64, I64); +cuda_dtype!(f16, F16); +cuda_dtype!(bf16, BF16); +cuda_dtype!(f32, F32); +cuda_dtype!(f64, F64); +cuda_dtype!(F8E4M3, F8E4M3); + +impl CudaStorage { + pub fn wrap_cuda_slice(slice: CudaSlice, device: CudaDevice) -> CudaStorage { + T::wrap_cuda_slice(slice, device) + } + + pub fn as_cuda_slice(&self) -> Result<&CudaSlice> { + T::as_cuda_slice(self) + } +} + +fn gemm_config( + alpha: T, + beta: T, + (b, m, n, k): (usize, usize, usize, usize), + lhs_l: &Layout, + rhs_l: &Layout, +) -> Result> { + // https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemm + use cudarc::cublas::sys::cublasOperation_t; + + let lhs_stride = lhs_l.stride(); + let rhs_stride = rhs_l.stride(); + let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; + let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; + let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; + let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; + // The a tensor has dims batching, k, n (rhs) + // We also allow for the case where the stride on the minor dimension is not as expected but + // there is a single element. + let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) { + (n as i32, cublasOperation_t::CUBLAS_OP_N) + } else if (rhs_m1 == k || n == 1) && (rhs_m2 == 1 || k == 1) { + (k as i32, cublasOperation_t::CUBLAS_OP_T) + } else { + Err(CudaError::MatMulNonContiguous { + lhs_stride: lhs_l.clone(), + rhs_stride: rhs_l.clone(), + mnk: (m, n, k), + })? + }; + // The b tensor has dims batching, m, k (lhs) + // We also allow for the case where the stride on the minor dimension is not as expected but + // there is a single element. + let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) { + (k as i32, cublasOperation_t::CUBLAS_OP_N) + } else if (lhs_m1 == m || k == 1) && (lhs_m2 == 1 || m == 1) { + (m as i32, cublasOperation_t::CUBLAS_OP_T) + } else { + Err(CudaError::MatMulNonContiguous { + lhs_stride: lhs_l.clone(), + rhs_stride: rhs_l.clone(), + mnk: (m, n, k), + })? + }; + // The setup below was copied from: + // https://github.com/lebedov/scikit-cuda/blob/7e7300474286019c917a6c8a4bca59405c64fbce/tests/test_cublas.py#L531 + let gemm = GemmConfig { + alpha, + beta, + m: n as i32, + n: m as i32, + k: k as i32, + lda, + ldb, + ldc: n as i32, + transa, + transb, + }; + + let stride_b: usize = match lhs_stride[..lhs_stride.len() - 2] { + [s1, stride] if s1 == stride * lhs_l.dims()[1] => stride, + [_, stride] if lhs_l.dims()[0] == 1 => stride, + [stride, _] if lhs_l.dims()[1] == 1 => stride, + [stride] => stride, + [] => m * k, + _ => Err(CudaError::MatMulNonContiguous { + lhs_stride: lhs_l.clone(), + rhs_stride: rhs_l.clone(), + mnk: (m, n, k), + })?, + }; + let stride_a: usize = match rhs_stride[..rhs_stride.len() - 2] { + [s1, stride] if s1 == stride * rhs_l.dims()[1] => stride, + [_, stride] if rhs_l.dims()[0] == 1 => stride, + [stride, _] if rhs_l.dims()[1] == 1 => stride, + [stride] => stride, + [] => n * k, + _ => Err(CudaError::MatMulNonContiguous { + lhs_stride: lhs_l.clone(), + rhs_stride: rhs_l.clone(), + mnk: (m, n, k), + })?, + }; + + Ok(StridedBatchedConfig { + batch_size: b as i32, + gemm, + stride_a: stride_a as i64, + stride_b: stride_b as i64, + stride_c: (m * n) as i64, + }) +} + +impl BackendStorage for CudaStorage { + type Device = CudaDevice; + + fn try_clone(&self, layout: &Layout) -> Result { + let slice = Clone.map(&self.slice, self.device(), layout)?; + let device = self.device.clone(); + Ok(Self { slice, device }) + } + + fn dtype(&self) -> DType { + match self.slice { + CudaStorageSlice::U8(_) => DType::U8, + CudaStorageSlice::U32(_) => DType::U32, + CudaStorageSlice::I16(_) => DType::I16, + CudaStorageSlice::I32(_) => DType::I32, + CudaStorageSlice::I64(_) => DType::I64, + CudaStorageSlice::BF16(_) => DType::BF16, + CudaStorageSlice::F16(_) => DType::F16, + CudaStorageSlice::F32(_) => DType::F32, + CudaStorageSlice::F64(_) => DType::F64, + CudaStorageSlice::F8E4M3(_) => DType::F8E4M3, + } + } + + fn device(&self) -> &CudaDevice { + &self.device + } + + fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result { + let shape = layout.shape(); + let dims = shape.dims(); + let el = shape.elem_count(); + let cfg = LaunchConfig::for_num_elems(el as u32); + let dev = self.device(); + let ds = SlicePtrOrNull::params_from_layout(dev, layout)?; + let start_o = layout.start_offset(); + // This returns an i64 rather than a &i64, this is useful to get around some temporary + // lifetime issue and is safe as long as self.slice does not go out of scope before inp + // is used. + let inp = match &self.slice { + CudaStorageSlice::U8(inp) => *inp.slice(start_o..).device_ptr(), + CudaStorageSlice::U32(inp) => *inp.slice(start_o..).device_ptr(), + CudaStorageSlice::I16(inp) => *inp.slice(start_o..).device_ptr(), + CudaStorageSlice::I32(inp) => *inp.slice(start_o..).device_ptr(), + CudaStorageSlice::I64(inp) => *inp.slice(start_o..).device_ptr(), + CudaStorageSlice::BF16(inp) => *inp.slice(start_o..).device_ptr(), + CudaStorageSlice::F16(inp) => *inp.slice(start_o..).device_ptr(), + CudaStorageSlice::F32(inp) => *inp.slice(start_o..).device_ptr(), + CudaStorageSlice::F64(inp) => *inp.slice(start_o..).device_ptr(), + CudaStorageSlice::F8E4M3(inp) => *inp.slice(start_o..).device_ptr(), + }; + let inp = &inp; + + let kernel_name = format!("cast_{}_{}", self.dtype().as_str(), dtype.as_str()); + let func = dev.get_or_load_func(&kernel_name, kernels::CAST)?; + let slice = match dtype { + DType::U8 => { + let out = unsafe { dev.alloc::(el) }.w()?; + let params = (el, dims.len(), &ds, *inp, &out); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::U8(out) + } + DType::U32 => { + let out = unsafe { dev.alloc::(el) }.w()?; + let params = (el, dims.len(), &ds, *inp, &out); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::U32(out) + } + DType::I16 => { + let out = unsafe { dev.alloc::(el) }.w()?; + let params = (el, dims.len(), &ds, *inp, &out); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::I16(out) + } + DType::I32 => { + let out = unsafe { dev.alloc::(el) }.w()?; + let params = (el, dims.len(), &ds, *inp, &out); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::I32(out) + } + DType::I64 => { + let out = unsafe { dev.alloc::(el) }.w()?; + let params = (el, dims.len(), &ds, *inp, &out); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::I64(out) + } + DType::BF16 => { + let out = unsafe { dev.alloc::(el) }.w()?; + let params = (el, dims.len(), &ds, *inp, &out); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::BF16(out) + } + DType::F16 => { + let out = unsafe { dev.alloc::(el) }.w()?; + let params = (el, dims.len(), &ds, *inp, &out); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::F16(out) + } + DType::F32 => { + let out = unsafe { dev.alloc::(el) }.w()?; + let params = (el, dims.len(), &ds, *inp, &out); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::F32(out) + } + DType::F64 => { + let out = unsafe { dev.alloc::(el) }.w()?; + let params = (el, dims.len(), &ds, *inp, &out); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::F64(out) + } + DType::F8E4M3 => { + let out = unsafe { dev.alloc::(el) }.w()?; + let params = (el, dims.len(), &ds, *inp, &out); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::F8E4M3(out) + } + }; + Ok(Self { + slice, + device: dev.clone(), + }) + } + + fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result { + let device = self.device().clone(); + let slice = Affine(mul, add).map(&self.slice, &device, layout)?; + Ok(Self { slice, device }) + } + + fn powf(&self, layout: &Layout, e: f64) -> Result { + let device = self.device().clone(); + let slice = Powf(e).map(&self.slice, &device, layout)?; + Ok(Self { slice, device }) + } + + fn elu(&self, layout: &Layout, alpha: f64) -> Result { + let device = self.device().clone(); + let slice = Elu(alpha).map(&self.slice, &device, layout)?; + Ok(Self { slice, device }) + } + + fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result { + let device = self.device().clone(); + let slice = FastReduce(sum_dims, op).map(&self.slice, &device, layout)?; + Ok(Self { slice, device }) + } + + fn cmp(&self, op: CmpOp, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result { + let device = self.device().clone(); + let slice = Cmp(op).map(&self.slice, lhs_l, &rhs.slice, rhs_l, &device)?; + Ok(Self { slice, device }) + } + + fn unary_impl(&self, layout: &Layout) -> Result { + let device = self.device().clone(); + let slice = U::V.map(&self.slice, &device, layout)?; + Ok(Self { slice, device }) + } + + fn binary_impl( + &self, + rhs: &Self, + lhs_l: &Layout, + rhs_l: &Layout, + ) -> Result { + let device = self.device().clone(); + let slice = B::V.map(&self.slice, lhs_l, &rhs.slice, rhs_l, &device)?; + Ok(Self { slice, device }) + } + + fn to_cpu_storage(&self) -> Result { + match &self.slice { + CudaStorageSlice::U8(slice) => { + let dev = slice.device(); + let cpu_storage = dev.dtoh_sync_copy(slice).w()?; + Ok(CpuStorage::U8(cpu_storage)) + } + CudaStorageSlice::U32(slice) => { + let dev = slice.device(); + let cpu_storage = dev.dtoh_sync_copy(slice).w()?; + Ok(CpuStorage::U32(cpu_storage)) + } + CudaStorageSlice::I16(slice) => { + let dev = slice.device(); + let cpu_storage = dev.dtoh_sync_copy(slice).w()?; + Ok(CpuStorage::I16(cpu_storage)) + } + CudaStorageSlice::I32(slice) => { + let dev = slice.device(); + let cpu_storage = dev.dtoh_sync_copy(slice).w()?; + Ok(CpuStorage::I32(cpu_storage)) + } + CudaStorageSlice::I64(slice) => { + let dev = slice.device(); + let cpu_storage = dev.dtoh_sync_copy(slice).w()?; + Ok(CpuStorage::I64(cpu_storage)) + } + CudaStorageSlice::BF16(slice) => { + let dev = slice.device(); + let cpu_storage = dev.dtoh_sync_copy(slice).w()?; + Ok(CpuStorage::BF16(cpu_storage)) + } + CudaStorageSlice::F16(slice) => { + let dev = slice.device(); + let cpu_storage = dev.dtoh_sync_copy(slice).w()?; + Ok(CpuStorage::F16(cpu_storage)) + } + CudaStorageSlice::F32(slice) => { + let dev = slice.device(); + let cpu_storage = dev.dtoh_sync_copy(slice).w()?; + Ok(CpuStorage::F32(cpu_storage)) + } + CudaStorageSlice::F64(slice) => { + let dev = slice.device(); + let cpu_storage = dev.dtoh_sync_copy(slice).w()?; + Ok(CpuStorage::F64(cpu_storage)) + } + CudaStorageSlice::F8E4M3(slice) => { + let dev = slice.device(); + let cpu_storage = dev.dtoh_sync_copy(slice).w()?; + Ok(CpuStorage::F8E4M3(cpu_storage)) + } + } + } + + fn where_cond( + &self, + layout: &Layout, + t: &Self, + t_l: &Layout, + f: &Self, + f_l: &Layout, + ) -> Result { + let device = self.device().clone(); + let slice = WhereCond(self, layout).map(&t.slice, t_l, &f.slice, f_l, &device)?; + Ok(Self { slice, device }) + } + + fn conv1d( + &self, + l: &Layout, + kernel: &Self, + kernel_l: &Layout, + params: &crate::conv::ParamsConv1D, + ) -> Result { + const USE_IM2COL_CONV1D: bool = true; + + let device = self.device().clone(); + if !USE_IM2COL_CONV1D { + let slice = Conv1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?; + return Ok(Self { slice, device }); + } + + let col = Im2Col1D { + l_k: params.k_size, + stride: params.stride, + dilation: params.dilation, + padding: params.padding, + } + .map(&self.slice, &device, l)?; + let col = Self { slice: col, device }; + let l_out = params.l_out(); + let b = params.b_size; + let n = params.c_out; + let k = params.k_size * params.c_in; + let m = l_out; + let col_l = Layout::contiguous((b, m, k)); + let res = if kernel_l.is_contiguous() { + let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) + .transpose(1, 2)? + .broadcast_as((b, k, n))?; + col.matmul_with_alpha(kernel, None, (b, m, n, k), &col_l, &kernel_l)? + } else { + // Make the kernel contiguous if not already the case. + let mut kernel_c = unsafe { + self.device() + .alloc_uninit(kernel_l.shape(), kernel.dtype())? + }; + kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?; + let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) + .transpose(1, 2)? + .broadcast_as((b, k, n))?; + col.matmul_with_alpha(kernel, None, (b, m, n, k), &col_l, &kernel_l)? + }; + let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?; + let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? }; + res.copy_strided_src(&mut res_t, 0, &res_l)?; + Ok(res_t) + } + + fn conv_transpose1d( + &self, + l: &Layout, + kernel: &Self, + kernel_l: &Layout, + params: &crate::conv::ParamsConvTranspose1D, + ) -> Result { + const USE_COL2IM_CONV1D_TR: bool = true; + + let device = self.device().clone(); + let can_use_col2im = kernel_l.is_contiguous() + && params.dilation == 1 + && params.padding == 0 + && params.output_padding == 0; + let slice = if USE_COL2IM_CONV1D_TR && can_use_col2im { + let (b_size, c_in, l_in) = l.shape().dims3()?; + let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?; + if !kernel_l.is_contiguous() { + crate::bail!( + "convtr1d: the second argument (kernel) has to be contiguous {kernel_l:?}" + ) + } + if c_in != c_in2 { + crate::bail!( + "convtr1d: shape mismatch on c_in {:?} {:?}", + l.shape(), + kernel_l.shape() + ) + } + let col = { + // This merges the last two dimensions of the kernel together. + let kernel_l_mm = Layout::new( + (b_size, c_in, k_size * c_out).into(), + vec![0, k_size * c_out, 1], + kernel_l.start_offset(), + ); + self.matmul_with_alpha( + kernel, + None, + ( + b_size, + /* m */ l_in, + /* n */ c_out * k_size, + /* k */ c_in, + ), + &l.transpose(1, 2)?, + &kernel_l_mm, + )? + }; + let col_l = Layout::contiguous((b_size, l_in, c_out, k_size)); + Col2Im1D { + stride: params.stride, + } + .map(&col.slice, &device, &col_l)? + } else { + ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)? + }; + Ok(Self { slice, device }) + } + + #[cfg(not(feature = "cudnn"))] + fn conv2d( + &self, + l: &Layout, + kernel: &Self, + kernel_l: &Layout, + params: &crate::conv::ParamsConv2D, + ) -> Result { + const USE_IM2COL_CONV2D: bool = true; + + let device = self.device().clone(); + if !USE_IM2COL_CONV2D { + let slice = Conv2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?; + return Ok(Self { slice, device }); + } + + let col = Im2Col { + h_k: params.k_h, + w_k: params.k_w, + stride: params.stride, + dilation: params.dilation, + padding: params.padding, + } + .map(&self.slice, &device, l)?; + let col = Self { slice: col, device }; + let h_out = params.out_h(); + let w_out = params.out_w(); + let b = params.b_size; + let n = params.c_out; + let k = params.k_h * params.k_w * params.c_in; + let m = h_out * w_out; + let col_l = Layout::contiguous((b, m, k)); + let res = if kernel_l.is_contiguous() { + let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) + .transpose(1, 2)? + .broadcast_as((b, k, n))?; + col.matmul_with_alpha(kernel, None, (b, m, n, k), &col_l, &kernel_l)? + } else { + // Make the kernel contiguous if not already the case. + let mut kernel_c = unsafe { + self.device() + .alloc_uninit(kernel_l.shape(), kernel.dtype())? + }; + kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?; + let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) + .transpose(1, 2)? + .broadcast_as((b, k, n))?; + col.matmul_with_alpha(kernel, None, (b, m, n, k), &col_l, &kernel_l)? + }; + let res_l = Layout::contiguous((b, h_out, w_out, n)) + .transpose(1, 2)? + .transpose(1, 3)?; + let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? }; + res.copy_strided_src(&mut res_t, 0, &res_l)?; + Ok(res_t) + } + + #[cfg(feature = "cudnn")] + fn conv2d( + &self, + inp_l: &Layout, + kernel: &Self, + kernel_l: &Layout, + params: &crate::conv::ParamsConv2D, + ) -> Result { + let device = self.device().clone(); + if !kernel_l.is_contiguous() { + let slice = Conv2D(params).map(&self.slice, inp_l, &kernel.slice, kernel_l, &device)?; + return Ok(Self { slice, device }); + } + let (out_w, out_h) = (params.out_w(), params.out_h()); + let dst_el = params.c_out * out_w * out_h * params.b_size; + let slice = match (&self.slice, &kernel.slice) { + (S::U8(inp), S::U8(k)) => { + let inp = &inp.slice(inp_l.start_offset()..); + let k = &k.slice(kernel_l.start_offset()..); + let mut out = unsafe { device.alloc::(dst_el) }.w()?; + crate::cudnn::launch_conv2d::(inp, inp_l, k, &mut out, params, &device) + .map_err(crate::Error::wrap)?; + S::U8(out) + } + (S::BF16(inp), S::BF16(k)) => { + let inp = &inp.slice(inp_l.start_offset()..); + let k = &k.slice(kernel_l.start_offset()..); + let mut out = unsafe { device.alloc::(dst_el) }.w()?; + // Only PSEUDO_BFLOAT16_CONFIG is supported in cudnn, there is no "true bfloat16" + // version. + // https://docs.nvidia.com/deeplearning/cudnn/latest/api/cudnn-cnn-library.html#id88 + crate::cudnn::launch_conv2d::(inp, inp_l, k, &mut out, params, &device) + .map_err(crate::Error::wrap)?; + S::BF16(out) + } + (S::F16(inp), S::F16(k)) => { + let inp = &inp.slice(inp_l.start_offset()..); + let k = &k.slice(kernel_l.start_offset()..); + let mut out = unsafe { device.alloc::(dst_el) }.w()?; + crate::cudnn::launch_conv2d::(inp, inp_l, k, &mut out, params, &device) + .map_err(crate::Error::wrap)?; + S::F16(out) + } + (S::F32(inp), S::F32(k)) => { + let inp = &inp.slice(inp_l.start_offset()..); + let k = &k.slice(kernel_l.start_offset()..); + let mut out = unsafe { device.alloc::(dst_el) }.w()?; + crate::cudnn::launch_conv2d::(inp, inp_l, k, &mut out, params, &device) + .map_err(crate::Error::wrap)?; + S::F32(out) + } + (S::F64(inp), S::F64(k)) => { + let inp = &inp.slice(inp_l.start_offset()..); + let k = &k.slice(kernel_l.start_offset()..); + let mut out = unsafe { device.alloc::(dst_el) }.w()?; + crate::cudnn::launch_conv2d::(inp, inp_l, k, &mut out, params, &device) + .map_err(crate::Error::wrap)?; + S::F64(out) + } + (S::U32(_), S::U32(_)) => Err(CudaError::InternalError("conv2d does not support u32"))?, + (S::I16(_), S::I16(_)) => Err(CudaError::InternalError("conv2d does not support i16"))?, + (S::I32(_), S::I32(_)) => Err(CudaError::InternalError("conv2d does not support i32"))?, + (S::I64(_), S::I64(_)) => Err(CudaError::InternalError("conv2d does not support i64"))?, + _ => Err(CudaError::InternalError("dtype mismatch in conv2d"))?, + }; + Ok(Self { slice, device }) + } + + fn conv_transpose2d( + &self, + l: &Layout, + kernel: &Self, + kernel_l: &Layout, + params: &crate::conv::ParamsConvTranspose2D, + ) -> Result { + let device = self.device().clone(); + let slice = + ConvTranspose2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?; + Ok(Self { slice, device }) + } + + fn avg_pool2d(&self, l: &Layout, k: (usize, usize), stride: (usize, usize)) -> Result { + let device = self.device().clone(); + let slice = Pool2D { + w_k: k.0, + h_k: k.1, + w_stride: stride.0, + h_stride: stride.1, + op: PoolOp::Avg, + } + .map(&self.slice, &device, l)?; + Ok(Self { slice, device }) + } + + fn max_pool2d(&self, l: &Layout, k: (usize, usize), stride: (usize, usize)) -> Result { + let device = self.device().clone(); + let slice = Pool2D { + w_k: k.0, + h_k: k.1, + w_stride: stride.0, + h_stride: stride.1, + op: PoolOp::Max, + } + .map(&self.slice, &device, l)?; + Ok(Self { slice, device }) + } + + fn upsample_nearest1d(&self, _: &Layout, _out_sz: usize) -> Result { + crate::bail!("upsample-nearest1d is not supported on cuda") + } + + fn upsample_nearest2d(&self, l: &Layout, out_w: usize, out_h: usize) -> Result { + let device = self.device().clone(); + let slice = UpsampleNearest2D(out_w, out_h).map(&self.slice, &device, l)?; + Ok(Self { slice, device }) + } + + fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result { + let device = self.device().clone(); + let slice = IndexSelect(ids, ids_l, dim).map(&self.slice, &device, l)?; + Ok(Self { slice, device }) + } + fn gather(&self, l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result { + let device = self.device().clone(); + let slice = Gather(ids, ids_l, dim).map(&self.slice, &device, l)?; + Ok(Self { slice, device }) + } + fn scatter_add( + &self, + l: &Layout, + ids: &Self, + ids_l: &Layout, + src: &Self, + src_l: &Layout, + dim: usize, + ) -> Result { + let device = self.device().clone(); + let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? }; + self.copy_strided_src(&mut acc, 0, l)?; + ScatterAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?; + Ok(acc) + } + fn index_add( + &self, + l: &Layout, + ids: &Self, + ids_l: &Layout, + src: &Self, + src_l: &Layout, + dim: usize, + ) -> Result { + let device = self.device().clone(); + let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? }; + self.copy_strided_src(&mut acc, 0, l)?; + IndexAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?; + Ok(acc) + } + + fn matmul_with_alpha_beta( + &self, + rhs: &Self, + c: &mut Self, + s: Option, + (b, m, n, k): (usize, usize, usize, usize), + lhs_l: &Layout, + rhs_l: &Layout, + c_l: &Layout, + ) -> Result<()> { + let elem_count = b * m * n; + + match c_l.contiguous_offsets() { + Some((o1, o2)) => { + if o1 != 0 { + crate::bail!("`c` start offset must be 0"); + } + if o2 != elem_count { + crate::bail!("`c` end offset must be {}", elem_count) + } + } + None => crate::bail!("`c` has to be contiguous"), + }; + + match (&self.slice, &rhs.slice, &mut c.slice) { + ( + CudaStorageSlice::BF16(lhs), + CudaStorageSlice::BF16(rhs), + CudaStorageSlice::BF16(c), + ) => { + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); + let cfg = gemm_config( + bf16::from_f64(s.unwrap_or(1.0)), + bf16::ONE, + (b, m, n, k), + lhs_l, + rhs_l, + )?; + unsafe { gemm_strided_batched_bf16(&self.device.blas, cfg, rhs, lhs, c) }.w()?; + } + (CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs), CudaStorageSlice::F16(c)) => { + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); + let cfg = gemm_config( + f16::from_f64(s.unwrap_or(1.0)), + f16::ONE, + (b, m, n, k), + lhs_l, + rhs_l, + )?; + unsafe { gemm_strided_batched_f16(&self.device.blas, cfg, rhs, lhs, c) }.w()?; + } + (CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs), CudaStorageSlice::F32(c)) => { + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); + let cfg = gemm_config(s.unwrap_or(1.0) as f32, 1., (b, m, n, k), lhs_l, rhs_l)?; + unsafe { gemm_strided_batched_f32(&self.device.blas, cfg, rhs, lhs, c) }.w()?; + } + (CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs), CudaStorageSlice::F64(c)) => { + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); + let cfg = gemm_config(s.unwrap_or(1.0), 1., (b, m, n, k), lhs_l, rhs_l)?; + unsafe { self.device.blas.gemm_strided_batched(cfg, rhs, lhs, c) }.w()?; + } + _ => Err(CudaError::InternalError("dtype mismatch in matmul op"))?, + }; + Ok(()) + } + + fn matmul_with_alpha( + &self, + rhs: &Self, + scale: Option, + (b, m, n, k): (usize, usize, usize, usize), + lhs_l: &Layout, + rhs_l: &Layout, + ) -> Result { + let elem_count = b * m * n; + let dev = &self.device; + let slice = match (&self.slice, &rhs.slice) { + (CudaStorageSlice::BF16(lhs), CudaStorageSlice::BF16(rhs)) => { + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); + let cfg = gemm_config( + bf16::from_f64(scale.unwrap_or(1.)), + bf16::ZERO, + (b, m, n, k), + lhs_l, + rhs_l, + )?; + let mut out = unsafe { dev.alloc::(elem_count) }.w()?; + unsafe { gemm_strided_batched_bf16(&self.device.blas, cfg, rhs, lhs, &mut out) } + .w()?; + CudaStorageSlice::BF16(out) + } + (CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => { + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); + let cfg = gemm_config( + f16::from_f64(scale.unwrap_or(1.)), + f16::ZERO, + (b, m, n, k), + lhs_l, + rhs_l, + )?; + let mut out = unsafe { dev.alloc::(elem_count) }.w()?; + unsafe { gemm_strided_batched_f16(&self.device.blas, cfg, rhs, lhs, &mut out) } + .w()?; + CudaStorageSlice::F16(out) + } + (CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => { + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); + let cfg = gemm_config(scale.unwrap_or(1.) as f32, 0., (b, m, n, k), lhs_l, rhs_l)?; + let mut out = unsafe { dev.alloc::(elem_count) }.w()?; + unsafe { gemm_strided_batched_f32(&self.device.blas, cfg, rhs, lhs, &mut out) } + .w()?; + CudaStorageSlice::F32(out) + } + (CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => { + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); + let cfg = gemm_config(scale.unwrap_or(1.), 0., (b, m, n, k), lhs_l, rhs_l)?; + let mut out = unsafe { dev.alloc::(elem_count) }.w()?; + unsafe { + self.device + .blas + .gemm_strided_batched(cfg, rhs, lhs, &mut out) + } + .w()?; + CudaStorageSlice::F64(out) + } + _ => Err(CudaError::InternalError("dtype mismatch in matmul op"))?, + }; + let device = dev.clone(); + Ok(Self { slice, device }) + } + + fn copy2d( + &self, + dst: &mut Self, + d1: usize, + d2: usize, + src_s: usize, + dst_s: usize, + src_o: usize, + dst_o: usize, + ) -> Result<()> { + let dev = &self.device; + let d1 = d1 as u32; + let d2 = d2 as u32; + // Nothing to copy so we exit early to avoid launching a kernel and some potential invalid + // argument with a null pointer. + if d1 == 0 || d2 == 0 { + return Ok(()); + } + let dst_s = dst_s as u32; + let src_s = src_s as u32; + let (src, dst, kname) = match (&self.slice, &mut dst.slice) { + (S::U8(s), S::U8(d)) => ( + *s.slice(src_o..).device_ptr(), + *d.slice(dst_o..).device_ptr(), + "copy2d_u8", + ), + (S::U32(s), S::U32(d)) => ( + *s.slice(src_o..).device_ptr(), + *d.slice(dst_o..).device_ptr(), + "copy2d_u32", + ), + (S::I16(s), S::I16(d)) => ( + *s.slice(src_o..).device_ptr(), + *d.slice(dst_o..).device_ptr(), + "copy2d_i16", + ), + (S::I32(s), S::I32(d)) => ( + *s.slice(src_o..).device_ptr(), + *d.slice(dst_o..).device_ptr(), + "copy2d_i32", + ), + (S::I64(s), S::I64(d)) => ( + *s.slice(src_o..).device_ptr(), + *d.slice(dst_o..).device_ptr(), + "copy2d_i64", + ), + (S::BF16(s), S::BF16(d)) => ( + *s.slice(src_o..).device_ptr(), + *d.slice(dst_o..).device_ptr(), + "copy2d_bf16", + ), + (S::F16(s), S::F16(d)) => ( + *s.slice(src_o..).device_ptr(), + *d.slice(dst_o..).device_ptr(), + "copy2d_f16", + ), + (S::F32(s), S::F32(d)) => ( + *s.slice(src_o..).device_ptr(), + *d.slice(dst_o..).device_ptr(), + "copy2d_f32", + ), + (S::F64(s), S::F64(d)) => ( + *s.slice(src_o..).device_ptr(), + *d.slice(dst_o..).device_ptr(), + "copy2d_f64", + ), + (S::F8E4M3(s), S::F8E4M3(d)) => ( + *s.slice(src_o..).device_ptr(), + *d.slice(dst_o..).device_ptr(), + "copy2d_f8_e4m3", + ), + _ => Err(CudaError::InternalError("dtype mismatch in copy2d"))?, + }; + let func = dev.get_or_load_func(kname, kernels::FILL)?; + let cfg = LaunchConfig::for_num_elems(d1 * d2); + let params = (src, dst, d1, d2, src_s, dst_s); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(()) + } + + fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { + let src_shape = src_l.shape(); + let dims = src_shape.dims(); + let el_count = src_shape.elem_count(); + if el_count == 0 { + return Ok(()); + } + let cfg = LaunchConfig::for_num_elems(el_count as u32); + let dev = &self.device; + let ds = SlicePtrOrNull::params_from_layout(dev, src_l)?; + match (&self.slice, &mut dst.slice) { + (CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => { + let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); + if src_l.is_contiguous() { + dev.dtod_copy(&src, &mut dst).w()? + } else { + let func = dev.get_or_load_func("ucopy_bf16", kernels::UNARY)?; + // SAFETY: Set later by running the kernel. + let params = (el_count, dims.len(), &ds, &src, &mut dst); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()? + } + } + (CudaStorageSlice::F16(src), CudaStorageSlice::F16(dst)) => { + let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); + if src_l.is_contiguous() { + dev.dtod_copy(&src, &mut dst).w()? + } else { + let func = dev.get_or_load_func("ucopy_f16", kernels::UNARY)?; + // SAFETY: Set later by running the kernel. + let params = (el_count, dims.len(), &ds, &src, &mut dst); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()? + } + } + (CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => { + let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); + if src_l.is_contiguous() { + dev.dtod_copy(&src, &mut dst).w()? + } else { + let func = dev.get_or_load_func("ucopy_f32", kernels::UNARY)?; + // SAFETY: Set later by running the kernel. + let params = (el_count, dims.len(), &ds, &src, &mut dst); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()? + } + } + (CudaStorageSlice::F8E4M3(src), CudaStorageSlice::F8E4M3(dst)) => { + let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); + if src_l.is_contiguous() { + dev.dtod_copy(&src, &mut dst).w()? + } else { + let func = dev.get_or_load_func("ucopy_f8_e4m3", kernels::UNARY)?; + // SAFETY: Set later by running the kernel. + let params = (el_count, dims.len(), &ds, &src, &mut dst); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()? + } + } + (CudaStorageSlice::U8(src), CudaStorageSlice::U8(dst)) => { + let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); + if src_l.is_contiguous() { + dev.dtod_copy(&src, &mut dst).w()? + } else { + let func = dev.get_or_load_func("ucopy_u8", kernels::UNARY)?; + // SAFETY: Set later by running the kernel. + let params = (el_count, dims.len(), &ds, &src, &mut dst); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()? + } + } + (CudaStorageSlice::U32(src), CudaStorageSlice::U32(dst)) => { + let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); + if src_l.is_contiguous() { + dev.dtod_copy(&src, &mut dst).w()? + } else { + let func = dev.get_or_load_func("ucopy_u32", kernels::UNARY)?; + // SAFETY: Set later by running the kernel. + let params = (el_count, dims.len(), &ds, &src, &mut dst); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()? + } + } + (CudaStorageSlice::I16(src), CudaStorageSlice::I16(dst)) => { + let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); + if src_l.is_contiguous() { + dev.dtod_copy(&src, &mut dst).w()? + } else { + let func = dev.get_or_load_func("ucopy_i16", kernels::UNARY)?; + // SAFETY: Set later by running the kernel. + let params = (el_count, dims.len(), &ds, &src, &mut dst); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()? + } + } + (CudaStorageSlice::I32(src), CudaStorageSlice::I32(dst)) => { + let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); + if src_l.is_contiguous() { + dev.dtod_copy(&src, &mut dst).w()? + } else { + let func = dev.get_or_load_func("ucopy_i32", kernels::UNARY)?; + // SAFETY: Set later by running the kernel. + let params = (el_count, dims.len(), &ds, &src, &mut dst); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()? + } + } + (CudaStorageSlice::I64(src), CudaStorageSlice::I64(dst)) => { + let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); + if src_l.is_contiguous() { + dev.dtod_copy(&src, &mut dst).w()? + } else { + let func = dev.get_or_load_func("ucopy_i64", kernels::UNARY)?; + // SAFETY: Set later by running the kernel. + let params = (el_count, dims.len(), &ds, &src, &mut dst); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()? + } + } + (CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => { + let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); + if src_l.is_contiguous() { + dev.dtod_copy(&src, &mut dst).w()? + } else { + let func = dev.get_or_load_func("ucopy_f64", kernels::UNARY)?; + // SAFETY: Set later by running the kernel. + let params = (el_count, dims.len(), &ds, &src, &mut dst); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + } + } + _ => Err(CudaError::InternalError( + "dtype mismatch in copy_strided op", + ))?, + } + Ok(()) + } +} + +// Default for the reduced precision setting is false, similar to pytorch. +// https://github.com/pytorch/pytorch/issues/123157 +static MM_F16_REDUCED_PRECISION: std::sync::atomic::AtomicBool = + std::sync::atomic::AtomicBool::new(false); +static MM_BF16_REDUCED_PRECISION: std::sync::atomic::AtomicBool = + std::sync::atomic::AtomicBool::new(false); +static MM_F32_REDUCED_PRECISION: std::sync::atomic::AtomicBool = + std::sync::atomic::AtomicBool::new(false); + +/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are +/// allowed with f32 GEMMs. +pub fn gemm_reduced_precision_f32() -> bool { + MM_F32_REDUCED_PRECISION.load(std::sync::atomic::Ordering::Relaxed) +} + +/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are +/// allowed with f32 GEMMs. +pub fn set_gemm_reduced_precision_f32(b: bool) { + MM_F32_REDUCED_PRECISION.store(b, std::sync::atomic::Ordering::Relaxed) +} + +/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are +/// allowed with f16 GEMMs. +pub fn gemm_reduced_precision_f16() -> bool { + MM_F16_REDUCED_PRECISION.load(std::sync::atomic::Ordering::Relaxed) +} + +/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are +/// allowed with f16 GEMMs. +pub fn set_gemm_reduced_precision_f16(b: bool) { + MM_F16_REDUCED_PRECISION.store(b, std::sync::atomic::Ordering::Relaxed) +} + +/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are +/// allowed with bf16 GEMMs. +pub fn gemm_reduced_precision_bf16() -> bool { + MM_BF16_REDUCED_PRECISION.load(std::sync::atomic::Ordering::Relaxed) +} + +/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are +/// allowed with bf16 GEMMs. +pub fn set_gemm_reduced_precision_bf16(b: bool) { + MM_BF16_REDUCED_PRECISION.store(b, std::sync::atomic::Ordering::Relaxed) +} + +unsafe fn gemm_strided_batched_f32( + cublas: &cudarc::cublas::CudaBlas, + cfg: StridedBatchedConfig, + a: &cudarc::driver::CudaView, + b: &cudarc::driver::CudaView, + c: &mut CudaSlice, +) -> std::result::Result<(), cudarc::cublas::result::CublasError> { + use cudarc::cublas::sys; + use cudarc::driver::DevicePtrMut; + + let compute_type = if gemm_reduced_precision_f32() { + sys::cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_TF32 + } else { + sys::cublasComputeType_t::CUBLAS_COMPUTE_32F + }; + let alpha = &cfg.gemm.alpha as *const f32 as *const _; + let beta = &cfg.gemm.beta as *const f32 as *const _; + + cudarc::cublas::result::gemm_strided_batched_ex( + *cublas.handle(), + cfg.gemm.transa, + cfg.gemm.transb, + cfg.gemm.m, + cfg.gemm.n, + cfg.gemm.k, + alpha, + *a.device_ptr() as *const _, + sys::cudaDataType_t::CUDA_R_32F, + cfg.gemm.lda, + cfg.stride_a, + *b.device_ptr() as *const _, + sys::cudaDataType_t::CUDA_R_32F, + cfg.gemm.ldb, + cfg.stride_b, + beta, + *c.device_ptr_mut() as *mut _, + sys::cudaDataType_t::CUDA_R_32F, + cfg.gemm.ldc, + cfg.stride_c, + cfg.batch_size, + compute_type, + sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT_TENSOR_OP, + ) +} + +unsafe fn gemm_strided_batched_f16( + cublas: &cudarc::cublas::CudaBlas, + cfg: StridedBatchedConfig, + a: &cudarc::driver::CudaView, + b: &cudarc::driver::CudaView, + c: &mut CudaSlice, +) -> std::result::Result<(), cudarc::cublas::result::CublasError> { + use cudarc::cublas::sys; + use cudarc::driver::DevicePtrMut; + + let alpha = cfg.gemm.alpha; + let beta = cfg.gemm.beta; + let alpha_f32: f32 = cfg.gemm.alpha.to_f32(); + let beta_f32: f32 = cfg.gemm.beta.to_f32(); + let (compute_type, alpha, beta) = if gemm_reduced_precision_f16() { + ( + sys::cublasComputeType_t::CUBLAS_COMPUTE_16F, + (&alpha) as *const f16 as *const _, + (&beta) as *const f16 as *const _, + ) + } else { + ( + sys::cublasComputeType_t::CUBLAS_COMPUTE_32F, + (&alpha_f32) as *const f32 as *const _, + (&beta_f32) as *const f32 as *const _, + ) + }; + + cudarc::cublas::result::gemm_strided_batched_ex( + *cublas.handle(), + cfg.gemm.transa, + cfg.gemm.transb, + cfg.gemm.m, + cfg.gemm.n, + cfg.gemm.k, + alpha, + *a.device_ptr() as *const _, + sys::cudaDataType_t::CUDA_R_16F, + cfg.gemm.lda, + cfg.stride_a, + *b.device_ptr() as *const _, + sys::cudaDataType_t::CUDA_R_16F, + cfg.gemm.ldb, + cfg.stride_b, + beta, + *c.device_ptr_mut() as *mut _, + sys::cudaDataType_t::CUDA_R_16F, + cfg.gemm.ldc, + cfg.stride_c, + cfg.batch_size, + compute_type, + sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT_TENSOR_OP, + ) +} + +unsafe fn gemm_strided_batched_bf16( + cublas: &cudarc::cublas::CudaBlas, + cfg: StridedBatchedConfig, + a: &cudarc::driver::CudaView, + b: &cudarc::driver::CudaView, + c: &mut CudaSlice, +) -> std::result::Result<(), cudarc::cublas::result::CublasError> { + use cudarc::cublas::sys; + use cudarc::driver::DevicePtrMut; + + let alpha_f32: f32 = cfg.gemm.alpha.to_f32(); + let beta_f32: f32 = cfg.gemm.beta.to_f32(); + // The type for alpha and beta depends on the computeType. + // https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmstridedbatchedex + let (compute_type, alpha, beta) = if gemm_reduced_precision_bf16() { + ( + sys::cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_16BF, + (&alpha_f32) as *const f32 as *const _, + (&beta_f32) as *const f32 as *const _, + ) + } else { + ( + sys::cublasComputeType_t::CUBLAS_COMPUTE_32F, + (&alpha_f32) as *const f32 as *const _, + (&beta_f32) as *const f32 as *const _, + ) + }; + + cudarc::cublas::result::gemm_strided_batched_ex( + *cublas.handle(), + cfg.gemm.transa, + cfg.gemm.transb, + cfg.gemm.m, + cfg.gemm.n, + cfg.gemm.k, + alpha, + *a.device_ptr() as *const _, + sys::cudaDataType_t::CUDA_R_16BF, + cfg.gemm.lda, + cfg.stride_a, + *b.device_ptr() as *const _, + sys::cudaDataType_t::CUDA_R_16BF, + cfg.gemm.ldb, + cfg.stride_b, + beta, + *c.device_ptr_mut() as *mut _, + sys::cudaDataType_t::CUDA_R_16BF, + cfg.gemm.ldc, + cfg.stride_c, + cfg.batch_size, + compute_type, + sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT_TENSOR_OP, + ) +} + +pub struct KVConcat { + pub concat_dim: usize, +} +impl crate::CustomOp2 for KVConcat { + fn name(&self) -> &'static str { + "kvconcat" + } + + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + crate::bail!("no cpu support for kvconcat") + } + + fn cuda_fwd( + &self, + ltensor: &CudaStorage, + ltensor_l: &Layout, + rtensor: &CudaStorage, + rtensor_l: &Layout, + ) -> Result<(CudaStorage, Shape)> { + assert!(self.concat_dim == 2 || self.concat_dim == 0); //must be in the dim of sequence len + let dev = <ensor.device; + let elem_count = ltensor_l.shape().elem_count() + rtensor_l.shape().elem_count(); + let dims_l = ltensor_l.shape().dims(); + let dims_r = rtensor_l.shape().dims(); + let dim_size = dims_l.len(); + let cfg = LaunchConfig::for_num_elems(elem_count as u32); + + let chunk_l = if dim_size > 3 { + dims_l[0] * dims_l[1] + } else { + dims_l[0] + }; + let chunk_r = if dim_size > 3 { + dims_r[0] * dims_r[1] + } else { + dims_r[0] + }; + let lstride = if dim_size > 3 { + dims_l[2] * dims_l[3] + } else { + dims_l[1] * dims_l[2] + }; + let rstride = if dim_size > 3 { + dims_r[2] * dims_r[3] + } else { + dims_r[1] * dims_r[2] + }; + + let slice = match (<ensor.slice, &rtensor.slice) { + (CudaStorageSlice::BF16(left_), CudaStorageSlice::BF16(right_)) => { + let out = unsafe { dev.alloc::(elem_count).w()? }; + let func = dev.get_or_load_func("kvconcat_bf16", kernels::KVCONCAT)?; + let params = ( + left_, + right_, + &out, + self.concat_dim, + chunk_l, + chunk_r, + lstride, + rstride, + ); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::BF16(out) + } + (CudaStorageSlice::F32(left_), CudaStorageSlice::F32(right_)) => { + let out = unsafe { dev.alloc::(elem_count).w()? }; + let func = dev.get_or_load_func("kvconcat_f32", kernels::KVCONCAT)?; + let params = ( + left_, + right_, + &out, + self.concat_dim, + chunk_l, + chunk_r, + lstride, + rstride, + ); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::F32(out) + } + (CudaStorageSlice::F16(left_), CudaStorageSlice::F16(right_)) => { + let out = unsafe { dev.alloc::(elem_count).w()? }; + let func = dev.get_or_load_func("kvconcat_f16", kernels::KVCONCAT)?; + let params = ( + left_, + right_, + &out, + self.concat_dim, + chunk_l, + chunk_r, + lstride, + rstride, + ); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::F16(out) + } + (CudaStorageSlice::F64(left_), CudaStorageSlice::F64(right_)) => { + let out = unsafe { dev.alloc::(elem_count).w()? }; + let func = dev.get_or_load_func("kvconcat_f64", kernels::KVCONCAT)?; + let params = ( + left_, + right_, + &out, + self.concat_dim, + chunk_l, + chunk_r, + lstride, + rstride, + ); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::F64(out) + } + (CudaStorageSlice::F8E4M3(left_), CudaStorageSlice::F8E4M3(right_)) => { + let out = unsafe { dev.alloc::(elem_count).w()? }; + let func = dev.get_or_load_func("kvconcat_f8_e4m3", kernels::KVCONCAT)?; + let params = ( + left_, + right_, + &out, + self.concat_dim, + chunk_l, + chunk_r, + lstride, + rstride, + ); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::F8E4M3(out) + } + (CudaStorageSlice::U8(left_), CudaStorageSlice::U8(right_)) => { + let out = unsafe { dev.alloc::(elem_count).w()? }; + let func = dev.get_or_load_func("kvconcat_u8", kernels::KVCONCAT)?; + let params = ( + left_, + right_, + &out, + self.concat_dim, + chunk_l, + chunk_r, + lstride, + rstride, + ); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::U8(out) + } + _ => Err(CudaError::InternalError("dtype mismatch in kvconcat op"))?, + }; + + let mut lshape: Vec = ltensor_l.shape().dims().to_vec(); + if self.concat_dim == 0 { + lshape[0] += rtensor_l.shape().dims()[0]; + } else { + if dim_size > 3 { + lshape[2] += rtensor_l.shape().dims()[2]; + } else { + lshape[1] += rtensor_l.shape().dims()[1]; + } + } + + let device = dev.clone(); + Ok(( + CudaStorage { + slice: slice, + device, + }, + lshape.into(), + )) + } +} diff --git a/mcandle/mcandle-core/src/cuda_backend/utils.rs b/mcandle/mcandle-core/src/cuda_backend/utils.rs new file mode 100644 index 000000000..9009c3930 --- /dev/null +++ b/mcandle/mcandle-core/src/cuda_backend/utils.rs @@ -0,0 +1,184 @@ +/// Helper functions to plug cuda kernels in mcandle. +use crate::{Layout, Result, Shape, WithDType}; +pub use cudarc; +use cudarc::driver::{CudaSlice, DeviceRepr, ValidAsZeroBits}; + +use super::{CudaDevice, CudaError, WrapErr}; + +pub type S = super::CudaStorageSlice; + +pub trait Map1 { + fn f( + &self, + src: &CudaSlice, + dev: &CudaDevice, + layout: &Layout, + ) -> Result>; + + fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result { + let out = match s { + S::U8(s) => S::U8(self.f(s, d, l)?), + S::U32(s) => S::U32(self.f(s, d, l)?), + S::I16(s) => S::I16(self.f(s, d, l)?), + S::I32(s) => S::I32(self.f(s, d, l)?), + S::I64(s) => S::I64(self.f(s, d, l)?), + S::BF16(s) => S::BF16(self.f(s, d, l)?), + S::F16(s) => S::F16(self.f(s, d, l)?), + S::F32(s) => S::F32(self.f(s, d, l)?), + S::F64(s) => S::F64(self.f(s, d, l)?), + S::F8E4M3(s) => S::F8E4M3(self.f(s, d, l)?), + }; + Ok(out) + } +} + +pub trait Map2 { + fn f( + &self, + src1: &CudaSlice, + layout1: &Layout, + src2: &CudaSlice, + layout2: &Layout, + dev: &CudaDevice, + ) -> Result>; + + fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result { + let out = match (s1, s2) { + (S::U8(s1), S::U8(s2)) => S::U8(self.f(s1, l1, s2, l2, d)?), + (S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?), + (S::I64(s1), S::I64(s2)) => S::I64(self.f(s1, l1, s2, l2, d)?), + (S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?), + (S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?), + (S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?), + (S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?), + (S::F8E4M3(s1), S::F8E4M3(s2)) => S::F8E4M3(self.f(s1, l1, s2, l2, d)?), + _ => Err(CudaError::InternalError("dtype mismatch in binary op"))?, + }; + Ok(out) + } +} + +pub trait Map3 { + #[allow(clippy::too_many_arguments)] + fn f( + &self, + src1: &CudaSlice, + layout1: &Layout, + src2: &CudaSlice, + layout2: &Layout, + src3: &CudaSlice, + layout3: &Layout, + dev: &CudaDevice, + ) -> Result>; + + #[allow(clippy::too_many_arguments)] + fn map( + &self, + s1: &S, + l1: &Layout, + s2: &S, + l2: &Layout, + s3: &S, + l3: &Layout, + d: &CudaDevice, + ) -> Result { + let out = match (s1, s2, s3) { + (S::U8(s1), S::U8(s2), S::U8(s3)) => S::U8(self.f(s1, l1, s2, l2, s3, l3, d)?), + (S::U32(s1), S::U32(s2), S::U32(s3)) => S::U32(self.f(s1, l1, s2, l2, s3, l3, d)?), + (S::I64(s1), S::I64(s2), S::I64(s3)) => S::I64(self.f(s1, l1, s2, l2, s3, l3, d)?), + (S::BF16(s1), S::BF16(s2), S::BF16(s3)) => S::BF16(self.f(s1, l1, s2, l2, s3, l3, d)?), + (S::F16(s1), S::F16(s2), S::F16(s3)) => S::F16(self.f(s1, l1, s2, l2, s3, l3, d)?), + (S::F32(s1), S::F32(s2), S::F32(s3)) => S::F32(self.f(s1, l1, s2, l2, s3, l3, d)?), + (S::F64(s1), S::F64(s2), S::F64(s3)) => S::F64(self.f(s1, l1, s2, l2, s3, l3, d)?), + (S::F8E4M3(s1), S::F8E4M3(s2), S::F8E4M3(s3)) => { + S::F8E4M3(self.f(s1, l1, s2, l2, s3, l3, d)?) + } + _ => Err(CudaError::InternalError("dtype mismatch in ternary op"))?, + }; + Ok(out) + } +} + +pub trait Map2InPlace { + fn f( + &self, + dst: &mut CudaSlice, + dst_shape: &Shape, + src: &CudaSlice, + src_l: &Layout, + dev: &CudaDevice, + ) -> Result<()>; + + fn map( + &self, + dst: &mut S, + dst_s: &Shape, + src: &S, + src_l: &Layout, + d: &CudaDevice, + ) -> Result<()> { + match (dst, src) { + (S::U8(dst), S::U8(src)) => self.f(dst, dst_s, src, src_l, d), + (S::U32(dst), S::U32(src)) => self.f(dst, dst_s, src, src_l, d), + (S::I64(dst), S::I64(src)) => self.f(dst, dst_s, src, src_l, d), + (S::BF16(dst), S::BF16(src)) => self.f(dst, dst_s, src, src_l, d), + (S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d), + (S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d), + (S::F64(dst), S::F64(src)) => self.f(dst, dst_s, src, src_l, d), + (S::F8E4M3(dst), S::F8E4M3(src)) => self.f(dst, dst_s, src, src_l, d), + _ => Err(CudaError::InternalError("dtype mismatch in binary op"))?, + } + } +} + +pub trait Map1Any { + fn f) -> S>( + &self, + src: &CudaSlice, + dev: &CudaDevice, + layout: &Layout, + wrap: W, + ) -> Result; + + fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result { + let out = match s { + S::U8(s) => self.f(s, d, l, S::U8)?, + S::U32(s) => self.f(s, d, l, S::U32)?, + S::I16(s) => self.f(s, d, l, S::I16)?, + S::I32(s) => self.f(s, d, l, S::I32)?, + S::I64(s) => self.f(s, d, l, S::I64)?, + S::BF16(s) => self.f(s, d, l, S::BF16)?, + S::F16(s) => self.f(s, d, l, S::F16)?, + S::F32(s) => self.f(s, d, l, S::F32)?, + S::F64(s) => self.f(s, d, l, S::F64)?, + S::F8E4M3(s) => self.f(s, d, l, S::F8E4M3)?, + }; + Ok(out) + } +} + +pub trait Map2Any { + fn f( + &self, + src1: &CudaSlice, + layout1: &Layout, + src2: &CudaSlice, + layout2: &Layout, + dev: &CudaDevice, + ) -> Result; + + fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result { + let out = match (s1, s2) { + (S::U8(s1), S::U8(s2)) => self.f(s1, l1, s2, l2, d)?, + (S::U32(s1), S::U32(s2)) => self.f(s1, l1, s2, l2, d)?, + (S::I64(s1), S::I64(s2)) => self.f(s1, l1, s2, l2, d)?, + (S::BF16(s1), S::BF16(s2)) => self.f(s1, l1, s2, l2, d)?, + (S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?, + (S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?, + (S::F64(s1), S::F64(s2)) => self.f(s1, l1, s2, l2, d)?, + (S::F8E4M3(s1), S::F8E4M3(s2)) => self.f(s1, l1, s2, l2, d)?, + _ => Err(CudaError::InternalError("dtype mismatch in binary op")).w()?, + }; + Ok(out) + } +} diff --git a/mcandle/mcandle-core/src/custom_op.rs b/mcandle/mcandle-core/src/custom_op.rs new file mode 100644 index 000000000..54e54bb39 --- /dev/null +++ b/mcandle/mcandle-core/src/custom_op.rs @@ -0,0 +1,484 @@ +use crate::op::{BackpropOp, Op}; +use crate::tensor::from_storage; +use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor}; +use std::sync::Arc; + +/// Unary ops that can be defined in user-land. +pub trait CustomOp1 { + // Box does not support const yet, so use a function to get the name. + fn name(&self) -> &'static str; + + /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)>; + + /// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn cuda_fwd(&self, _storage: &CudaStorage, _layout: &Layout) -> Result<(CudaStorage, Shape)> { + Err(crate::Error::Cuda( + format!("no cuda implementation for {}", self.name()).into(), + )) + } + + /// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn metal_fwd( + &self, + _storage: &MetalStorage, + _layout: &Layout, + ) -> Result<(MetalStorage, Shape)> { + Err(crate::Error::Metal( + format!("no metal implementation for {}", self.name()).into(), + )) + } + + /// This function takes as argument the argument `arg` used in the forward pass, the result + /// produced by the forward operation `res` and the gradient of the result `grad_res`. + /// The function should return the gradient of the argument. + fn bwd(&self, _arg: &Tensor, _res: &Tensor, _grad_res: &Tensor) -> Result> { + Err(crate::Error::BackwardNotSupported { op: self.name() }) + } +} + +pub trait CustomOp2 { + fn name(&self) -> &'static str; + + /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn cpu_fwd( + &self, + s1: &CpuStorage, + l1: &Layout, + s2: &CpuStorage, + l2: &Layout, + ) -> Result<(CpuStorage, Shape)>; + + /// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn cuda_fwd( + &self, + _: &CudaStorage, + _: &Layout, + _: &CudaStorage, + _: &Layout, + ) -> Result<(CudaStorage, Shape)> { + Err(crate::Error::Cuda( + format!("no cuda implementation for {}", self.name()).into(), + )) + } + + /// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn metal_fwd( + &self, + _: &MetalStorage, + _: &Layout, + _: &MetalStorage, + _: &Layout, + ) -> Result<(MetalStorage, Shape)> { + Err(crate::Error::Metal( + format!("no metal implementation for {}", self.name()).into(), + )) + } + + fn bwd( + &self, + _arg1: &Tensor, + _arg2: &Tensor, + _res: &Tensor, + _grad_res: &Tensor, + ) -> Result<(Option, Option)> { + Err(crate::Error::BackwardNotSupported { op: self.name() }) + } +} + +pub trait CustomOp3 { + fn name(&self) -> &'static str; + + /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn cpu_fwd( + &self, + s1: &CpuStorage, + l1: &Layout, + s2: &CpuStorage, + l2: &Layout, + s3: &CpuStorage, + l3: &Layout, + ) -> Result<(CpuStorage, Shape)>; + + /// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn cuda_fwd( + &self, + _: &CudaStorage, + _: &Layout, + _: &CudaStorage, + _: &Layout, + _: &CudaStorage, + _: &Layout, + ) -> Result<(CudaStorage, Shape)> { + Err(crate::Error::Cuda( + format!("no cuda implementation for {}", self.name()).into(), + )) + } + + /// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn metal_fwd( + &self, + _: &MetalStorage, + _: &Layout, + _: &MetalStorage, + _: &Layout, + _: &MetalStorage, + _: &Layout, + ) -> Result<(MetalStorage, Shape)> { + Err(crate::Error::Metal( + format!("no metal implementation for {}", self.name()).into(), + )) + } + + fn bwd( + &self, + _arg1: &Tensor, + _arg2: &Tensor, + _arg3: &Tensor, + _res: &Tensor, + _grad_res: &Tensor, + ) -> Result<(Option, Option, Option)> { + Err(crate::Error::BackwardNotSupported { op: self.name() }) + } +} + +impl Tensor { + /// Applies a unary custom op without backward support + pub fn apply_op1_no_bwd(&self, c: &C) -> Result { + let (storage, shape) = self.storage().apply_op1(self.layout(), c)?; + Ok(from_storage(storage, shape, BackpropOp::none(), false)) + } + + /// Applies a binary custom op without backward support + pub fn apply_op2_no_bwd(&self, rhs: &Self, c: &C) -> Result { + let (storage, shape) = + self.storage() + .apply_op2(self.layout(), &rhs.storage(), rhs.layout(), c)?; + Ok(from_storage(storage, shape, BackpropOp::none(), false)) + } + + /// Applies a ternary custom op without backward support + pub fn apply_op3_no_bwd(&self, t2: &Self, t3: &Self, c: &C) -> Result { + let (storage, shape) = self.storage().apply_op3( + self.layout(), + &t2.storage(), + t2.layout(), + &t3.storage(), + t3.layout(), + c, + )?; + Ok(from_storage(storage, shape, BackpropOp::none(), false)) + } + + /// Applies a unary custom op. + pub fn apply_op1_arc(&self, c: Arc>) -> Result { + let (storage, shape) = self + .storage() + .apply_op1(self.layout(), c.as_ref().as_ref())?; + let op = BackpropOp::new1(self, |s| Op::CustomOp1(s, c.clone())); + Ok(from_storage(storage, shape, op, false)) + } + + pub fn apply_op1(&self, c: C) -> Result { + self.apply_op1_arc(Arc::new(Box::new(c))) + } + + /// Applies a binary custom op. + pub fn apply_op2_arc( + &self, + rhs: &Self, + c: Arc>, + ) -> Result { + let (storage, shape) = self.storage().apply_op2( + self.layout(), + &rhs.storage(), + rhs.layout(), + c.as_ref().as_ref(), + )?; + let op = BackpropOp::new2(self, rhs, |t1, t2| Op::CustomOp2(t1, t2, c.clone())); + Ok(from_storage(storage, shape, op, false)) + } + + pub fn apply_op2(&self, r: &Self, c: C) -> Result { + self.apply_op2_arc(r, Arc::new(Box::new(c))) + } + + /// Applies a ternary custom op. + pub fn apply_op3_arc( + &self, + t2: &Self, + t3: &Self, + c: Arc>, + ) -> Result { + let (storage, shape) = self.storage().apply_op3( + self.layout(), + &t2.storage(), + t2.layout(), + &t3.storage(), + t3.layout(), + c.as_ref().as_ref(), + )?; + let op = BackpropOp::new3(self, t2, t3, |t1, t2, t3| { + Op::CustomOp3(t1, t2, t3, c.clone()) + }); + Ok(from_storage(storage, shape, op, false)) + } + + pub fn apply_op3( + &self, + t2: &Self, + t3: &Self, + c: C, + ) -> Result { + self.apply_op3_arc(t2, t3, Arc::new(Box::new(c))) + } +} + +// In place ops. + +/// Unary ops that can be defined in user-land. +/// These ops work in place and as such back-prop is unsupported. +pub trait InplaceOp1 { + // Box does not support const yet, so use a function to get the name. + fn name(&self) -> &'static str; + + /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn cpu_fwd(&self, storage: &mut CpuStorage, layout: &Layout) -> Result<()>; + + /// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn cuda_fwd(&self, _storage: &mut CudaStorage, _layout: &Layout) -> Result<()> { + Err(crate::Error::Cuda( + format!("no cuda implementation for {}", self.name()).into(), + )) + } + + /// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn metal_fwd(&self, _storage: &mut MetalStorage, _layout: &Layout) -> Result<()> { + Err(crate::Error::Metal( + format!("no metal implementation for {}", self.name()).into(), + )) + } +} + +pub trait InplaceOp2 { + fn name(&self) -> &'static str; + + /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn cpu_fwd(&self, s1: &mut CpuStorage, l1: &Layout, s2: &CpuStorage, l2: &Layout) + -> Result<()>; + + /// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn cuda_fwd(&self, _: &mut CudaStorage, _: &Layout, _: &CudaStorage, _: &Layout) -> Result<()> { + Err(crate::Error::Cuda( + format!("no cuda implementation for {}", self.name()).into(), + )) + } + + /// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn metal_fwd( + &self, + _: &mut MetalStorage, + _: &Layout, + _: &MetalStorage, + _: &Layout, + ) -> Result<()> { + Err(crate::Error::Metal( + format!("no metal implementation for {}", self.name()).into(), + )) + } +} + +pub trait InplaceOp3 { + fn name(&self) -> &'static str; + + /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn cpu_fwd( + &self, + s1: &mut CpuStorage, + l1: &Layout, + s2: &CpuStorage, + l2: &Layout, + s3: &CpuStorage, + l3: &Layout, + ) -> Result<()>; + + /// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn cuda_fwd( + &self, + _: &mut CudaStorage, + _: &Layout, + _: &CudaStorage, + _: &Layout, + _: &CudaStorage, + _: &Layout, + ) -> Result<()> { + Err(crate::Error::Cuda( + format!("no cuda implementation for {}", self.name()).into(), + )) + } + + /// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn metal_fwd( + &self, + _: &mut MetalStorage, + _: &Layout, + _: &MetalStorage, + _: &Layout, + _: &MetalStorage, + _: &Layout, + ) -> Result<()> { + Err(crate::Error::Metal( + format!("no metal implementation for {}", self.name()).into(), + )) + } +} + +impl Tensor { + /// Applies a unary custom op in place. + pub fn inplace_op1(&self, c: &C) -> Result<()> { + self.storage_mut().inplace_op1(self.layout(), c) + } + + /// Applies a unary custom op in place (for the first tensor). + pub fn inplace_op2(&self, rhs: &Self, c: &C) -> Result<()> { + self.storage_mut() + .inplace_op2(self.layout(), &rhs.storage(), rhs.layout(), c) + } + + /// Applies a ternary custom op in place (for the first tensor). + pub fn inplace_op3(&self, t2: &Self, t3: &Self, c: &C) -> Result<()> { + self.storage_mut().inplace_op3( + self.layout(), + &t2.storage(), + t2.layout(), + &t3.storage(), + t3.layout(), + c, + ) + } +} + +pub struct UgIOp1 { + name: &'static str, + #[cfg(feature = "cuda")] + func: cudarc::driver::CudaFunction, + #[cfg(feature = "metal")] + func: metal::ComputePipelineState, +} + +impl UgIOp1 { + #[allow(unused)] + pub fn new( + name: &'static str, + kernel: ug::lang::ssa::Kernel, + device: &crate::Device, + ) -> Result { + #[cfg(feature = "cuda")] + { + let device = device.as_cuda_device()?; + let func = device.compile(name, kernel)?; + Ok(Self { name, func }) + } + #[cfg(feature = "metal")] + { + let device = device.as_metal_device()?; + let func = device.compile(name, kernel)?; + Ok(Self { name, func }) + } + #[cfg(not(any(feature = "cuda", feature = "metal")))] + { + Ok(Self { name }) + } + } +} + +impl InplaceOp1 for UgIOp1 { + fn name(&self) -> &'static str { + self.name + } + + fn cpu_fwd(&self, _: &mut CpuStorage, _: &Layout) -> Result<()> { + crate::bail!("ug ops are only supported on metal/cuda at the moment") + } + + #[cfg(feature = "metal")] + fn metal_fwd(&self, sto: &mut MetalStorage, layout: &Layout) -> Result<()> { + use crate::backend::BackendStorage; + use mcandle_metal_kernels::utils::EncoderProvider; + + let elem_count = layout.shape().elem_count(); + if sto.dtype() != crate::DType::F32 { + // TODO: support more dtypes. + crate::bail!("input is not a f32 tensor") + } + let device = sto.device(); + println!("here"); + let command_buffer = device.command_buffer()?; + let command_buffer = &command_buffer; + let encoder = command_buffer.encoder(); + let encoder = encoder.as_ref(); + encoder.set_compute_pipeline_state(&self.func); + let (g, b) = if elem_count % 32 == 0 { + (elem_count / 32, 32) + } else { + (elem_count, 1) + }; + let grid_dims = metal::MTLSize { + width: g as u64, + height: 1, + depth: 1, + }; + let group_dims = mcandle_metal_kernels::utils::get_block_dims(b as u64, 1, 1); + mcandle_metal_kernels::utils::set_param(encoder, 0, (sto.buffer(), 0usize)); + + encoder.use_resource(sto.buffer(), metal::MTLResourceUsage::Write); + encoder.dispatch_threads(grid_dims, group_dims); + + Ok(()) + } + + #[cfg(feature = "cuda")] + fn cuda_fwd(&self, sto: &mut CudaStorage, layout: &Layout) -> Result<()> { + use crate::cuda_backend::WrapErr; + use cudarc::driver::LaunchAsync; + + let elem_count = layout.shape().elem_count(); + // TODO: support more dtypes. + let sto = sto.as_cuda_slice::()?; + let sto = match layout.contiguous_offsets() { + None => crate::bail!("input has to be contiguous"), + Some((o1, o2)) => sto.slice(o1..o2), + }; + let params = (&sto,); + let (g, b) = if elem_count % 32 == 0 { + (elem_count / 32, 32) + } else { + (elem_count, 1) + }; + let cfg = cudarc::driver::LaunchConfig { + grid_dim: (g as u32, 1, 1), + block_dim: (b as u32, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { self.func.clone().launch(cfg, params) }.w()?; + Ok(()) + } +} diff --git a/mcandle/mcandle-core/src/device.rs b/mcandle/mcandle-core/src/device.rs new file mode 100644 index 000000000..c0da162be --- /dev/null +++ b/mcandle/mcandle-core/src/device.rs @@ -0,0 +1,407 @@ +use crate::backend::BackendDevice; +use crate::cpu_backend::CpuDevice; +use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType}; + +/// A `DeviceLocation` represents a physical device whereas multiple `Device` +/// can live on the same location (typically for cuda devices). +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub enum DeviceLocation { + Cpu, + Cuda { gpu_id: usize }, + Metal { gpu_id: usize }, +} + +#[derive(Debug, Clone)] +pub enum Device { + Cpu, + Cuda(crate::CudaDevice), + Metal(crate::MetalDevice), +} + +pub trait NdArray { + fn shape(&self) -> Result; + + fn to_cpu_storage(&self) -> CpuStorage; +} + +impl NdArray for S { + fn shape(&self) -> Result { + Ok(Shape::from(())) + } + + fn to_cpu_storage(&self) -> CpuStorage { + S::to_cpu_storage(&[*self]) + } +} + +impl NdArray for &[S; N] { + fn shape(&self) -> Result { + Ok(Shape::from(self.len())) + } + + fn to_cpu_storage(&self) -> CpuStorage { + S::to_cpu_storage(self.as_slice()) + } +} + +impl NdArray for &[S] { + fn shape(&self) -> Result { + Ok(Shape::from(self.len())) + } + + fn to_cpu_storage(&self) -> CpuStorage { + S::to_cpu_storage(self) + } +} + +impl NdArray for &[[S; N]; M] { + fn shape(&self) -> Result { + Ok(Shape::from((M, N))) + } + + fn to_cpu_storage(&self) -> CpuStorage { + S::to_cpu_storage_owned(self.concat()) + } +} + +impl NdArray + for &[[[S; N3]; N2]; N1] +{ + fn shape(&self) -> Result { + Ok(Shape::from((N1, N2, N3))) + } + + fn to_cpu_storage(&self) -> CpuStorage { + let mut vec = Vec::with_capacity(N1 * N2 * N3); + for i1 in 0..N1 { + for i2 in 0..N2 { + vec.extend(self[i1][i2]) + } + } + S::to_cpu_storage_owned(vec) + } +} + +impl NdArray + for &[[[[S; N4]; N3]; N2]; N1] +{ + fn shape(&self) -> Result { + Ok(Shape::from((N1, N2, N3, N4))) + } + + fn to_cpu_storage(&self) -> CpuStorage { + let mut vec = Vec::with_capacity(N1 * N2 * N3 * N4); + for i1 in 0..N1 { + for i2 in 0..N2 { + for i3 in 0..N3 { + vec.extend(self[i1][i2][i3]) + } + } + } + S::to_cpu_storage_owned(vec) + } +} + +impl NdArray for Vec { + fn shape(&self) -> Result { + if self.is_empty() { + crate::bail!("empty array") + } + let shape0 = self[0].shape()?; + let n = self.len(); + for v in self.iter() { + let shape = v.shape()?; + if shape != shape0 { + crate::bail!("two elements have different shapes {shape:?} {shape0:?}") + } + } + Ok(Shape::from([[n].as_slice(), shape0.dims()].concat())) + } + + fn to_cpu_storage(&self) -> CpuStorage { + // This allocates intermediary memory and shouldn't be necessary. + let storages = self.iter().map(|v| v.to_cpu_storage()).collect::>(); + CpuStorage::concat(storages.as_slice()).unwrap() + } +} + +impl Device { + pub fn new_cuda(ordinal: usize) -> Result { + Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?)) + } + + pub fn as_cuda_device(&self) -> Result<&crate::CudaDevice> { + match self { + Self::Cuda(d) => Ok(d), + Self::Cpu => crate::bail!("expected a cuda device, got cpu"), + Self::Metal(_) => crate::bail!("expected a cuda device, got Metal"), + } + } + + pub fn as_metal_device(&self) -> Result<&crate::MetalDevice> { + match self { + Self::Cuda(_) => crate::bail!("expected a metal device, got cuda"), + Self::Cpu => crate::bail!("expected a metal device, got cpu"), + Self::Metal(d) => Ok(d), + } + } + + pub fn new_cuda_with_stream(ordinal: usize) -> Result { + Ok(Self::Cuda(crate::CudaDevice::new_with_stream(ordinal)?)) + } + + pub fn new_metal(ordinal: usize) -> Result { + Ok(Self::Metal(crate::MetalDevice::new(ordinal)?)) + } + + pub fn set_seed(&self, seed: u64) -> Result<()> { + match self { + Self::Cpu => CpuDevice.set_seed(seed), + Self::Cuda(c) => c.set_seed(seed), + Self::Metal(m) => m.set_seed(seed), + } + } + + /// Get the current seed for the device RNG. + pub fn get_current_seed(&self) -> Result { + match self { + Self::Cpu => CpuDevice.get_current_seed(), + Self::Cuda(c) => c.get_current_seed(), + Self::Metal(m) => m.get_current_seed(), + } + } + + pub fn same_device(&self, rhs: &Self) -> bool { + match (self, rhs) { + (Self::Cpu, Self::Cpu) => true, + (Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs), + (Self::Metal(lhs), Self::Metal(rhs)) => lhs.same_device(rhs), + _ => false, + } + } + + pub fn location(&self) -> DeviceLocation { + match self { + Self::Cpu => DeviceLocation::Cpu, + Self::Cuda(device) => device.location(), + Device::Metal(device) => device.location(), + } + } + + pub fn is_cpu(&self) -> bool { + matches!(self, Self::Cpu) + } + + pub fn is_cuda(&self) -> bool { + matches!(self, Self::Cuda(_)) + } + + pub fn is_metal(&self) -> bool { + matches!(self, Self::Metal(_)) + } + + pub fn supports_bf16(&self) -> bool { + match self { + Self::Cuda(_) | Self::Metal(_) => true, + Self::Cpu => false, + } + } + + /// Return `BF16` for devices that support it, otherwise default to `F32`. + pub fn bf16_default_to_f32(&self) -> DType { + if self.supports_bf16() { + DType::BF16 + } else { + DType::F32 + } + } + + pub fn cuda_if_available(ordinal: usize) -> Result { + if crate::utils::cuda_is_available() { + Self::new_cuda(ordinal) + } else { + Ok(Self::Cpu) + } + } + + pub(crate) fn rand_uniform_f64( + &self, + lo: f64, + up: f64, + shape: &Shape, + dtype: DType, + ) -> Result { + match self { + Device::Cpu => { + let storage = CpuDevice.rand_uniform(shape, dtype, lo, up)?; + Ok(Storage::Cpu(storage)) + } + Device::Cuda(device) => { + // TODO: Remove the special case if we start supporting generating f16/bf16 directly. + if dtype == DType::F16 || dtype == DType::BF16 { + let storage = device.rand_uniform(shape, DType::F32, lo, up)?; + Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype) + } else { + let storage = device.rand_uniform(shape, dtype, lo, up)?; + Ok(Storage::Cuda(storage)) + } + } + Device::Metal(device) => { + let storage = device.rand_uniform(shape, dtype, lo, up)?; + Ok(Storage::Metal(storage)) + } + } + } + + pub(crate) fn rand_uniform( + &self, + lo: T, + up: T, + shape: &Shape, + ) -> Result { + self.rand_uniform_f64(lo.to_f64(), up.to_f64(), shape, T::DTYPE) + } + + pub(crate) fn rand_normal_f64( + &self, + mean: f64, + std: f64, + shape: &Shape, + dtype: DType, + ) -> Result { + match self { + Device::Cpu => { + let storage = CpuDevice.rand_normal(shape, dtype, mean, std)?; + Ok(Storage::Cpu(storage)) + } + Device::Cuda(device) => { + // TODO: Remove the special case if we start supporting generating f16/bf16 directly. + if dtype == DType::F16 || dtype == DType::BF16 { + let storage = device.rand_normal(shape, DType::F32, mean, std)?; + Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype) + } else { + let storage = device.rand_normal(shape, dtype, mean, std)?; + Ok(Storage::Cuda(storage)) + } + } + Device::Metal(device) => { + let storage = device.rand_normal(shape, dtype, mean, std)?; + Ok(Storage::Metal(storage)) + } + } + } + + pub(crate) fn rand_normal( + &self, + mean: T, + std: T, + shape: &Shape, + ) -> Result { + self.rand_normal_f64(mean.to_f64(), std.to_f64(), shape, T::DTYPE) + } + + pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result { + match self { + Device::Cpu => { + let storage = CpuDevice.ones_impl(shape, dtype)?; + Ok(Storage::Cpu(storage)) + } + Device::Cuda(device) => { + let storage = device.ones_impl(shape, dtype)?; + Ok(Storage::Cuda(storage)) + } + Device::Metal(device) => { + let storage = device.ones_impl(shape, dtype)?; + Ok(Storage::Metal(storage)) + } + } + } + + pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Result { + match self { + Device::Cpu => { + let storage = CpuDevice.zeros_impl(shape, dtype)?; + Ok(Storage::Cpu(storage)) + } + Device::Cuda(device) => { + let storage = device.zeros_impl(shape, dtype)?; + Ok(Storage::Cuda(storage)) + } + Device::Metal(device) => { + let storage = device.zeros_impl(shape, dtype)?; + Ok(Storage::Metal(storage)) + } + } + } + + pub(crate) unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result { + match self { + Device::Cpu => { + let storage = CpuDevice.alloc_uninit(shape, dtype)?; + Ok(Storage::Cpu(storage)) + } + Device::Cuda(device) => { + let storage = device.alloc_uninit(shape, dtype)?; + Ok(Storage::Cuda(storage)) + } + Device::Metal(device) => { + let storage = device.alloc_uninit(shape, dtype)?; + Ok(Storage::Metal(storage)) + } + } + } + + pub(crate) fn storage_from_slice(&self, data: &[D]) -> Result { + match self { + Device::Cpu => Ok(Storage::Cpu(data.to_cpu_storage())), + Device::Cuda(device) => { + let storage = device.storage_from_slice(data)?; + Ok(Storage::Cuda(storage)) + } + Device::Metal(device) => { + let storage = device.storage_from_slice(data)?; + Ok(Storage::Metal(storage)) + } + } + } + + pub(crate) fn storage(&self, array: A) -> Result { + match self { + Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())), + Device::Cuda(device) => { + let storage = array.to_cpu_storage(); + let storage = device.storage_from_cpu_storage(&storage)?; + Ok(Storage::Cuda(storage)) + } + Device::Metal(device) => { + let storage = array.to_cpu_storage(); + let storage = device.storage_from_cpu_storage(&storage)?; + Ok(Storage::Metal(storage)) + } + } + } + + pub(crate) fn storage_owned(&self, data: Vec) -> Result { + match self { + Device::Cpu => Ok(Storage::Cpu(S::to_cpu_storage_owned(data))), + Device::Cuda(device) => { + let storage = S::to_cpu_storage_owned(data); + let storage = device.storage_from_cpu_storage(&storage)?; + Ok(Storage::Cuda(storage)) + } + Device::Metal(device) => { + let storage = S::to_cpu_storage_owned(data); + let storage = device.storage_from_cpu_storage(&storage)?; + Ok(Storage::Metal(storage)) + } + } + } + + pub fn synchronize(&self) -> Result<()> { + match self { + Self::Cpu => Ok(()), + Self::Cuda(d) => d.synchronize(), + Self::Metal(d) => d.synchronize(), + } + } +} diff --git a/mcandle/mcandle-core/src/display.rs b/mcandle/mcandle-core/src/display.rs new file mode 100644 index 000000000..8b8f49503 --- /dev/null +++ b/mcandle/mcandle-core/src/display.rs @@ -0,0 +1,543 @@ +/// Pretty printing of tensors +/// This implementation should be in line with the PyTorch version. +/// https://github.com/pytorch/pytorch/blob/7b419e8513a024e172eae767e24ec1b849976b13/torch/_tensor_str.py +use crate::{DType, Result, Tensor, WithDType}; +use float8::F8E4M3; +use half::{bf16, f16}; + +impl Tensor { + fn fmt_dt( + &self, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + let device_str = match self.device().location() { + crate::DeviceLocation::Cpu => "".to_owned(), + crate::DeviceLocation::Cuda { gpu_id } => { + format!(", cuda:{}", gpu_id) + } + crate::DeviceLocation::Metal { gpu_id } => { + format!(", metal:{}", gpu_id) + } + }; + + write!(f, "Tensor[")?; + match self.dims() { + [] => { + if let Ok(v) = self.to_scalar::() { + write!(f, "{v}")? + } + } + [s] if *s < 10 => { + if let Ok(vs) = self.to_vec1::() { + for (i, v) in vs.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{v}")?; + } + } + } + dims => { + write!(f, "dims ")?; + for (i, d) in dims.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{d}")?; + } + } + } + write!(f, "; {}{}]", self.dtype().as_str(), device_str) + } +} + +impl std::fmt::Debug for Tensor { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self.dtype() { + DType::U8 => self.fmt_dt::(f), + DType::U32 => self.fmt_dt::(f), + DType::I16 => self.fmt_dt::(f), + DType::I32 => self.fmt_dt::(f), + DType::I64 => self.fmt_dt::(f), + DType::BF16 => self.fmt_dt::(f), + DType::F16 => self.fmt_dt::(f), + DType::F32 => self.fmt_dt::(f), + DType::F64 => self.fmt_dt::(f), + DType::F8E4M3 => self.fmt_dt::(f), + } + } +} + +/// Options for Tensor pretty printing +#[derive(Debug, Clone)] +pub struct PrinterOptions { + pub precision: usize, + pub threshold: usize, + pub edge_items: usize, + pub line_width: usize, + pub sci_mode: Option, +} + +static PRINT_OPTS: std::sync::Mutex = + std::sync::Mutex::new(PrinterOptions::const_default()); + +impl PrinterOptions { + // We cannot use the default trait as it's not const. + const fn const_default() -> Self { + Self { + precision: 4, + threshold: 1000, + edge_items: 3, + line_width: 80, + sci_mode: None, + } + } +} + +pub fn print_options() -> &'static std::sync::Mutex { + &PRINT_OPTS +} + +pub fn set_print_options(options: PrinterOptions) { + *PRINT_OPTS.lock().unwrap() = options +} + +pub fn set_print_options_default() { + *PRINT_OPTS.lock().unwrap() = PrinterOptions::const_default() +} + +pub fn set_print_options_short() { + *PRINT_OPTS.lock().unwrap() = PrinterOptions { + precision: 2, + threshold: 1000, + edge_items: 2, + line_width: 80, + sci_mode: None, + } +} + +pub fn set_print_options_full() { + *PRINT_OPTS.lock().unwrap() = PrinterOptions { + precision: 4, + threshold: usize::MAX, + edge_items: 3, + line_width: 80, + sci_mode: None, + } +} + +pub fn set_line_width(line_width: usize) { + PRINT_OPTS.lock().unwrap().line_width = line_width +} + +pub fn set_precision(precision: usize) { + PRINT_OPTS.lock().unwrap().precision = precision +} + +pub fn set_edge_items(edge_items: usize) { + PRINT_OPTS.lock().unwrap().edge_items = edge_items +} + +pub fn set_threshold(threshold: usize) { + PRINT_OPTS.lock().unwrap().threshold = threshold +} + +pub fn set_sci_mode(sci_mode: Option) { + PRINT_OPTS.lock().unwrap().sci_mode = sci_mode +} + +struct FmtSize { + current_size: usize, +} + +impl FmtSize { + fn new() -> Self { + Self { current_size: 0 } + } + + fn final_size(self) -> usize { + self.current_size + } +} + +impl std::fmt::Write for FmtSize { + fn write_str(&mut self, s: &str) -> std::fmt::Result { + self.current_size += s.len(); + Ok(()) + } +} + +trait TensorFormatter { + type Elem: WithDType; + + fn fmt(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result; + + fn max_width(&self, to_display: &Tensor) -> usize { + let mut max_width = 1; + if let Ok(vs) = to_display.flatten_all().and_then(|t| t.to_vec1()) { + for &v in vs.iter() { + let mut fmt_size = FmtSize::new(); + let _res = self.fmt(v, 1, &mut fmt_size); + max_width = usize::max(max_width, fmt_size.final_size()) + } + } + max_width + } + + fn write_newline_indent(i: usize, f: &mut std::fmt::Formatter) -> std::fmt::Result { + writeln!(f)?; + for _ in 0..i { + write!(f, " ")? + } + Ok(()) + } + + fn fmt_tensor( + &self, + t: &Tensor, + indent: usize, + max_w: usize, + summarize: bool, + po: &PrinterOptions, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + let dims = t.dims(); + let edge_items = po.edge_items; + write!(f, "[")?; + match dims { + [] => { + if let Ok(v) = t.to_scalar::() { + self.fmt(v, max_w, f)? + } + } + [v] if summarize && *v > 2 * edge_items => { + if let Ok(vs) = t + .narrow(0, 0, edge_items) + .and_then(|t| t.to_vec1::()) + { + for v in vs.into_iter() { + self.fmt(v, max_w, f)?; + write!(f, ", ")?; + } + } + write!(f, "...")?; + if let Ok(vs) = t + .narrow(0, v - edge_items, edge_items) + .and_then(|t| t.to_vec1::()) + { + for v in vs.into_iter() { + write!(f, ", ")?; + self.fmt(v, max_w, f)?; + } + } + } + [_] => { + let elements_per_line = usize::max(1, po.line_width / (max_w + 2)); + if let Ok(vs) = t.to_vec1::() { + for (i, v) in vs.into_iter().enumerate() { + if i > 0 { + if i % elements_per_line == 0 { + write!(f, ",")?; + Self::write_newline_indent(indent, f)? + } else { + write!(f, ", ")?; + } + } + self.fmt(v, max_w, f)? + } + } + } + _ => { + if summarize && dims[0] > 2 * edge_items { + for i in 0..edge_items { + match t.get(i) { + Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?, + Err(e) => write!(f, "{e:?}")?, + } + write!(f, ",")?; + Self::write_newline_indent(indent, f)? + } + write!(f, "...")?; + Self::write_newline_indent(indent, f)?; + for i in dims[0] - edge_items..dims[0] { + match t.get(i) { + Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?, + Err(e) => write!(f, "{e:?}")?, + } + if i + 1 != dims[0] { + write!(f, ",")?; + Self::write_newline_indent(indent, f)? + } + } + } else { + for i in 0..dims[0] { + match t.get(i) { + Ok(t) => self.fmt_tensor(&t, indent + 1, max_w, summarize, po, f)?, + Err(e) => write!(f, "{e:?}")?, + } + if i + 1 != dims[0] { + write!(f, ",")?; + Self::write_newline_indent(indent, f)? + } + } + } + } + } + write!(f, "]")?; + Ok(()) + } +} + +struct FloatFormatter { + int_mode: bool, + sci_mode: bool, + precision: usize, + _phantom: std::marker::PhantomData, +} + +impl FloatFormatter +where + S: WithDType + num_traits::Float + std::fmt::Display, +{ + fn new(t: &Tensor, po: &PrinterOptions) -> Result { + let mut int_mode = true; + let mut sci_mode = false; + + // Rather than containing all values, this should only include + // values that end up being displayed according to [threshold]. + let values = t + .flatten_all()? + .to_vec1()? + .into_iter() + .filter(|v: &S| v.is_finite() && !v.is_zero()) + .collect::>(); + if !values.is_empty() { + let mut nonzero_finite_min = S::max_value(); + let mut nonzero_finite_max = S::min_value(); + for &v in values.iter() { + let v = v.abs(); + if v < nonzero_finite_min { + nonzero_finite_min = v + } + if v > nonzero_finite_max { + nonzero_finite_max = v + } + } + + for &value in values.iter() { + if value.ceil() != value { + int_mode = false; + break; + } + } + if let Some(v1) = S::from(1000.) { + if let Some(v2) = S::from(1e8) { + if let Some(v3) = S::from(1e-4) { + sci_mode = nonzero_finite_max / nonzero_finite_min > v1 + || nonzero_finite_max > v2 + || nonzero_finite_min < v3 + } + } + } + } + + match po.sci_mode { + None => {} + Some(v) => sci_mode = v, + } + Ok(Self { + int_mode, + sci_mode, + precision: po.precision, + _phantom: std::marker::PhantomData, + }) + } +} + +impl TensorFormatter for FloatFormatter +where + S: WithDType + num_traits::Float + std::fmt::Display + std::fmt::LowerExp, +{ + type Elem = S; + + fn fmt(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result { + if self.sci_mode { + write!( + f, + "{v:width$.prec$e}", + v = v, + width = max_w, + prec = self.precision + ) + } else if self.int_mode { + if v.is_finite() { + write!(f, "{v:width$.0}.", v = v, width = max_w - 1) + } else { + write!(f, "{v:max_w$.0}") + } + } else { + write!( + f, + "{v:width$.prec$}", + v = v, + width = max_w, + prec = self.precision + ) + } + } +} + +struct IntFormatter { + _phantom: std::marker::PhantomData, +} + +impl IntFormatter { + fn new() -> Self { + Self { + _phantom: std::marker::PhantomData, + } + } +} + +impl TensorFormatter for IntFormatter +where + S: WithDType + std::fmt::Display, +{ + type Elem = S; + + fn fmt(&self, v: Self::Elem, max_w: usize, f: &mut T) -> std::fmt::Result { + write!(f, "{v:max_w$}") + } +} + +fn get_summarized_data(t: &Tensor, edge_items: usize) -> Result { + let dims = t.dims(); + if dims.is_empty() { + Ok(t.clone()) + } else if dims.len() == 1 { + if dims[0] > 2 * edge_items { + Tensor::cat( + &[ + t.narrow(0, 0, edge_items)?, + t.narrow(0, dims[0] - edge_items, edge_items)?, + ], + 0, + ) + } else { + Ok(t.clone()) + } + } else if dims[0] > 2 * edge_items { + let mut vs: Vec<_> = (0..edge_items) + .map(|i| get_summarized_data(&t.get(i)?, edge_items)) + .collect::>>()?; + for i in (dims[0] - edge_items)..dims[0] { + vs.push(get_summarized_data(&t.get(i)?, edge_items)?) + } + Tensor::cat(&vs, 0) + } else { + let vs: Vec<_> = (0..dims[0]) + .map(|i| get_summarized_data(&t.get(i)?, edge_items)) + .collect::>>()?; + Tensor::cat(&vs, 0) + } +} + +impl std::fmt::Display for Tensor { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let po = PRINT_OPTS.lock().unwrap(); + let summarize = self.elem_count() > po.threshold; + let to_display = if summarize { + match get_summarized_data(self, po.edge_items) { + Ok(v) => v, + Err(err) => return write!(f, "{err:?}"), + } + } else { + self.clone() + }; + match self.dtype() { + DType::U8 => { + let tf: IntFormatter = IntFormatter::new(); + let max_w = tf.max_width(&to_display); + tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; + writeln!(f)?; + } + DType::U32 => { + let tf: IntFormatter = IntFormatter::new(); + let max_w = tf.max_width(&to_display); + tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; + writeln!(f)?; + } + DType::I16 => { + let tf: IntFormatter = IntFormatter::new(); + let max_w = tf.max_width(&to_display); + tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; + writeln!(f)?; + } + DType::I32 => { + let tf: IntFormatter = IntFormatter::new(); + let max_w = tf.max_width(&to_display); + tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; + writeln!(f)?; + } + DType::I64 => { + let tf: IntFormatter = IntFormatter::new(); + let max_w = tf.max_width(&to_display); + tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; + writeln!(f)?; + } + DType::BF16 => { + if let Ok(tf) = FloatFormatter::::new(&to_display, &po) { + let max_w = tf.max_width(&to_display); + tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; + writeln!(f)?; + } + } + DType::F16 => { + if let Ok(tf) = FloatFormatter::::new(&to_display, &po) { + let max_w = tf.max_width(&to_display); + tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; + writeln!(f)?; + } + } + DType::F64 => { + if let Ok(tf) = FloatFormatter::::new(&to_display, &po) { + let max_w = tf.max_width(&to_display); + tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; + writeln!(f)?; + } + } + DType::F32 => { + if let Ok(tf) = FloatFormatter::::new(&to_display, &po) { + let max_w = tf.max_width(&to_display); + tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; + writeln!(f)?; + } + } + DType::F8E4M3 => { + if let Ok(tf) = FloatFormatter::::new(&to_display, &po) { + let max_w = tf.max_width(&to_display); + tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; + writeln!(f)?; + } + } + }; + + let device_str = match self.device().location() { + crate::DeviceLocation::Cpu => "".to_owned(), + crate::DeviceLocation::Cuda { gpu_id } => { + format!(", cuda:{}", gpu_id) + } + crate::DeviceLocation::Metal { gpu_id } => { + format!(", metal:{}", gpu_id) + } + }; + + write!( + f, + "Tensor[{:?}, {}{}]", + self.dims(), + self.dtype().as_str(), + device_str + ) + } +} diff --git a/mcandle/mcandle-core/src/dtype.rs b/mcandle/mcandle-core/src/dtype.rs new file mode 100644 index 000000000..f40ec3f7e --- /dev/null +++ b/mcandle/mcandle-core/src/dtype.rs @@ -0,0 +1,264 @@ +//! Types for elements that can be stored and manipulated using tensors. +#![allow(clippy::redundant_closure_call)] +use crate::backend::BackendStorage; +use crate::cpu::kernels::VecOps; +use crate::{CpuStorage, CpuStorageRef, Error, Result}; + +/// The different types of elements allowed in tensors. +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub enum DType { + // Floating-point 8 bits integer (4-bit exponent, 3-bit mantissa). + F8E4M3, + // Unsigned 8 bits integer. + U8, + // Unsigned 32 bits integer. + U32, + // Signed 16 bits integer. + I16, + // Signed 32 bits integer. + I32, + // Signed 64 bits integer. + I64, + // Brain floating-point using half precision (16 bits). + BF16, + // Floating-point using half precision (16 bits). + F16, + // Floating-point using single precision (32 bits). + F32, + // Floating-point using double precision (64 bits). + F64, +} + +#[derive(Debug, PartialEq, Eq)] +pub struct DTypeParseError(String); + +impl std::fmt::Display for DTypeParseError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "cannot parse '{}' as a dtype", self.0) + } +} + +impl std::error::Error for DTypeParseError {} + +impl std::str::FromStr for DType { + type Err = DTypeParseError; + fn from_str(s: &str) -> std::result::Result { + match s { + "u8" => Ok(Self::U8), + "u32" => Ok(Self::U32), + "i16" => Ok(Self::I16), + "i32" => Ok(Self::I32), + "i64" => Ok(Self::I64), + "bf16" => Ok(Self::BF16), + "f16" => Ok(Self::F16), + "f32" => Ok(Self::F32), + "f64" => Ok(Self::F64), + "f8_e4m3" => Ok(Self::F8E4M3), + _ => Err(DTypeParseError(s.to_string())), + } + } +} + +impl DType { + /// String representation for dtypes. + pub fn as_str(&self) -> &'static str { + match self { + Self::U8 => "u8", + Self::U32 => "u32", + Self::I16 => "i16", + Self::I32 => "i32", + Self::I64 => "i64", + Self::BF16 => "bf16", + Self::F16 => "f16", + Self::F32 => "f32", + Self::F64 => "f64", + Self::F8E4M3 => "f8_e4m3", + } + } + + /// The size used by each element in bytes, i.e. 1 for `U8`, 4 for `F32`. + pub fn size_in_bytes(&self) -> usize { + match self { + Self::U8 => 1, + Self::F8E4M3 => 1, + Self::U32 => 4, + Self::I16 => 2, + Self::I32 => 4, + Self::I64 => 8, + Self::BF16 => 2, + Self::F16 => 2, + Self::F32 => 4, + Self::F64 => 8, + } + } + + pub fn is_int(&self) -> bool { + match self { + Self::U8 | Self::U32 | Self::I16 | Self::I32 | Self::I64 => true, + Self::BF16 | Self::F16 | Self::F32 | Self::F64 | Self::F8E4M3 => false, + } + } + + pub fn is_float(&self) -> bool { + match self { + Self::U8 | Self::U32 | Self::I16 | Self::I32 | Self::I64 => false, + Self::BF16 | Self::F16 | Self::F32 | Self::F64 | Self::F8E4M3 => true, + } + } +} + +pub trait WithDType: + Sized + + Copy + + num_traits::NumAssign + + std::cmp::PartialOrd + + std::fmt::Display + + 'static + + Send + + Sync + + std::any::Any + + crate::cpu::kernels::VecOps +{ + const DTYPE: DType; + + fn from_f64(v: f64) -> Self; + fn to_f64(self) -> f64; + fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_>; + fn to_cpu_storage_owned(data: Vec) -> CpuStorage; + + fn to_cpu_storage(data: &[Self]) -> CpuStorage { + Self::to_cpu_storage_owned(data.to_vec()) + } + + fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]>; + fn cpu_storage_data(s: CpuStorage) -> Result>; +} + +macro_rules! with_dtype { + ($ty:ty, $dtype:ident, $from_f64:expr, $to_f64:expr) => { + impl WithDType for $ty { + const DTYPE: DType = DType::$dtype; + + fn from_f64(v: f64) -> Self { + $from_f64(v) + } + + fn to_f64(self) -> f64 { + $to_f64(self) + } + + fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_> { + CpuStorageRef::$dtype(data) + } + + fn to_cpu_storage_owned(data: Vec) -> CpuStorage { + CpuStorage::$dtype(data) + } + + fn cpu_storage_data(s: CpuStorage) -> Result> { + match s { + CpuStorage::$dtype(data) => Ok(data), + _ => Err(Error::UnexpectedDType { + expected: DType::$dtype, + got: s.dtype(), + msg: "unexpected dtype", + } + .bt()), + } + } + + fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> { + match s { + CpuStorage::$dtype(data) => Ok(data), + _ => Err(Error::UnexpectedDType { + expected: DType::$dtype, + got: s.dtype(), + msg: "unexpected dtype", + } + .bt()), + } + } + } + }; +} +use float8::F8E4M3; +use half::{bf16, f16}; + +with_dtype!(u8, U8, |v: f64| v as u8, |v: u8| v as f64); +with_dtype!(u32, U32, |v: f64| v as u32, |v: u32| v as f64); +with_dtype!(i16, I16, |v: f64| v as i16, |v: i16| v as f64); +with_dtype!(i32, I32, |v: f64| v as i32, |v: i32| v as f64); +with_dtype!(i64, I64, |v: f64| v as i64, |v: i64| v as f64); +with_dtype!(f16, F16, f16::from_f64, f16::to_f64); +with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64); +with_dtype!(f32, F32, |v: f64| v as f32, |v: f32| v as f64); +with_dtype!(f64, F64, |v: f64| v, |v: f64| v); +with_dtype!(F8E4M3, F8E4M3, |v: f64| F8E4M3::from_f64(v), |v: F8E4M3| v + .to_f64()); + +impl VecOps for F8E4M3 { + fn max(self, rhs: Self) -> Self { + F8E4M3::max(self, rhs) + } + fn min(self, rhs: Self) -> Self { + F8E4M3::min(self, rhs) + } +} + +pub trait IntDType: WithDType { + fn is_true(&self) -> bool; + fn as_usize(&self) -> usize; +} + +impl IntDType for i16 { + fn is_true(&self) -> bool { + *self != 0 + } + fn as_usize(&self) -> usize { + *self as usize + } +} + +impl IntDType for i32 { + fn is_true(&self) -> bool { + *self != 0 + } + fn as_usize(&self) -> usize { + *self as usize + } +} + +impl IntDType for i64 { + fn is_true(&self) -> bool { + *self != 0 + } + fn as_usize(&self) -> usize { + *self as usize + } +} + +impl IntDType for u32 { + fn is_true(&self) -> bool { + *self != 0 + } + fn as_usize(&self) -> usize { + *self as usize + } +} + +impl IntDType for u8 { + fn is_true(&self) -> bool { + *self != 0 + } + fn as_usize(&self) -> usize { + *self as usize + } +} + +pub trait FloatDType: WithDType {} + +impl FloatDType for f16 {} +impl FloatDType for bf16 {} +impl FloatDType for f32 {} +impl FloatDType for f64 {} +impl FloatDType for F8E4M3 {} diff --git a/mcandle/mcandle-core/src/dummy_cuda_backend.rs b/mcandle/mcandle-core/src/dummy_cuda_backend.rs new file mode 100644 index 000000000..814519ba3 --- /dev/null +++ b/mcandle/mcandle-core/src/dummy_cuda_backend.rs @@ -0,0 +1,294 @@ +#![allow(dead_code)] +use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; +use crate::{CpuStorage, DType, Error, Layout, Result, Shape}; + +#[derive(Debug, Clone)] +pub struct CudaDevice; + +#[derive(Debug)] +pub struct CudaStorage; + +macro_rules! fail { + () => { + unimplemented!("cuda support has not been enabled, add `cuda` feature to enable.") + }; +} + +impl CudaDevice { + pub fn new_with_stream(_: usize) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } +} + +impl crate::backend::BackendStorage for CudaStorage { + type Device = CudaDevice; + + fn try_clone(&self, _: &Layout) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + + fn dtype(&self) -> DType { + fail!() + } + + fn device(&self) -> &Self::Device { + fail!() + } + + fn to_cpu_storage(&self) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + + fn affine(&self, _: &Layout, _: f64, _: f64) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + + fn powf(&self, _: &Layout, _: f64) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + + fn elu(&self, _: &Layout, _: f64) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + + fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + + fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + + fn to_dtype(&self, _: &Layout, _: DType) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + + fn unary_impl(&self, _: &Layout) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + + fn binary_impl(&self, _: &Self, _: &Layout, _: &Layout) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + + fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + + fn conv1d( + &self, + _: &Layout, + _: &Self, + _: &Layout, + _: &crate::conv::ParamsConv1D, + ) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + + fn conv_transpose1d( + &self, + _: &Layout, + _: &Self, + _: &Layout, + _: &crate::conv::ParamsConvTranspose1D, + ) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + + fn conv2d( + &self, + _: &Layout, + _: &Self, + _: &Layout, + _: &crate::conv::ParamsConv2D, + ) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + + fn conv_transpose2d( + &self, + _l: &Layout, + _kernel: &Self, + _kernel_l: &Layout, + _params: &crate::conv::ParamsConvTranspose2D, + ) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + + fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + + fn scatter_add( + &self, + _: &Layout, + _: &Self, + _: &Layout, + _: &Self, + _: &Layout, + _: usize, + ) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + + fn index_add( + &self, + _: &Layout, + _: &Self, + _: &Layout, + _: &Self, + _: &Layout, + _: usize, + ) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + + fn matmul_with_alpha_beta( + &self, + _: &Self, + _: &mut Self, + _: Option, + _: (usize, usize, usize, usize), + _: &Layout, + _: &Layout, + _: &Layout, + ) -> Result<()> { + Err(Error::NotCompiledWithCudaSupport) + } + + fn matmul_with_alpha( + &self, + _: &Self, + _: Option, + _: (usize, usize, usize, usize), + _: &Layout, + _: &Layout, + ) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + + fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> { + Err(Error::NotCompiledWithCudaSupport) + } + + fn copy2d( + &self, + _: &mut Self, + _: usize, + _: usize, + _: usize, + _: usize, + _: usize, + _: usize, + ) -> Result<()> { + Err(Error::NotCompiledWithCudaSupport) + } + + fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + + fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + + fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + + fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } +} + +impl crate::backend::BackendDevice for CudaDevice { + type Storage = CudaStorage; + fn new(_: usize) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + + fn set_seed(&self, _: u64) -> Result<()> { + Err(Error::NotCompiledWithCudaSupport) + } + + fn get_current_seed(&self) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + + fn location(&self) -> crate::DeviceLocation { + fail!() + } + + fn same_device(&self, _: &Self) -> bool { + fail!() + } + + fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + + fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + + unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + + fn storage_from_slice(&self, _: &[T]) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + + fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + + fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + + fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + + fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + + fn synchronize(&self) -> Result<()> { + Ok(()) + } +} + +/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are +/// allowed with f16 GEMMs. +pub fn gemm_reduced_precision_f16() -> bool { + true +} + +/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are +/// allowed with f16 GEMMs. +pub fn set_gemm_reduced_precision_f16(_: bool) {} + +/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are +/// allowed with bf16 GEMMs. +pub fn gemm_reduced_precision_bf16() -> bool { + true +} + +/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are +/// allowed with bf16 GEMMs. +pub fn set_gemm_reduced_precision_bf16(_: bool) {} + +/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are +/// allowed with f32 GEMMs. +pub fn gemm_reduced_precision_f32() -> bool { + true +} + +/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are +/// allowed with f32 GEMMs. +pub fn set_gemm_reduced_precision_f32(_b: bool) {} diff --git a/mcandle/mcandle-core/src/dummy_metal_backend.rs b/mcandle/mcandle-core/src/dummy_metal_backend.rs new file mode 100644 index 000000000..2a3ea93c0 --- /dev/null +++ b/mcandle/mcandle-core/src/dummy_metal_backend.rs @@ -0,0 +1,270 @@ +#![allow(dead_code)] +use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; +use crate::{CpuStorage, DType, Error, Layout, Result, Shape}; + +#[derive(Debug, Clone)] +pub struct MetalDevice; + +#[derive(Debug)] +pub struct MetalStorage; + +#[derive(thiserror::Error, Debug)] +pub enum MetalError { + #[error("{0}")] + Message(String), +} + +impl From for MetalError { + fn from(e: String) -> Self { + MetalError::Message(e) + } +} + +macro_rules! fail { + () => { + unimplemented!("metal support has not been enabled, add `metal` feature to enable.") + }; +} + +impl crate::backend::BackendStorage for MetalStorage { + type Device = MetalDevice; + + fn try_clone(&self, _: &Layout) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn dtype(&self) -> DType { + fail!() + } + + fn device(&self) -> &Self::Device { + fail!() + } + + fn to_cpu_storage(&self) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn affine(&self, _: &Layout, _: f64, _: f64) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn powf(&self, _: &Layout, _: f64) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn elu(&self, _: &Layout, _: f64) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn to_dtype(&self, _: &Layout, _: DType) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn unary_impl(&self, _: &Layout) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn binary_impl(&self, _: &Self, _: &Layout, _: &Layout) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn conv1d( + &self, + _: &Layout, + _: &Self, + _: &Layout, + _: &crate::conv::ParamsConv1D, + ) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn conv_transpose1d( + &self, + _l: &Layout, + _kernel: &Self, + _kernel_l: &Layout, + _params: &crate::conv::ParamsConvTranspose1D, + ) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn conv2d( + &self, + _: &Layout, + _: &Self, + _: &Layout, + _: &crate::conv::ParamsConv2D, + ) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn conv_transpose2d( + &self, + _l: &Layout, + _kernel: &Self, + _kernel_l: &Layout, + _params: &crate::conv::ParamsConvTranspose2D, + ) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn scatter_add( + &self, + _: &Layout, + _: &Self, + _: &Layout, + _: &Self, + _: &Layout, + _: usize, + ) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn index_add( + &self, + _: &Layout, + _: &Self, + _: &Layout, + _: &Self, + _: &Layout, + _: usize, + ) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn matmul_with_alpha_beta( + &self, + _: &Self, + _: &mut Self, + _: Option, + _: (usize, usize, usize, usize), + _: &Layout, + _: &Layout, + _: &Layout, + ) -> Result<()> { + Err(Error::NotCompiledWithMetalSupport) + } + + fn matmul_with_alpha( + &self, + _: &Self, + _: Option, + _: (usize, usize, usize, usize), + _: &Layout, + _: &Layout, + ) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> { + Err(Error::NotCompiledWithMetalSupport) + } + + fn copy2d( + &self, + _: &mut Self, + _: usize, + _: usize, + _: usize, + _: usize, + _: usize, + _: usize, + ) -> Result<()> { + Err(Error::NotCompiledWithMetalSupport) + } + + fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } +} + +impl crate::backend::BackendDevice for MetalDevice { + type Storage = MetalStorage; + fn new(_: usize) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn set_seed(&self, _: u64) -> Result<()> { + Err(Error::NotCompiledWithMetalSupport) + } + + fn get_current_seed(&self) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn location(&self) -> crate::DeviceLocation { + fail!() + } + + fn same_device(&self, _: &Self) -> bool { + fail!() + } + + fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn storage_from_slice(&self, _: &[T]) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + fn synchronize(&self) -> Result<()> { + Ok(()) + } +} diff --git a/mcandle/mcandle-core/src/error.rs b/mcandle/mcandle-core/src/error.rs new file mode 100644 index 000000000..4d2577f7d --- /dev/null +++ b/mcandle/mcandle-core/src/error.rs @@ -0,0 +1,394 @@ +use std::{ + convert::Infallible, + fmt::{Debug, Display}, +}; + +use crate::{DType, DeviceLocation, Layout, MetalError, Shape}; + +#[derive(Debug, Clone)] +pub struct MatMulUnexpectedStriding { + pub lhs_l: Layout, + pub rhs_l: Layout, + pub bmnk: (usize, usize, usize, usize), + pub msg: &'static str, +} + +/// Main library error type. +#[derive(thiserror::Error, Debug)] +pub enum Error { + // === DType Errors === + #[error("{msg}, expected: {expected:?}, got: {got:?}")] + UnexpectedDType { + msg: &'static str, + expected: DType, + got: DType, + }, + + #[error("dtype mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")] + DTypeMismatchBinaryOp { + lhs: DType, + rhs: DType, + op: &'static str, + }, + + #[error("dtype mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}, c: {rhs:?}")] + DTypeMismatchBinaryOp3 { + lhs: DType, + rhs: DType, + c: DType, + op: &'static str, + }, + + #[error("unsupported dtype {0:?} for op {1}")] + UnsupportedDTypeForOp(DType, &'static str), + + // === Dimension Index Errors === + #[error("{op}: dimension index {dim} out of range for shape {shape:?}")] + DimOutOfRange { + shape: Shape, + dim: i32, + op: &'static str, + }, + + #[error("{op}: duplicate dim index {dims:?} for shape {shape:?}")] + DuplicateDimIndex { + shape: Shape, + dims: Vec, + op: &'static str, + }, + + // === Shape Errors === + #[error("unexpected rank, expected: {expected}, got: {got} ({shape:?})")] + UnexpectedNumberOfDims { + expected: usize, + got: usize, + shape: Shape, + }, + + #[error("{msg}, expected: {expected:?}, got: {got:?}")] + UnexpectedShape { + msg: String, + expected: Shape, + got: Shape, + }, + + #[error( + "Shape mismatch, got buffer of size {buffer_size} which is compatible with shape {shape:?}" + )] + ShapeMismatch { buffer_size: usize, shape: Shape }, + + #[error("shape mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")] + ShapeMismatchBinaryOp { + lhs: Shape, + rhs: Shape, + op: &'static str, + }, + + #[error("shape mismatch in cat for dim {dim}, shape for arg 1: {first_shape:?} shape for arg {n}: {nth_shape:?}")] + ShapeMismatchCat { + dim: usize, + first_shape: Shape, + n: usize, + nth_shape: Shape, + }, + + #[error("Cannot divide tensor of shape {shape:?} equally along dim {dim} into {n_parts}")] + ShapeMismatchSplit { + shape: Shape, + dim: usize, + n_parts: usize, + }, + + #[error("{op} can only be performed on a single dimension")] + OnlySingleDimension { op: &'static str, dims: Vec }, + + #[error("empty tensor for {op}")] + EmptyTensor { op: &'static str }, + + // === Device Errors === + #[error("device mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")] + DeviceMismatchBinaryOp { + lhs: DeviceLocation, + rhs: DeviceLocation, + op: &'static str, + }, + + #[error("device mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}, c: {c:?}")] + DeviceMismatchBinaryOp3 { + lhs: DeviceLocation, + rhs: DeviceLocation, + c: DeviceLocation, + op: &'static str, + }, + + // === Op Specific Errors === + #[error("narrow invalid args {msg}: {shape:?}, dim: {dim}, start: {start}, len:{len}")] + NarrowInvalidArgs { + shape: Shape, + dim: usize, + start: usize, + len: usize, + msg: &'static str, + }, + + #[error("conv1d invalid args {msg}: inp: {inp_shape:?}, k: {k_shape:?}, pad: {padding}, stride: {stride}")] + Conv1dInvalidArgs { + inp_shape: Shape, + k_shape: Shape, + padding: usize, + stride: usize, + msg: &'static str, + }, + + #[error("{op} invalid index {index} with dim size {size}")] + InvalidIndex { + op: &'static str, + index: usize, + size: usize, + }, + + #[error("cannot broadcast {src_shape:?} to {dst_shape:?}")] + BroadcastIncompatibleShapes { src_shape: Shape, dst_shape: Shape }, + + #[error("cannot set variable {msg}")] + CannotSetVar { msg: &'static str }, + + // Box indirection to avoid large variant. + #[error("{0:?}")] + MatMulUnexpectedStriding(Box), + + #[error("{op} only supports contiguous tensors")] + RequiresContiguous { op: &'static str }, + + #[error("{op} expects at least one tensor")] + OpRequiresAtLeastOneTensor { op: &'static str }, + + #[error("{op} expects at least two tensors")] + OpRequiresAtLeastTwoTensors { op: &'static str }, + + #[error("backward is not supported for {op}")] + BackwardNotSupported { op: &'static str }, + + // === Other Errors === + #[error("the mcandle crate has not been built with cuda support")] + NotCompiledWithCudaSupport, + + #[error("the mcandle crate has not been built with metal support")] + NotCompiledWithMetalSupport, + + #[error("cannot find tensor {path}")] + CannotFindTensor { path: String }, + + // === Wrapped Errors === + #[error(transparent)] + Cuda(Box), + + #[error("Metal error {0}")] + Metal(#[from] MetalError), + + #[error(transparent)] + Ug(#[from] ug::Error), + + #[error(transparent)] + TryFromIntError(#[from] core::num::TryFromIntError), + + #[error("npy/npz error {0}")] + Npy(String), + + /// Zip file format error. + #[error(transparent)] + Zip(#[from] zip::result::ZipError), + + /// Integer parse error. + #[error(transparent)] + ParseInt(#[from] std::num::ParseIntError), + + /// Utf8 parse error. + #[error(transparent)] + FromUtf8(#[from] std::string::FromUtf8Error), + + /// I/O error. + #[error(transparent)] + Io(#[from] std::io::Error), + + /// SafeTensor error. + #[error(transparent)] + SafeTensor(#[from] safetensors::SafeTensorError), + + #[error("unsupported safetensor dtype {0:?}")] + UnsupportedSafeTensorDtype(safetensors::Dtype), + + /// Arbitrary errors wrapping. + #[error(transparent)] + Wrapped(Box), + + /// Arbitrary errors wrapping with context. + #[error("{wrapped:?}\n{context:?}")] + WrappedContext { + wrapped: Box, + context: String, + }, + + /// Adding path information to an error. + #[error("path: {path:?} {inner}")] + WithPath { + inner: Box, + path: std::path::PathBuf, + }, + + #[error("{inner}\n{backtrace}")] + WithBacktrace { + inner: Box, + backtrace: Box, + }, + + /// User generated error message, typically created via `bail!`. + #[error("{0}")] + Msg(String), +} + +pub type Result = std::result::Result; + +impl Error { + /// Create a new error by wrapping another. + pub fn wrap(err: impl std::error::Error + Send + Sync + 'static) -> Self { + Self::Wrapped(Box::new(err)).bt() + } + + /// Create a new error based on a printable error message. + /// + /// If the message implements `std::error::Error`, prefer using [`Error::wrap`] instead. + pub fn msg(msg: M) -> Self { + Self::Msg(msg.to_string()).bt() + } + + /// Create a new error based on a debuggable error message. + /// + /// If the message implements `std::error::Error`, prefer using [`Error::wrap`] instead. + pub fn debug(err: impl std::fmt::Debug) -> Self { + Self::Msg(format!("{err:?}")).bt() + } + + pub fn bt(self) -> Self { + let backtrace = std::backtrace::Backtrace::capture(); + match backtrace.status() { + std::backtrace::BacktraceStatus::Disabled + | std::backtrace::BacktraceStatus::Unsupported => self, + _ => Self::WithBacktrace { + inner: Box::new(self), + backtrace: Box::new(backtrace), + }, + } + } + + pub fn with_path>(self, p: P) -> Self { + Self::WithPath { + inner: Box::new(self), + path: p.as_ref().to_path_buf(), + } + } +} + +#[macro_export] +macro_rules! bail { + ($msg:literal $(,)?) => { + return Err($crate::Error::Msg(format!($msg).into()).bt()) + }; + ($err:expr $(,)?) => { + return Err($crate::Error::Msg(format!($err).into()).bt()) + }; + ($fmt:expr, $($arg:tt)*) => { + return Err($crate::Error::Msg(format!($fmt, $($arg)*).into()).bt()) + }; +} + +pub fn zip(r1: Result, r2: Result) -> Result<(T, U)> { + match (r1, r2) { + (Ok(r1), Ok(r2)) => Ok((r1, r2)), + (Err(e), _) => Err(e), + (_, Err(e)) => Err(e), + } +} + +pub(crate) mod private { + pub trait Sealed {} + + impl Sealed for std::result::Result where E: std::error::Error {} + impl Sealed for Option {} +} + +/// Attach more context to an error. +/// +/// Inspired by [`anyhow::Context`]. +pub trait Context: private::Sealed { + /// Wrap the error value with additional context. + fn context(self, context: C) -> std::result::Result + where + C: Display + Send + Sync + 'static; + + /// Wrap the error value with additional context that is evaluated lazily + /// only once an error does occur. + fn with_context(self, f: F) -> std::result::Result + where + C: Display + Send + Sync + 'static, + F: FnOnce() -> C; +} + +impl Context for std::result::Result +where + E: std::error::Error + Send + Sync + 'static, +{ + fn context(self, context: C) -> std::result::Result + where + C: Display + Send + Sync + 'static, + { + // Not using map_err to save 2 useless frames off the captured backtrace + // in ext_context. + match self { + Ok(ok) => Ok(ok), + Err(error) => Err(Error::WrappedContext { + wrapped: Box::new(error), + context: context.to_string(), + }), + } + } + + fn with_context(self, context: F) -> std::result::Result + where + C: Display + Send + Sync + 'static, + F: FnOnce() -> C, + { + match self { + Ok(ok) => Ok(ok), + Err(error) => Err(Error::WrappedContext { + wrapped: Box::new(error), + context: context().to_string(), + }), + } + } +} + +impl Context for Option { + fn context(self, context: C) -> std::result::Result + where + C: Display + Send + Sync + 'static, + { + // Not using ok_or_else to save 2 useless frames off the captured + // backtrace. + match self { + Some(ok) => Ok(ok), + None => Err(Error::msg(context)), + } + } + + fn with_context(self, context: F) -> std::result::Result + where + C: Display + Send + Sync + 'static, + F: FnOnce() -> C, + { + match self { + Some(ok) => Ok(ok), + None => Err(Error::msg(context())), + } + } +} diff --git a/mcandle/mcandle-core/src/indexer.rs b/mcandle/mcandle-core/src/indexer.rs new file mode 100644 index 000000000..04b7cf088 --- /dev/null +++ b/mcandle/mcandle-core/src/indexer.rs @@ -0,0 +1,257 @@ +use crate::{Error, Tensor}; +use std::ops::{ + Bound, Range, RangeBounds, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive, +}; + +impl Tensor { + /// Intended to be use by the trait `.i()` + /// + /// ``` + /// # use mcandle_core::{Tensor, DType, Device, IndexOp}; + /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// + /// let c = a.i(0..1)?; + /// assert_eq!(c.shape().dims(), &[1, 3]); + /// + /// let c = a.i(0)?; + /// assert_eq!(c.shape().dims(), &[3]); + /// + /// let c = a.i((.., ..2) )?; + /// assert_eq!(c.shape().dims(), &[2, 2]); + /// + /// let c = a.i((.., ..=2))?; + /// assert_eq!(c.shape().dims(), &[2, 3]); + /// + /// # Ok::<(), mcandle_core::Error>(()) + /// ``` + fn index(&self, indexers: &[TensorIndexer]) -> Result { + let mut x = self.clone(); + let dims = self.shape().dims(); + let mut current_dim = 0; + for (i, indexer) in indexers.iter().enumerate() { + x = match indexer { + TensorIndexer::Select(n) => x.narrow(current_dim, *n, 1)?.squeeze(current_dim)?, + TensorIndexer::Narrow(left_bound, right_bound) => { + let start = match left_bound { + Bound::Included(n) => *n, + Bound::Excluded(n) => *n + 1, + Bound::Unbounded => 0, + }; + let stop = match right_bound { + Bound::Included(n) => *n + 1, + Bound::Excluded(n) => *n, + Bound::Unbounded => dims[i], + }; + let out = x.narrow(current_dim, start, stop.saturating_sub(start))?; + current_dim += 1; + out + } + TensorIndexer::IndexSelect(indexes) => { + if indexes.rank() != 1 { + crate::bail!("multi-dimensional tensor indexing is not supported") + } + let out = x.index_select(&indexes.to_device(x.device())?, current_dim)?; + current_dim += 1; + out + } + TensorIndexer::Err(e) => crate::bail!("indexing error {e:?}"), + }; + } + Ok(x) + } +} + +#[derive(Debug)] +/// Generic structure used to index a slice of the tensor +pub enum TensorIndexer { + /// This selects the elements for which an index has some specific value. + Select(usize), + /// This is a regular slice, purely indexing a chunk of the tensor + Narrow(Bound, Bound), + /// Indexing via a 1d tensor + IndexSelect(Tensor), + Err(Error), +} + +impl From for TensorIndexer { + fn from(index: usize) -> Self { + TensorIndexer::Select(index) + } +} + +impl From<&[u32]> for TensorIndexer { + fn from(index: &[u32]) -> Self { + match Tensor::new(index, &crate::Device::Cpu) { + Ok(tensor) => TensorIndexer::IndexSelect(tensor), + Err(e) => TensorIndexer::Err(e), + } + } +} + +impl From> for TensorIndexer { + fn from(index: Vec) -> Self { + let len = index.len(); + match Tensor::from_vec(index, len, &crate::Device::Cpu) { + Ok(tensor) => TensorIndexer::IndexSelect(tensor), + Err(e) => TensorIndexer::Err(e), + } + } +} + +impl From<&Tensor> for TensorIndexer { + fn from(tensor: &Tensor) -> Self { + TensorIndexer::IndexSelect(tensor.clone()) + } +} + +trait RB: RangeBounds {} +impl RB for Range {} +impl RB for RangeFrom {} +impl RB for RangeFull {} +impl RB for RangeInclusive {} +impl RB for RangeTo {} +impl RB for RangeToInclusive {} + +impl From for TensorIndexer { + fn from(range: T) -> Self { + use std::ops::Bound::*; + let start = match range.start_bound() { + Included(idx) => Included(*idx), + Excluded(idx) => Excluded(*idx), + Unbounded => Unbounded, + }; + let end = match range.end_bound() { + Included(idx) => Included(*idx), + Excluded(idx) => Excluded(*idx), + Unbounded => Unbounded, + }; + TensorIndexer::Narrow(start, end) + } +} + +/// Trait used to implement multiple signatures for ease of use of the slicing +/// of a tensor +pub trait IndexOp { + /// Returns a slicing iterator which are the chunks of data necessary to + /// reconstruct the desired tensor. + fn i(&self, index: T) -> Result; +} + +impl IndexOp for Tensor +where + T: Into, +{ + ///```rust + /// use mcandle_core::{Tensor, DType, Device, IndexOp}; + /// let a = Tensor::new(&[ + /// [0., 1.], + /// [2., 3.], + /// [4., 5.] + /// ], &Device::Cpu)?; + /// + /// let b = a.i(0)?; + /// assert_eq!(b.shape().dims(), &[2]); + /// assert_eq!(b.to_vec1::()?, &[0., 1.]); + /// + /// let c = a.i(..2)?; + /// assert_eq!(c.shape().dims(), &[2, 2]); + /// assert_eq!(c.to_vec2::()?, &[ + /// [0., 1.], + /// [2., 3.] + /// ]); + /// + /// let d = a.i(1..)?; + /// assert_eq!(d.shape().dims(), &[2, 2]); + /// assert_eq!(d.to_vec2::()?, &[ + /// [2., 3.], + /// [4., 5.] + /// ]); + /// # Ok::<(), mcandle_core::Error>(()) + /// ``` + fn i(&self, index: T) -> Result { + self.index(&[index.into()]) + } +} + +impl IndexOp<(A,)> for Tensor +where + A: Into, +{ + ///```rust + /// use mcandle_core::{Tensor, DType, Device, IndexOp}; + /// let a = Tensor::new(&[ + /// [0f32, 1.], + /// [2. , 3.], + /// [4. , 5.] + /// ], &Device::Cpu)?; + /// + /// let b = a.i((0,))?; + /// assert_eq!(b.shape().dims(), &[2]); + /// assert_eq!(b.to_vec1::()?, &[0., 1.]); + /// + /// let c = a.i((..2,))?; + /// assert_eq!(c.shape().dims(), &[2, 2]); + /// assert_eq!(c.to_vec2::()?, &[ + /// [0., 1.], + /// [2., 3.] + /// ]); + /// + /// let d = a.i((1..,))?; + /// assert_eq!(d.shape().dims(), &[2, 2]); + /// assert_eq!(d.to_vec2::()?, &[ + /// [2., 3.], + /// [4., 5.] + /// ]); + /// # Ok::<(), mcandle_core::Error>(()) + /// ``` + fn i(&self, (a,): (A,)) -> Result { + self.index(&[a.into()]) + } +} +#[allow(non_snake_case)] +impl IndexOp<(A, B)> for Tensor +where + A: Into, + B: Into, +{ + ///```rust + /// use mcandle_core::{Tensor, DType, Device, IndexOp}; + /// let a = Tensor::new(&[[0f32, 1., 2.], [3., 4., 5.], [6., 7., 8.]], &Device::Cpu)?; + /// + /// let b = a.i((1, 0))?; + /// assert_eq!(b.to_vec0::()?, 3.); + /// + /// let c = a.i((..2, 1))?; + /// assert_eq!(c.shape().dims(), &[2]); + /// assert_eq!(c.to_vec1::()?, &[1., 4.]); + /// + /// let d = a.i((2.., ..))?; + /// assert_eq!(c.shape().dims(), &[2]); + /// assert_eq!(c.to_vec1::()?, &[1., 4.]); + /// # Ok::<(), mcandle_core::Error>(()) + /// ``` + fn i(&self, (a, b): (A, B)) -> Result { + self.index(&[a.into(), b.into()]) + } +} + +macro_rules! index_op_tuple { + ($doc:tt, $($t:ident),+) => { + #[allow(non_snake_case)] + impl<$($t),*> IndexOp<($($t,)*)> for Tensor + where + $($t: Into,)* + { + #[doc=$doc] + fn i(&self, ($($t,)*): ($($t,)*)) -> Result { + self.index(&[$($t.into(),)*]) + } + } + }; +} + +index_op_tuple!("see [TensorIndex#method.i]", A, B, C); +index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D); +index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D, E); +index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D, E, F); +index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D, E, F, G); diff --git a/mcandle/mcandle-core/src/layout.rs b/mcandle/mcandle-core/src/layout.rs new file mode 100644 index 000000000..7e3b7afbb --- /dev/null +++ b/mcandle/mcandle-core/src/layout.rs @@ -0,0 +1,277 @@ +use crate::{Error, Result, Shape}; + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct Layout { + shape: Shape, + // The strides are given in number of elements and not in bytes. + stride: Vec, + start_offset: usize, +} + +impl Layout { + pub fn new(shape: Shape, stride: Vec, start_offset: usize) -> Self { + Self { + shape, + stride, + start_offset, + } + } + + pub fn contiguous_with_offset>(shape: S, start_offset: usize) -> Self { + let shape = shape.into(); + let stride = shape.stride_contiguous(); + Self { + shape, + stride, + start_offset, + } + } + + pub fn contiguous>(shape: S) -> Self { + Self::contiguous_with_offset(shape, 0) + } + + pub fn dims(&self) -> &[usize] { + self.shape.dims() + } + + /// The dimension size for a specified dimension index. + pub fn dim(&self, dim: D) -> Result { + let dim = dim.to_index(&self.shape, "dim")?; + Ok(self.dims()[dim]) + } + + pub fn shape(&self) -> &Shape { + &self.shape + } + + pub fn stride(&self) -> &[usize] { + &self.stride + } + + pub fn start_offset(&self) -> usize { + self.start_offset + } + + /// Returns the appropriate start and stop offset if the data is stored in a C + /// contiguous (aka row major) way. + pub fn contiguous_offsets(&self) -> Option<(usize, usize)> { + if self.is_contiguous() { + let start_o = self.start_offset; + Some((start_o, start_o + self.shape.elem_count())) + } else { + None + } + } + + /// Returns true if the data is stored in a C contiguous (aka row major) way. + /// Note that this does not implies that the start offset is 0 or that there are no extra + /// elements at the end of the storage. + pub fn is_contiguous(&self) -> bool { + self.shape.is_contiguous(&self.stride) + } + + /// Returns true if the data is stored in a Fortran contiguous (aka column major) way. + pub fn is_fortran_contiguous(&self) -> bool { + self.shape.is_fortran_contiguous(&self.stride) + } + + pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Result { + let dims = self.shape().dims(); + if dim >= dims.len() { + Err(Error::DimOutOfRange { + shape: self.shape().clone(), + dim: dim as i32, + op: "narrow", + } + .bt())? + } + if start + len > dims[dim] { + Err(Error::NarrowInvalidArgs { + shape: self.shape.clone(), + dim, + start, + len, + msg: "start + len > dim_len", + } + .bt())? + } + let mut dims = dims.to_vec(); + dims[dim] = len; + Ok(Self { + shape: Shape::from(dims), + stride: self.stride.clone(), + start_offset: self.start_offset + self.stride[dim] * start, + }) + } + + pub fn transpose(&self, dim1: usize, dim2: usize) -> Result { + let rank = self.shape.rank(); + if rank <= dim1 || rank <= dim2 { + Err(Error::UnexpectedNumberOfDims { + expected: usize::max(dim1, dim2), + got: rank, + shape: self.shape().clone(), + } + .bt())? + } + let mut stride = self.stride().to_vec(); + let mut dims = self.shape().dims().to_vec(); + dims.swap(dim1, dim2); + stride.swap(dim1, dim2); + Ok(Self { + shape: Shape::from(dims), + stride, + start_offset: self.start_offset, + }) + } + + pub fn permute(&self, idxs: &[usize]) -> Result { + let is_permutation = + idxs.len() == self.shape.rank() && (0..idxs.len()).all(|i| idxs.contains(&i)); + if !is_permutation { + crate::bail!( + "dimension mismatch in permute, tensor {:?}, dims: {:?}", + self.dims(), + idxs + ) + } + let stride = self.stride(); + let dims = self.shape().dims(); + let mut perm_stride = stride.to_vec(); + let mut perm_dims = dims.to_vec(); + for (i, &idx) in idxs.iter().enumerate() { + perm_stride[i] = stride[idx]; + perm_dims[i] = dims[idx]; + } + Ok(Self { + shape: Shape::from(perm_dims), + stride: perm_stride, + start_offset: self.start_offset, + }) + } + + pub fn broadcast_as>(&self, shape: S) -> Result { + let shape = shape.into(); + if shape.rank() < self.shape().rank() { + return Err(Error::BroadcastIncompatibleShapes { + src_shape: self.shape().clone(), + dst_shape: shape, + } + .bt()); + } + let added_dims = shape.rank() - self.shape().rank(); + let mut stride = vec![0; added_dims]; + for (&dst_dim, (&src_dim, &src_stride)) in shape.dims()[added_dims..] + .iter() + .zip(self.dims().iter().zip(self.stride())) + { + let s = if dst_dim == src_dim { + src_stride + } else if src_dim != 1 { + return Err(Error::BroadcastIncompatibleShapes { + src_shape: self.shape().clone(), + dst_shape: shape, + } + .bt()); + } else { + 0 + }; + stride.push(s) + } + Ok(Self { + shape, + stride, + start_offset: self.start_offset, + }) + } + + pub(crate) fn strided_index(&self) -> crate::StridedIndex { + crate::StridedIndex::from_layout(self) + } + + pub(crate) fn strided_blocks(&self) -> crate::StridedBlocks { + let mut block_len = 1; + let mut contiguous_dims = 0; // These are counted from the right. + for (&stride, &dim) in self.stride().iter().zip(self.dims().iter()).rev() { + if stride != block_len { + break; + } + block_len *= dim; + contiguous_dims += 1; + } + let index_dims = self.dims().len() - contiguous_dims; + if index_dims == 0 { + crate::StridedBlocks::SingleBlock { + start_offset: self.start_offset, + len: block_len, + } + } else { + let block_start_index = crate::StridedIndex::new( + &self.dims()[..index_dims], + &self.stride[..index_dims], + self.start_offset, + ); + crate::StridedBlocks::MultipleBlocks { + block_start_index, + block_len, + } + } + } + + // Returns the contiguous offsets with broadcast if applicable. + pub(crate) fn offsets_b(&self) -> Option { + let mut left_broadcast = 1; + let mut right_broadcast = 1; + let strides = self.stride(); + let dims = self.dims(); + let mut start_cont = 0; + let mut end_cont = dims.len(); + for (&s, &d) in strides.iter().zip(dims.iter()) { + if s != 0 { + break; + } + start_cont += 1; + left_broadcast *= d; + } + if start_cont == dims.len() { + return Some(ContiguousOffsetsWithBroadcast { + start: self.start_offset, + len: 1, + left_broadcast, + right_broadcast: 1, + }); + } + for (&s, &d) in strides.iter().zip(dims.iter()).rev() { + if s != 0 { + break; + } + end_cont -= 1; + right_broadcast *= d; + } + // Check that the inner dims are contiguous + let strides = &strides[start_cont..end_cont]; + let dims = &dims[start_cont..end_cont]; + let mut len = 1; + for (&stride, &dim) in strides.iter().zip(dims.iter()).rev() { + if stride != len { + return None; + } + len *= dim; + } + Some(ContiguousOffsetsWithBroadcast { + start: self.start_offset, + len, + left_broadcast, + right_broadcast, + }) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ContiguousOffsetsWithBroadcast { + pub start: usize, + pub len: usize, + pub left_broadcast: usize, + pub right_broadcast: usize, +} diff --git a/mcandle/mcandle-core/src/lib.rs b/mcandle/mcandle-core/src/lib.rs new file mode 100644 index 000000000..9aafa9151 --- /dev/null +++ b/mcandle/mcandle-core/src/lib.rs @@ -0,0 +1,174 @@ +//! ML framework for Rust +//! +//! ```rust +//! use mcandle_core::{Tensor, DType, Device}; +//! # use mcandle_core::Error; +//! # fn main() -> Result<(), Error>{ +//! +//! let a = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((2, 3))?; +//! let b = Tensor::arange(0f32, 12f32, &Device::Cpu)?.reshape((3, 4))?; +//! +//! let c = a.matmul(&b)?; +//! # Ok(())} +//! ``` +//! +//! ## Features +//! +//! - Simple syntax (looks and feels like PyTorch) +//! - CPU and Cuda backends (and M1 support) +//! - Enable serverless (CPU) small and fast deployments +//! - Model training +//! - Distributed computing (NCCL). +//! - Models out of the box (Llama, Whisper, Falcon, ...) +//! +//! ## FAQ +//! +//! - Why Candle? +//! +//! Candle stems from the need to reduce binary size in order to *enable serverless* +//! possible by making the whole engine smaller than PyTorch very large library volume +//! +//! And simply *removing Python* from production workloads. +//! Python can really add overhead in more complex workflows and the [GIL](https://www.backblaze.com/blog/the-python-gil-past-present-and-future/) is a notorious source of headaches. +//! +//! Rust is cool, and a lot of the HF ecosystem already has Rust crates [safetensors](https://github.com/huggingface/safetensors) and [tokenizers](https://github.com/huggingface/tokenizers) +//! +//! ## Other Crates +//! +//! Candle consists of a number of crates. This crate holds core the common data structures but you may wish +//! to look at the docs for the other crates which can be found here: +//! +//! - [mcandle-core](https://docs.rs/mcandle-core/). Core Datastructures and DataTypes. +//! - [mcandle-nn](https://docs.rs/mcandle-nn/). Building blocks for Neural Nets. +//! - [mcandle-datasets](https://docs.rs/mcandle-datasets/). Rust access to commonly used Datasets like MNIST. +//! - [mcandle-examples](https://docs.rs/mcandle-examples/). Examples of Candle in Use. +//! - [mcandle-onnx](https://docs.rs/mcandle-onnx/). Loading and using ONNX models. +//! - [mcandle-pyo3](https://docs.rs/mcandle-pyo3/). Access to Candle from Python. +//! - [mcandle-transformers](https://docs.rs/mcandle-transformers/). Candle implemntation of many published transformer models. +//! + +#[cfg(feature = "accelerate")] +mod accelerate; +pub mod backend; +pub mod backprop; +pub mod conv; +mod convert; +pub mod cpu; +pub mod cpu_backend; +#[cfg(feature = "cuda")] +pub mod cuda_backend; +mod custom_op; +mod device; +pub mod display; +mod dtype; +pub mod dummy_cuda_backend; +mod dummy_metal_backend; +pub mod error; +mod indexer; +pub mod layout; +#[cfg(feature = "metal")] +pub mod metal_backend; +#[cfg(feature = "mkl")] +mod mkl; +pub mod npy; +pub mod op; +pub mod pickle; +pub mod quantized; +pub mod safetensors; +pub mod scalar; +pub mod shape; +mod sort; +mod storage; +pub mod streaming; +mod strided_index; +mod tensor; +mod tensor_cat; +mod tensor_indexing; +pub mod test_utils; +pub mod utils; +mod variable; + +#[cfg(feature = "cudnn")] +pub use cuda_backend::cudnn; + +pub use cpu_backend::{CpuStorage, CpuStorageRef}; +pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3, UgIOp1}; +pub use device::{Device, DeviceLocation, NdArray}; +pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType}; +pub use error::{Context, Error, Result}; +pub use indexer::{IndexOp, TensorIndexer}; +pub use layout::Layout; +pub use shape::{Shape, D}; +pub use storage::Storage; +pub use streaming::{StreamTensor, StreamingBinOp, StreamingModule}; +pub use strided_index::{StridedBlocks, StridedIndex}; +pub use tensor::{from_storage_no_op, Tensor, TensorId}; +pub use variable::Var; + +#[cfg(feature = "cuda")] +pub use cuda_backend as cuda; + +#[cfg(not(feature = "cuda"))] +pub use dummy_cuda_backend as cuda; + +pub use cuda::{CudaDevice, CudaStorage}; + +#[cfg(feature = "metal")] +pub use metal_backend::{MetalDevice, MetalError, MetalStorage}; + +#[cfg(not(feature = "metal"))] +pub use dummy_metal_backend::{MetalDevice, MetalError, MetalStorage}; + +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +pub trait ToUsize2 { + fn to_usize2(self) -> (usize, usize); +} + +impl ToUsize2 for usize { + fn to_usize2(self) -> (usize, usize) { + (self, self) + } +} + +impl ToUsize2 for (usize, usize) { + fn to_usize2(self) -> (usize, usize) { + self + } +} + +// A simple trait defining a module with forward method using a single argument. +pub trait Module { + fn forward(&self, xs: &Tensor) -> Result; +} + +impl Result> Module for T { + fn forward(&self, xs: &Tensor) -> Result { + self(xs) + } +} + +impl Module for Option<&M> { + fn forward(&self, xs: &Tensor) -> Result { + match self { + None => Ok(xs.clone()), + Some(m) => m.forward(xs), + } + } +} + +// A trait defining a module with forward method using a single tensor argument and a flag to +// separate the training and evaluation behaviors. +pub trait ModuleT { + fn forward_t(&self, xs: &Tensor, train: bool) -> Result; +} + +impl ModuleT for M { + fn forward_t(&self, xs: &Tensor, _train: bool) -> Result { + self.forward(xs) + } +} diff --git a/mcandle/mcandle-core/src/metal_backend/device.rs b/mcandle/mcandle-core/src/metal_backend/device.rs new file mode 100644 index 000000000..57a3df040 --- /dev/null +++ b/mcandle/mcandle-core/src/metal_backend/device.rs @@ -0,0 +1,348 @@ +use crate::{DType, Result}; +use mcandle_metal_kernels::Kernels; +use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; +use std::collections::HashMap; +use std::ffi::c_void; +use std::path::Path; +use std::sync::{Arc, Mutex, RwLock}; + +use super::MetalError; + +/// Unique identifier for cuda devices. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub struct DeviceId(usize); + +impl DeviceId { + pub(crate) fn new() -> Self { + // https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805 + use std::sync::atomic; + static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1); + Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed)) + } +} + +type BufferMap = HashMap<(NSUInteger, MTLResourceOptions), Vec>>; +pub(crate) struct Commands { + /// Single command queue for the entire device. + command_queue: CommandQueue, + /// One command buffer at a time. + /// The scheduler works by allowing multiple + /// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) + /// on a single command buffer. Using a single command buffer would be fastest on the GPU but + /// prevents overlapping of CPU and GPU commands (because command buffer needs to be committed + /// to start to work). + /// Despite what the documentation says, command buffers are NOT ordered. They are ordered + /// for their START time, but there's no guarantee that command buffer1 will finish before + /// command buffer2 starts (or there are metal bugs there) + command_buffer: CommandBuffer, + /// Keeps track of the current amount of compute command encoders on the current + /// command buffer + /// Arc, RwLock because of the interior mutability. + command_buffer_index: usize, + /// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc) + compute_per_buffer: usize, +} + +impl Commands { + pub(crate) fn new(command_queue: CommandQueue) -> Result { + let command_buffer = command_queue.new_command_buffer().to_owned(); + command_buffer.enqueue(); + let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") { + Ok(val) => val.parse()?, + _ => 50, + }; + Ok(Self { + command_queue, + command_buffer, + command_buffer_index: 0, + compute_per_buffer, + }) + } + + pub fn command_buffer(&mut self) -> Result<(bool, CommandBuffer)> { + let mut command_buffer = self.command_buffer.to_owned(); + let mut flushed = false; + if self.command_buffer_index > self.compute_per_buffer { + self.command_buffer.commit(); + command_buffer = self.command_queue.new_command_buffer().to_owned(); + self.command_buffer = command_buffer.clone(); + self.command_buffer_index = 0; + flushed = true; + } + self.command_buffer_index += 1; + Ok((flushed, command_buffer)) + } + + pub fn wait_until_completed(&mut self) -> Result<()> { + match self.command_buffer.status() { + metal::MTLCommandBufferStatus::Committed + | metal::MTLCommandBufferStatus::Scheduled + | metal::MTLCommandBufferStatus::Completed => { + panic!("Already committed"); + } + _ => {} + } + self.command_buffer.commit(); + self.command_buffer.wait_until_completed(); + self.command_buffer = self.command_queue.new_command_buffer().to_owned(); + + Ok(()) + } +} + +#[derive(Clone)] +pub struct MetalDevice { + /// Unique identifier, the registryID is not sufficient as it identifies the GPU rather than + /// the device itself. + pub(crate) id: DeviceId, + + /// Raw metal device: + pub(crate) device: metal::Device, + + pub(crate) commands: Arc>, + + /// Simple allocator struct. + /// The buffers are stored in size buckets since ML tends to use similar shapes over and over. + /// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting + /// (could be linked to FFI communication overhead). + /// + /// Whenever a buffer has a strong_count==1, we can reuse it, it means it was dropped in the + /// graph calculation, and only we the allocator kept a reference to it, therefore it's free + /// to be reused. However, in order for this to work, we need to guarantee the order of + /// operation, so that this buffer is not being used by another kernel at the same time. + /// Arc is the CPU reference count, it doesn't mean anything on the GPU side of things. + /// + /// Whenever we actually allocate a new buffer, we make a full sweep to clean up unused buffers + /// (strong_count = 1). + pub(crate) buffers: Arc>, + + /// Simple keeper struct to keep track of the already compiled kernels so we can reuse them. + /// Heavily used by [`mcandle_metal_kernels`] + pub(crate) kernels: Arc, + /// Seed for random number generation. + pub(crate) seed: Arc>, + /// Whether to use the MLX matmul kernels instead of the MFA ones. + pub(crate) use_mlx_mm: bool, + /// Value of the current seed + pub(crate) seed_value: Arc>, +} + +impl std::fmt::Debug for MetalDevice { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "MetalDevice({:?})", self.id) + } +} + +impl std::ops::Deref for MetalDevice { + type Target = metal::DeviceRef; + + fn deref(&self) -> &Self::Target { + &self.device + } +} + +impl MetalDevice { + pub fn set_use_mlx_mm(&mut self, use_mlx_mm: bool) { + self.use_mlx_mm = use_mlx_mm + } + + pub fn compile( + &self, + func_name: &'static str, + kernel: ug::lang::ssa::Kernel, + ) -> Result { + let mut buf = vec![]; + ug_metal::code_gen::gen(&mut buf, func_name, &kernel)?; + let metal_code = String::from_utf8(buf)?; + let lib = self + .device + .new_library_with_source(&metal_code, &metal::CompileOptions::new()) + .map_err(MetalError::from)?; + let func = lib + .get_function(func_name, None) + .map_err(MetalError::from)?; + let pl = self + .device + .new_compute_pipeline_state_with_function(&func) + .map_err(MetalError::from)?; + Ok(pl) + } + + pub fn id(&self) -> DeviceId { + self.id + } + + pub fn metal_device(&self) -> &metal::Device { + &self.device + } + + fn drop_unused_buffers(&self) -> Result<()> { + let mut buffers = self.buffers.write().map_err(MetalError::from)?; + for subbuffers in buffers.values_mut() { + let newbuffers = subbuffers + .iter() + .filter(|s| Arc::strong_count(*s) > 1) + .map(Arc::clone) + .collect(); + *subbuffers = newbuffers; + } + Ok(()) + } + + pub fn command_buffer(&self) -> Result { + let mut commands = self.commands.write().map_err(MetalError::from)?; + let (flushed, command_buffer) = commands.command_buffer()?; + if flushed { + self.drop_unused_buffers()? + } + Ok(command_buffer) + } + + pub fn wait_until_completed(&self) -> Result<()> { + let mut commands = self.commands.write().map_err(MetalError::from)?; + commands.wait_until_completed() + } + + pub fn kernels(&self) -> &Kernels { + &self.kernels + } + + pub fn device(&self) -> &metal::Device { + &self.device + } + + /// Creates a new buffer (not necessarily zeroed). + /// The buffer is [MTLPrivate](https://developer.apple.com/documentation/metal/mtlstoragemode) + /// This means the buffer data cannot be read on the CPU directly. + /// + /// [`name`] is only used to keep track of the resource origin in case of bugs + pub fn new_buffer( + &self, + element_count: usize, + dtype: DType, + name: &str, + ) -> Result> { + let size = (element_count * dtype.size_in_bytes()) as NSUInteger; + self.allocate_buffer(size, MTLResourceOptions::StorageModeShared, name) + } + + /// Creates a new buffer (not necessarily zeroed). + /// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode) + /// This means the buffer can be read on the CPU but will require manual + /// synchronization when the CPU memory is modified + /// Used as a bridge to gather data back from the GPU + pub fn new_buffer_managed(&self, size: NSUInteger) -> Result> { + self.allocate_buffer(size, MTLResourceOptions::StorageModeShared, "managed") + } + + /// Creates a new buffer from data. + /// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode) + /// + /// Does not require synchronization, as [newBufferWithBytes](https://developer.apple.com/documentation/metal/mtldevice/1433429-newbufferwithbytes) + /// allocates the buffer and copies over the existing data before returning the MTLBuffer. + pub fn new_buffer_with_data(&self, data: &[T]) -> Result> { + let size = core::mem::size_of_val(data) as NSUInteger; + let new_buffer = self.device.new_buffer_with_data( + data.as_ptr() as *const c_void, + size, + MTLResourceOptions::StorageModeShared, + ); + let mut buffers = self.buffers.write().map_err(MetalError::from)?; + + let subbuffers = buffers + .entry((size, MTLResourceOptions::StorageModeShared)) + .or_insert(vec![]); + + let new_buffer = Arc::new(new_buffer); + subbuffers.push(new_buffer.clone()); + Ok(new_buffer) + } + + pub fn allocate_zeros(&self, size_in_bytes: usize) -> Result> { + let buffer = self.allocate_buffer( + size_in_bytes as NSUInteger, + MTLResourceOptions::StorageModeShared, + "allocate_zeros", + )?; + let command_buffer = self.command_buffer()?; + command_buffer.set_label("zeros"); + let blit = command_buffer.new_blit_command_encoder(); + blit.fill_buffer( + &buffer, + metal::NSRange { + location: 0, + length: buffer.length(), + }, + 0, + ); + blit.end_encoding(); + Ok(buffer) + } + + /// The critical allocator algorithm + fn allocate_buffer( + &self, + size: NSUInteger, + option: MTLResourceOptions, + _name: &str, + ) -> Result> { + let mut buffers = self.buffers.write().map_err(MetalError::from)?; + if let Some(b) = find_available_buffer(size, option, &buffers) { + // Cloning also ensures we increment the strong count + return Ok(b.clone()); + } + + let size = buf_size(size); + let subbuffers = buffers.entry((size, option)).or_insert(vec![]); + + let new_buffer = self.device.new_buffer(size as NSUInteger, option); + let new_buffer = Arc::new(new_buffer); + subbuffers.push(new_buffer.clone()); + + Ok(new_buffer) + } + + /// Create a metal GPU capture trace on [`path`]. + pub fn capture>(&self, path: P) -> Result<()> { + let capture = metal::CaptureManager::shared(); + let descriptor = metal::CaptureDescriptor::new(); + descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument); + descriptor.set_capture_device(self); + // The [set_output_url] call requires an absolute path so we convert it if needed. + if path.as_ref().is_absolute() { + descriptor.set_output_url(path); + } else { + let path = std::env::current_dir()?.join(path); + descriptor.set_output_url(path); + } + + capture + .start_capture(&descriptor) + .map_err(MetalError::from)?; + Ok(()) + } +} + +fn buf_size(size: NSUInteger) -> NSUInteger { + size.saturating_sub(1).next_power_of_two() as NSUInteger +} + +fn find_available_buffer( + size: NSUInteger, + option: MTLResourceOptions, + buffers: &BufferMap, +) -> Option> { + let mut best_buffer: Option<&Arc> = None; + let mut best_buffer_size: NSUInteger = NSUInteger::MAX; + for ((buffer_size, buffer_option), subbuffers) in buffers.iter() { + if buffer_size >= &size && buffer_size < &best_buffer_size && buffer_option == &option { + for sub in subbuffers { + if Arc::strong_count(sub) == 1 { + best_buffer = Some(sub); + best_buffer_size = *buffer_size; + } + } + } + } + best_buffer.cloned() +} diff --git a/mcandle/mcandle-core/src/metal_backend/mod.rs b/mcandle/mcandle-core/src/metal_backend/mod.rs new file mode 100644 index 000000000..cac81eda9 --- /dev/null +++ b/mcandle/mcandle-core/src/metal_backend/mod.rs @@ -0,0 +1,2297 @@ +use crate::backend::{BackendDevice, BackendStorage}; +use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D}; +use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; +use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape}; +use mcandle_metal_kernels::{BufferOffset, CallConvTranspose2dCfg, Kernels}; +use metal::{Buffer, MTLResourceOptions, NSUInteger}; +use std::collections::HashMap; +use std::ffi::c_void; +use std::sync::{Arc, Mutex, PoisonError, RwLock, TryLockError}; + +mod device; +pub use device::{DeviceId, MetalDevice}; + +pub fn buffer_o<'a>(buffer: &'a Buffer, l: &Layout, dtype: DType) -> BufferOffset<'a> { + BufferOffset { + buffer, + offset_in_bytes: l.start_offset() * dtype.size_in_bytes(), + } +} +/// Simple way to catch lock error without +/// depending on T +#[derive(thiserror::Error, Debug)] +pub enum LockError { + #[error("{0}")] + Poisoned(String), + #[error("Would block")] + WouldBlock, +} + +impl From> for MetalError { + fn from(value: TryLockError) -> Self { + match value { + TryLockError::Poisoned(p) => MetalError::LockError(LockError::Poisoned(p.to_string())), + TryLockError::WouldBlock => MetalError::LockError(LockError::WouldBlock), + } + } +} + +impl From> for MetalError { + fn from(p: PoisonError) -> Self { + MetalError::LockError(LockError::Poisoned(p.to_string())) + } +} + +/// Metal related errors +#[derive(thiserror::Error, Debug)] +pub enum MetalError { + #[error("{0}")] + Message(String), + #[error(transparent)] + KernelError(#[from] mcandle_metal_kernels::MetalKernelError), + #[error("{0:?}")] + LockError(LockError), + #[error("{msg}, expected: {expected:?}, got: {got:?}")] + UnexpectedDType { + msg: &'static str, + expected: DType, + got: DType, + }, +} + +impl From for MetalError { + fn from(e: String) -> Self { + MetalError::Message(e) + } +} + +#[derive(Debug, Clone)] +pub struct MetalStorage { + /// The actual buffer containing the data. + buffer: Arc, + /// a reference to the device owning this buffer + device: MetalDevice, + /// The count of allocated elements in the buffer + count: usize, + /// The dtype is kept since buffers are untyped. + dtype: DType, +} + +impl BackendStorage for MetalStorage { + type Device = MetalDevice; + + fn try_clone(&self, _: &Layout) -> Result { + Ok(self.clone()) + } + + fn dtype(&self) -> DType { + self.dtype + } + + fn device(&self) -> &Self::Device { + &self.device + } + + fn to_cpu_storage(&self) -> Result { + match self.dtype { + DType::U8 => Ok(CpuStorage::U8(self.to_cpu()?)), + DType::U32 => Ok(CpuStorage::U32(self.to_cpu()?)), + DType::I16 => Ok(CpuStorage::I16(self.to_cpu()?)), + DType::I32 => Ok(CpuStorage::I32(self.to_cpu()?)), + DType::I64 => Ok(CpuStorage::I64(self.to_cpu()?)), + DType::F16 => Ok(CpuStorage::F16(self.to_cpu()?)), + DType::BF16 => Ok(CpuStorage::BF16(self.to_cpu()?)), + DType::F32 => Ok(CpuStorage::F32(self.to_cpu()?)), + DType::F64 => Ok(CpuStorage::F64(self.to_cpu()?)), + DType::F8E4M3 => Ok(CpuStorage::F64(self.to_cpu()?)), + } + } + + fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result { + let device = self.device().clone(); + + let shape = layout.shape(); + let el = shape.elem_count(); + let dtype = self.dtype; + + let buffer = device.new_buffer(el, self.dtype, "affine")?; + let command_buffer = self.device.command_buffer()?; + let src = buffer_o(&self.buffer, layout, dtype); + if layout.is_contiguous() { + let name = match self.dtype { + DType::F32 => "affine_f32", + DType::F16 => "affine_f16", + DType::BF16 => "affine_bf16", + DType::U8 => "affine_u8", + DType::U32 => "affine_u32", + dtype => crate::bail!("Metal contiguous affine {dtype:?} not implemented"), + }; + mcandle_metal_kernels::call_affine( + &device.device, + &command_buffer, + &device.kernels, + name, + el, + src, + &buffer, + mul as f32, + add as f32, + ) + .map_err(MetalError::from)?; + } else { + let name = match self.dtype { + DType::F32 => "affine_f32_strided", + DType::F16 => "affine_f16_strided", + DType::BF16 => "affine_bf16_strided", + dtype => crate::bail!("Metal strided affine {dtype:?} not implemented"), + }; + mcandle_metal_kernels::call_affine_strided( + &device.device, + &command_buffer, + &device.kernels, + name, + layout.dims(), + src, + layout.stride(), + &buffer, + mul as f32, + add as f32, + ) + .map_err(MetalError::from)?; + } + Ok(Self::new(buffer, device.clone(), el, dtype)) + } + + fn powf(&self, layout: &Layout, pow: f64) -> Result { + let device = self.device().clone(); + + let shape = layout.shape(); + let el = shape.elem_count(); + let dtype = self.dtype; + + let buffer = device.new_buffer(el, self.dtype, "powf")?; + let command_buffer = self.device.command_buffer()?; + let src = buffer_o(&self.buffer, layout, dtype); + if layout.is_contiguous() { + let name = match self.dtype { + DType::F32 => "powf_f32", + DType::F16 => "powf_f16", + DType::BF16 => "powf_bf16", + dtype => crate::bail!("Metal contiguous powf {dtype:?} not implemented"), + }; + mcandle_metal_kernels::call_powf( + &device.device, + &command_buffer, + &device.kernels, + name, + el, + src, + &buffer, + pow as f32, + ) + .map_err(MetalError::from)?; + } else { + let name = match self.dtype { + DType::F32 => "powf_f32_strided", + DType::F16 => "powf_f16_strided", + DType::BF16 => "powf_bf16_strided", + dtype => crate::bail!("Metal strided powf {dtype:?} not implemented"), + }; + mcandle_metal_kernels::call_powf_strided( + &device.device, + &command_buffer, + &device.kernels, + name, + layout.dims(), + src, + layout.stride(), + &buffer, + pow as f32, + ) + .map_err(MetalError::from)?; + } + Ok(Self::new(buffer, device.clone(), el, dtype)) + } + + fn elu(&self, layout: &Layout, alpha: f64) -> Result { + let device = self.device().clone(); + + let shape = layout.shape(); + let el = shape.elem_count(); + let dtype = self.dtype; + + let buffer = device.new_buffer(el, self.dtype, "elu")?; + let command_buffer = self.device.command_buffer()?; + let src = buffer_o(&self.buffer, layout, self.dtype); + if layout.is_contiguous() { + let name = match self.dtype { + DType::F32 => "elu_f32", + DType::F16 => "elu_f16", + DType::BF16 => "elu_bf16", + dtype => crate::bail!("Metal contiguous elu {dtype:?} not implemented"), + }; + mcandle_metal_kernels::call_elu( + &device.device, + &command_buffer, + &device.kernels, + name, + el, + src, + &buffer, + alpha as f32, + ) + .map_err(MetalError::from)?; + } else { + let name = match self.dtype { + DType::F32 => "elu_f32_strided", + DType::F16 => "elu_f16_strided", + DType::BF16 => "elu_bf16_strided", + dtype => crate::bail!("Metal strided elu {dtype:?} not implemented"), + }; + mcandle_metal_kernels::call_elu_strided( + &device.device, + &command_buffer, + &device.kernels, + name, + layout.dims(), + src, + layout.stride(), + &buffer, + alpha as f32, + ) + .map_err(MetalError::from)?; + } + Ok(Self::new(buffer, device.clone(), el, dtype)) + } + + fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result { + let device = self.device.clone(); + let src_stride = layout.stride(); + let src_dims = layout.shape().dims(); + // Source dims and strides with the sum dims at the end. + let mut dims = vec![]; + let mut stride = vec![]; + let mut dst_el: usize = 1; + for (dim_idx, &d) in src_dims.iter().enumerate() { + if !sum_dims.contains(&dim_idx) { + dst_el *= d; + dims.push(d); + stride.push(src_stride[dim_idx]); + } + } + for &dim_idx in sum_dims.iter() { + dims.push(src_dims[dim_idx]); + stride.push(src_stride[dim_idx]); + } + + // The reduction loop requires the shared array to be properly initialized and for + // this we want the number of threads to be a power of two. + let (name, check_empty, return_index) = match (op, self.dtype) { + (ReduceOp::Sum, DType::F32) => ("fast_sum_f32_strided", false, false), + (ReduceOp::Min, DType::F32) => ("fast_min_f32_strided", true, false), + (ReduceOp::Max, DType::F32) => ("fast_max_f32_strided", true, false), + (ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32_strided", true, true), + (ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32_strided", true, true), + (ReduceOp::Sum, DType::U32) => ("fast_sum_u32_strided", false, false), + (ReduceOp::Min, DType::U32) => ("fast_min_u32_strided", true, false), + (ReduceOp::Max, DType::U32) => ("fast_max_u32_strided", true, false), + (ReduceOp::ArgMin, DType::U32) => ("fast_argmin_u32_strided", true, true), + (ReduceOp::ArgMax, DType::U32) => ("fast_argmax_u32_strided", true, true), + (ReduceOp::Sum, DType::F16) => ("fast_sum_f16_strided", false, false), + (ReduceOp::Min, DType::F16) => ("fast_min_f16_strided", true, false), + (ReduceOp::Max, DType::F16) => ("fast_max_f16_strided", true, false), + (ReduceOp::ArgMin, DType::F16) => ("fast_argmin_f16_strided", true, true), + (ReduceOp::ArgMax, DType::F16) => ("fast_argmax_f16_strided", true, true), + (ReduceOp::Sum, DType::BF16) => ("fast_sum_bf16_strided", false, false), + (ReduceOp::Min, DType::BF16) => ("fast_min_bf16_strided", true, false), + (ReduceOp::Max, DType::BF16) => ("fast_max_bf16_strided", true, false), + (ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16_strided", true, true), + (ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16_strided", true, true), + (ReduceOp::Sum, DType::I16) => ("fast_sum_i16_strided", false, false), + (ReduceOp::Min, DType::I16) => ("fast_min_i16_strided", true, false), + (ReduceOp::Max, DType::I16) => ("fast_max_i16_strided", true, false), + (ReduceOp::ArgMin, DType::I16) => ("fast_argmin_i16_strided", true, true), + (ReduceOp::ArgMax, DType::I16) => ("fast_argmax_i16_strided", true, true), + (ReduceOp::Sum, DType::I32) => ("fast_sum_i32_strided", false, false), + (ReduceOp::Min, DType::I32) => ("fast_min_i32_strided", true, false), + (ReduceOp::Max, DType::I32) => ("fast_max_i32_strided", true, false), + (ReduceOp::ArgMin, DType::I32) => ("fast_argmin_i32_strided", true, true), + (ReduceOp::ArgMax, DType::I32) => ("fast_argmax_i32_strided", true, true), + (ReduceOp::Sum, DType::I64) => ("fast_sum_i64_strided", false, false), + (ReduceOp::Min, DType::I64) => ("fast_min_i64_strided", true, false), + (ReduceOp::Max, DType::I64) => ("fast_max_i64_strided", true, false), + (ReduceOp::ArgMin, DType::I64) => ("fast_argmin_i64_strided", true, true), + (ReduceOp::ArgMax, DType::I64) => ("fast_argmax_i64_strided", true, true), + (ReduceOp::Sum, DType::U8) => ("fast_sum_u8_strided", false, false), + (ReduceOp::Min, DType::U8) => ("fast_min_u8_strided", true, false), + (ReduceOp::Max, DType::U8) => ("fast_max_u8_strided", true, false), + (ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8_strided", true, true), + (ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8_strided", true, true), + (k, dtype) => crate::bail!("Metal reduce op {k:?} {dtype:?} not implemented"), + }; + if check_empty && layout.shape().elem_count() == 0 { + Err(crate::Error::EmptyTensor { op: "reduce" }.bt())? + } + let dtype = if return_index { DType::U32 } else { self.dtype }; + let buffer = device.new_buffer(dst_el, dtype, "reduce")?; + let command_buffer = self.device.command_buffer()?; + let src = buffer_o(&self.buffer, layout, self.dtype); + mcandle_metal_kernels::call_reduce_strided( + &device.device, + &command_buffer, + &device.kernels, + name, + &dims, + &stride, + dst_el, + src, + &buffer, + ) + .map_err(MetalError::from)?; + + Ok(Self::new(buffer, device, dst_el, dtype)) + } + + fn cmp(&self, op: CmpOp, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result { + let name = match op { + CmpOp::Eq => "eq", + CmpOp::Ne => "ne", + CmpOp::Le => "le", + CmpOp::Ge => "ge", + CmpOp::Lt => "lt", + CmpOp::Gt => "gt", + }; + self.binary(name, rhs, lhs_l, rhs_l) + } + + fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result { + let device = self.device(); + let shape = layout.shape(); + let el_count = shape.elem_count(); + let buffer = device.new_buffer(el_count, dtype, "todtype")?; + let command_buffer = device.command_buffer()?; + let src = buffer_o(&self.buffer, layout, self.dtype); + if layout.is_contiguous() { + let kernel_name = match (self.dtype, dtype) { + (DType::U32, DType::BF16) => "cast_u32_bf16", + (DType::U32, DType::F16) => "cast_u32_f16", + (DType::U32, DType::F32) => "cast_u32_f32", + (DType::U32, DType::I16) => "cast_u32_i16", + (DType::U32, DType::I32) => "cast_u32_i32", + (DType::U32, DType::I64) => "cast_u32_i64", + (DType::U32, DType::U8) => "cast_u32_u8", + + (DType::U8, DType::BF16) => "cast_u8_bf16", + (DType::U8, DType::F16) => "cast_u8_f16", + (DType::U8, DType::F32) => "cast_u8_f32", + (DType::U8, DType::I16) => "cast_u8_i16", + (DType::U8, DType::I32) => "cast_u8_i32", + (DType::U8, DType::I64) => "cast_u8_i64", + (DType::U8, DType::U32) => "cast_u8_u32", + + (DType::F32, DType::BF16) => "cast_f32_bf16", + (DType::F32, DType::F16) => "cast_f32_f16", + (DType::F32, DType::I16) => "cast_f32_i16", + (DType::F32, DType::I32) => "cast_f32_i32", + (DType::F32, DType::I64) => "cast_f32_i64", + (DType::F32, DType::U32) => "cast_f32_u32", + (DType::F32, DType::U8) => "cast_f32_u8", + + (DType::I16, DType::BF16) => "cast_i16_bf16", + (DType::I16, DType::F16) => "cast_i16_f16", + (DType::I16, DType::F32) => "cast_i16_f32", + (DType::I16, DType::U32) => "cast_i16_u32", + (DType::I16, DType::U8) => "cast_i16_u8", + + (DType::I32, DType::BF16) => "cast_i32_bf16", + (DType::I32, DType::F16) => "cast_i32_f16", + (DType::I32, DType::F32) => "cast_i32_f32", + (DType::I32, DType::U32) => "cast_i32_u32", + (DType::I32, DType::U8) => "cast_i32_u8", + + (DType::I64, DType::BF16) => "cast_i64_bf16", + (DType::I64, DType::F16) => "cast_i64_f16", + (DType::I64, DType::F32) => "cast_i64_f32", + (DType::I64, DType::U32) => "cast_i64_u32", + (DType::I64, DType::U8) => "cast_i64_u8", + + (DType::F16, DType::BF16) => "cast_f16_bf16", + (DType::F16, DType::F32) => "cast_f16_f32", + (DType::F16, DType::I16) => "cast_f16_i16", + (DType::F16, DType::I32) => "cast_f16_i32", + (DType::F16, DType::I64) => "cast_f16_i64", + (DType::F16, DType::U32) => "cast_f16_u32", + (DType::F16, DType::U8) => "cast_f16_u8", + + (DType::BF16, DType::F16) => "cast_bf16_f16", + (DType::BF16, DType::F32) => "cast_bf16_f32", + (DType::BF16, DType::I16) => "cast_bf16_i16", + (DType::BF16, DType::I32) => "cast_bf16_i32", + (DType::BF16, DType::I64) => "cast_bf16_i64", + (DType::BF16, DType::U32) => "cast_bf16_u32", + (DType::BF16, DType::U8) => "cast_bf16_u8", + + (left, right) => { + crate::bail!("Metal contiguous to_dtype {left:?} {right:?} not implemented") + } + }; + mcandle_metal_kernels::call_cast_contiguous( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + el_count, + src, + &buffer, + ) + .map_err(MetalError::from)?; + } else { + let kernel_name = match (self.dtype, dtype) { + (DType::BF16, DType::F16) => "cast_bf16_f16_strided", + (DType::BF16, DType::F32) => "cast_bf16_f32_strided", + (DType::BF16, DType::I64) => "cast_bf16_i64_strided", + (DType::BF16, DType::U32) => "cast_bf16_u32_strided", + (DType::BF16, DType::U8) => "cast_bf16_u8_strided", + + (DType::F16, DType::BF16) => "cast_f16_bf16_strided", + (DType::F16, DType::F32) => "cast_f16_f32_strided", + (DType::F16, DType::I64) => "cast_f16_i64_strided", + (DType::F16, DType::U32) => "cast_f16_u32_strided", + (DType::F16, DType::U8) => "cast_f16_u8_strided", + + (DType::F32, DType::BF16) => "cast_f32_bf16_strided", + (DType::F32, DType::F16) => "cast_f32_f16_strided", + (DType::F32, DType::I64) => "cast_f32_i64_strided", + (DType::F32, DType::U32) => "cast_f32_u32_strided", + (DType::F32, DType::U8) => "cast_f32_u8_strided", + + (DType::I64, DType::F32) => "cast_i64_f32_strided", + (DType::I64, DType::BF16) => "cast_i64_bf16_strided", + (DType::I64, DType::F16) => "cast_i64_f16_strided", + (DType::I64, DType::U32) => "cast_i64_u32_strided", + (DType::I64, DType::U8) => "cast_i64_u8_strided", + + (DType::U32, DType::BF16) => "cast_u32_bf16_strided", + (DType::U32, DType::F16) => "cast_u32_f16_strided", + (DType::U32, DType::F32) => "cast_u32_f32_strided", + (DType::U32, DType::U8) => "cast_u32_u8_strided", + (DType::U32, DType::I16) => "cast_u32_i16_strided", + (DType::U32, DType::I32) => "cast_u32_i32_strided", + (DType::U32, DType::I64) => "cast_u32_i64_strided", + + (DType::U8, DType::BF16) => "cast_u8_bf16_strided", + (DType::U8, DType::F16) => "cast_u8_f16_strided", + (DType::U8, DType::F32) => "cast_u8_f32_strided", + (DType::U8, DType::I16) => "cast_u8_i16_strided", + (DType::U8, DType::I32) => "cast_u8_i32_strided", + (DType::U8, DType::I64) => "cast_u8_i64_strided", + (DType::U8, DType::U32) => "cast_u8_u32_strided", + + (DType::I16, DType::F32) => "cast_i16_f32_strided", + (DType::I32, DType::F32) => "cast_i32_f32_strided", + + (left, right) => { + crate::bail!("Metal strided to_dtype {left:?} {right:?} not implemented") + } + }; + mcandle_metal_kernels::call_cast_strided( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + layout.dims(), + src, + layout.stride(), + &buffer, + ) + .map_err(MetalError::from)?; + } + command_buffer.set_label("to_dtype"); + Ok(Self::new(buffer, device.clone(), el_count, dtype)) + } + + fn unary_impl(&self, layout: &Layout) -> Result { + let device = self.device(); + let dtype = self.dtype; + let shape = layout.shape(); + let el_count = shape.elem_count(); + let buffer = device.new_buffer(el_count, dtype, B::KERNEL)?; + let command_buffer = device.command_buffer()?; + command_buffer.set_label(B::KERNEL); + let src = buffer_o(&self.buffer, layout, self.dtype); + + match (el_count % 2, dtype, layout.is_contiguous()) { + (0, DType::BF16 | DType::F16, true) => { + use mcandle_metal_kernels::unary::contiguous_tiled; + let kernel_name = match (B::KERNEL, dtype) { + ("uabs", DType::F16) => contiguous_tiled::abs::HALF, + ("uabs", DType::F32) => contiguous_tiled::abs::FLOAT, + ("uabs", DType::BF16) => contiguous_tiled::abs::BFLOAT, + ("uceil", DType::F16) => contiguous_tiled::ceil::HALF, + ("uceil", DType::F32) => contiguous_tiled::ceil::FLOAT, + ("uceil", DType::BF16) => contiguous_tiled::ceil::BFLOAT, + ("ucos", DType::F16) => contiguous_tiled::cos::HALF, + ("ucos", DType::F32) => contiguous_tiled::cos::FLOAT, + ("ucos", DType::BF16) => contiguous_tiled::cos::BFLOAT, + ("uerf", DType::F16) => contiguous_tiled::erf::HALF, + ("uerf", DType::F32) => contiguous_tiled::erf::FLOAT, + ("uerf", DType::BF16) => contiguous_tiled::erf::BFLOAT, + ("uexp", DType::F16) => contiguous_tiled::exp::HALF, + ("uexp", DType::F32) => contiguous_tiled::exp::FLOAT, + ("uexp", DType::BF16) => contiguous_tiled::exp::BFLOAT, + ("ufloor", DType::F16) => contiguous_tiled::floor::HALF, + ("ufloor", DType::F32) => contiguous_tiled::floor::FLOAT, + ("ufloor", DType::BF16) => contiguous_tiled::floor::BFLOAT, + ("ugelu_erf", DType::F16) => contiguous_tiled::gelu_erf::HALF, + ("ugelu_erf", DType::F32) => contiguous_tiled::gelu_erf::FLOAT, + ("ugelu_erf", DType::BF16) => contiguous_tiled::gelu_erf::BFLOAT, + ("ugelu", DType::F16) => contiguous_tiled::gelu::HALF, + ("ugelu", DType::F32) => contiguous_tiled::gelu::FLOAT, + ("ugelu", DType::BF16) => contiguous_tiled::gelu::BFLOAT, + ("ulog", DType::F16) => contiguous_tiled::log::HALF, + ("ulog", DType::F32) => contiguous_tiled::log::FLOAT, + ("ulog", DType::BF16) => contiguous_tiled::log::BFLOAT, + ("uneg", DType::F16) => contiguous_tiled::neg::HALF, + ("uneg", DType::F32) => contiguous_tiled::neg::FLOAT, + ("uneg", DType::BF16) => contiguous_tiled::neg::BFLOAT, + ("urecip", DType::F16) => contiguous_tiled::recip::HALF, + ("urecip", DType::F32) => contiguous_tiled::recip::FLOAT, + ("urecip", DType::BF16) => contiguous_tiled::recip::BFLOAT, + ("urelu", DType::F16) => contiguous_tiled::relu::HALF, + ("urelu", DType::F32) => contiguous_tiled::relu::FLOAT, + ("urelu", DType::BF16) => contiguous_tiled::relu::BFLOAT, + ("uround", DType::F16) => contiguous_tiled::round::HALF, + ("uround", DType::F32) => contiguous_tiled::round::FLOAT, + ("uround", DType::BF16) => contiguous_tiled::round::BFLOAT, + ("usilu", DType::F16) => contiguous_tiled::silu::HALF, + ("usilu", DType::F32) => contiguous_tiled::silu::FLOAT, + ("usilu", DType::BF16) => contiguous_tiled::silu::BFLOAT, + ("usin", DType::F16) => contiguous_tiled::sin::HALF, + ("usin", DType::F32) => contiguous_tiled::sin::FLOAT, + ("usin", DType::BF16) => contiguous_tiled::sin::BFLOAT, + ("usqr", DType::F16) => contiguous_tiled::sqr::HALF, + ("usqr", DType::F32) => contiguous_tiled::sqr::FLOAT, + ("usqr", DType::BF16) => contiguous_tiled::sqr::BFLOAT, + ("usqrt", DType::F16) => contiguous_tiled::sqrt::HALF, + ("usqrt", DType::F32) => contiguous_tiled::sqrt::FLOAT, + ("usqrt", DType::BF16) => contiguous_tiled::sqrt::BFLOAT, + ("utanh", DType::F16) => contiguous_tiled::tanh::HALF, + ("utanh", DType::F32) => contiguous_tiled::tanh::FLOAT, + ("utanh", DType::BF16) => contiguous_tiled::tanh::BFLOAT, + ("usign", DType::F16) => contiguous_tiled::sign::HALF, + ("usign", DType::F32) => contiguous_tiled::sign::FLOAT, + ("usign", DType::BF16) => contiguous_tiled::sign::BFLOAT, + ("usign", DType::I16) => contiguous_tiled::sign::I16, + ("usign", DType::I32) => contiguous_tiled::sign::I32, + ("usign", DType::I64) => contiguous_tiled::sign::I64, + (name, dtype) => { + crate::bail!( + "Metal contiguous_tiled unary {name} {dtype:?} not implemented" + ) + } + }; + mcandle_metal_kernels::call_unary_contiguous_tiled( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + el_count, + src, + &buffer, + ) + .map_err(MetalError::from)?; + } + (_, _, true) => { + use mcandle_metal_kernels::unary::contiguous; + let kernel_name = match (B::KERNEL, dtype) { + ("uabs", DType::F16) => contiguous::abs::HALF, + ("uabs", DType::F32) => contiguous::abs::FLOAT, + ("uabs", DType::BF16) => contiguous::abs::BFLOAT, + ("uceil", DType::F16) => contiguous::ceil::HALF, + ("uceil", DType::F32) => contiguous::ceil::FLOAT, + ("uceil", DType::BF16) => contiguous::ceil::BFLOAT, + ("ucos", DType::F16) => contiguous::cos::HALF, + ("ucos", DType::F32) => contiguous::cos::FLOAT, + ("ucos", DType::BF16) => contiguous::cos::BFLOAT, + ("uerf", DType::F16) => contiguous::erf::HALF, + ("uerf", DType::F32) => contiguous::erf::FLOAT, + ("uerf", DType::BF16) => contiguous::erf::BFLOAT, + ("uexp", DType::F16) => contiguous::exp::HALF, + ("uexp", DType::F32) => contiguous::exp::FLOAT, + ("uexp", DType::BF16) => contiguous::exp::BFLOAT, + ("ufloor", DType::F16) => contiguous::floor::HALF, + ("ufloor", DType::F32) => contiguous::floor::FLOAT, + ("ufloor", DType::BF16) => contiguous::floor::BFLOAT, + ("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF, + ("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT, + ("ugelu_erf", DType::BF16) => contiguous::gelu_erf::BFLOAT, + ("ugelu", DType::F16) => contiguous::gelu::HALF, + ("ugelu", DType::F32) => contiguous::gelu::FLOAT, + ("ugelu", DType::BF16) => contiguous::gelu::BFLOAT, + ("ulog", DType::F16) => contiguous::log::HALF, + ("ulog", DType::F32) => contiguous::log::FLOAT, + ("ulog", DType::BF16) => contiguous::log::BFLOAT, + ("uneg", DType::F16) => contiguous::neg::HALF, + ("uneg", DType::F32) => contiguous::neg::FLOAT, + ("uneg", DType::BF16) => contiguous::neg::BFLOAT, + ("urecip", DType::F16) => contiguous::recip::HALF, + ("urecip", DType::F32) => contiguous::recip::FLOAT, + ("urecip", DType::BF16) => contiguous::recip::BFLOAT, + ("urelu", DType::F16) => contiguous::relu::HALF, + ("urelu", DType::F32) => contiguous::relu::FLOAT, + ("urelu", DType::BF16) => contiguous::relu::BFLOAT, + ("uround", DType::F16) => contiguous::round::HALF, + ("uround", DType::F32) => contiguous::round::FLOAT, + ("uround", DType::BF16) => contiguous::round::BFLOAT, + ("usilu", DType::F16) => contiguous::silu::HALF, + ("usilu", DType::F32) => contiguous::silu::FLOAT, + ("usilu", DType::BF16) => contiguous::silu::BFLOAT, + ("usin", DType::F16) => contiguous::sin::HALF, + ("usin", DType::F32) => contiguous::sin::FLOAT, + ("usin", DType::BF16) => contiguous::sin::BFLOAT, + ("usqr", DType::F16) => contiguous::sqr::HALF, + ("usqr", DType::F32) => contiguous::sqr::FLOAT, + ("usqr", DType::BF16) => contiguous::sqr::BFLOAT, + ("usqrt", DType::F16) => contiguous::sqrt::HALF, + ("usqrt", DType::F32) => contiguous::sqrt::FLOAT, + ("usqrt", DType::BF16) => contiguous::sqrt::BFLOAT, + ("utanh", DType::F16) => contiguous::tanh::HALF, + ("utanh", DType::F32) => contiguous::tanh::FLOAT, + ("utanh", DType::BF16) => contiguous::tanh::BFLOAT, + ("usign", DType::F16) => contiguous::sign::HALF, + ("usign", DType::F32) => contiguous::sign::FLOAT, + ("usign", DType::BF16) => contiguous::sign::BFLOAT, + ("usign", DType::I16) => contiguous::sign::I16, + ("usign", DType::I32) => contiguous::sign::I32, + ("usign", DType::I64) => contiguous::sign::I64, + (name, dtype) => { + crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented") + } + }; + mcandle_metal_kernels::call_unary_contiguous( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + el_count, + src, + &buffer, + ) + .map_err(MetalError::from)?; + } + (_, _, false) => { + use mcandle_metal_kernels::unary::strided; + let kernel_name = match (B::KERNEL, dtype) { + ("ucos", DType::F32) => strided::cos::FLOAT, + ("usin", DType::F32) => strided::sin::FLOAT, + ("usqr", DType::F32) => strided::sqr::FLOAT, + ("usqrt", DType::F32) => strided::sqrt::FLOAT, + ("uneg", DType::F32) => strided::neg::FLOAT, + ("uexp", DType::F32) => strided::exp::FLOAT, + ("ulog", DType::F32) => strided::log::FLOAT, + ("ugelu", DType::F32) => strided::gelu::FLOAT, + ("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT, + ("uerf", DType::F32) => strided::erf::FLOAT, + ("usilu", DType::F32) => strided::silu::FLOAT, + ("uabs", DType::F32) => strided::abs::FLOAT, + ("uceil", DType::F32) => strided::ceil::FLOAT, + ("ufloor", DType::F32) => strided::floor::FLOAT, + ("urelu", DType::F32) => strided::relu::FLOAT, + ("uround", DType::F32) => strided::round::FLOAT, + ("utanh", DType::F32) => strided::tanh::FLOAT, + + ("ucos", DType::F16) => strided::cos::HALF, + ("usin", DType::F16) => strided::sin::HALF, + ("usqr", DType::F16) => strided::sqr::HALF, + ("usqrt", DType::F16) => strided::sqrt::HALF, + ("uneg", DType::F16) => strided::neg::HALF, + ("uexp", DType::F16) => strided::exp::HALF, + ("ulog", DType::F16) => strided::log::HALF, + ("ugelu", DType::F16) => strided::gelu::HALF, + ("ugelu_erf", DType::F16) => strided::gelu_erf::HALF, + ("uerf", DType::F16) => strided::erf::HALF, + ("usilu", DType::F16) => strided::silu::HALF, + ("uabs", DType::F16) => strided::abs::HALF, + ("uceil", DType::F16) => strided::ceil::HALF, + ("ufloor", DType::F16) => strided::floor::HALF, + ("urelu", DType::F16) => strided::relu::HALF, + ("uround", DType::F16) => strided::round::HALF, + ("utanh", DType::F16) => strided::tanh::HALF, + + ("ucos", DType::BF16) => strided::cos::BFLOAT, + ("usin", DType::BF16) => strided::sin::BFLOAT, + ("usqr", DType::BF16) => strided::sqr::BFLOAT, + ("usqrt", DType::BF16) => strided::sqrt::BFLOAT, + ("uneg", DType::BF16) => strided::neg::BFLOAT, + ("uexp", DType::BF16) => strided::exp::BFLOAT, + ("ulog", DType::BF16) => strided::log::BFLOAT, + ("ugelu", DType::BF16) => strided::gelu::BFLOAT, + ("ugelu_erf", DType::BF16) => strided::gelu_erf::BFLOAT, + ("uerf", DType::BF16) => strided::erf::BFLOAT, + ("usilu", DType::BF16) => strided::silu::BFLOAT, + ("uabs", DType::BF16) => strided::abs::BFLOAT, + ("uceil", DType::BF16) => strided::ceil::BFLOAT, + ("ufloor", DType::BF16) => strided::floor::BFLOAT, + ("urelu", DType::BF16) => strided::relu::BFLOAT, + ("uround", DType::BF16) => strided::round::BFLOAT, + ("utanh", DType::BF16) => strided::tanh::BFLOAT, + + (name, dtype) => { + crate::bail!("Metal strided unary {name} {dtype:?} not implemented") + } + }; + let dst = BufferOffset::zero_offset(&buffer); + mcandle_metal_kernels::call_unary_strided( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + layout.dims(), + src, + layout.stride(), + dst, + ) + .map_err(MetalError::from)?; + } + } + + Ok(Self::new(buffer, device.clone(), el_count, dtype)) + } + + fn binary_impl( + &self, + rhs: &Self, + lhs_l: &Layout, + rhs_l: &Layout, + ) -> Result { + self.binary(B::KERNEL, rhs, lhs_l, rhs_l) + } + + fn where_cond( + &self, + layout: &Layout, + t: &Self, + t_l: &Layout, + f: &Self, + f_l: &Layout, + ) -> Result { + let device = self.device.clone(); + let shape = t_l.shape(); + let dims = shape.dims(); + let el = shape.elem_count(); + let dtype = t.dtype; + let buffer = self.device.new_buffer(el, dtype, "where")?; + let command_buffer = self.device.command_buffer()?; + if t.dtype() != f.dtype() { + crate::bail!( + "Invalid where: different dtypes for values {:?} != {:?}", + t.dtype(), + f.dtype() + ); + } + let name = match (self.dtype, t.dtype()) { + (DType::U8, DType::F32) => "where_u8_f32", + (DType::U32, DType::F32) => "where_u32_f32", + (DType::U8, DType::BF16) => "where_u8_bf16", + (DType::U8, DType::F16) => "where_u8_f16", + (DType::U8, DType::I16) => "where_u8_i16", + (DType::U8, DType::I32) => "where_u8_i32", + (DType::U8, DType::I64) => "where_u8_i64", + (DType::U8, DType::U32) => "where_u8_u32", + (DType::U8, DType::U8) => "where_u8_u8", + (left, right) => crate::bail!("Metal where_cond {left:?} {right:?} not implemented"), + }; + let src = buffer_o(&self.buffer, layout, self.dtype); + let t = buffer_o(&t.buffer, t_l, t.dtype); + let f = buffer_o(&f.buffer, f_l, f.dtype); + mcandle_metal_kernels::call_where_cond_strided( + &device.device, + &command_buffer, + &device.kernels, + name, + dims, + src, + layout.stride(), + t, + t_l.stride(), + f, + f_l.stride(), + &buffer, + ) + .map_err(MetalError::from)?; + Ok(Self::new(buffer, device, el, dtype)) + } + + fn conv1d( + &self, + layout: &Layout, + kernel: &Self, + kernel_l: &Layout, + params: &ParamsConv1D, + ) -> Result { + let device = self.device().clone(); + let shape = layout.shape(); + let dims = shape.dims(); + let strides = layout.stride(); + + let stride = params.stride; + let dilation = params.dilation; + let padding = params.padding; + let k_size = params.k_size; + let l_out = (dims[2] + 2 * padding - dilation * (k_size - 1) - 1) / stride + 1; + let dst_el = dims[0] * l_out * dims[1] * k_size; + let dst = self + .device + .new_buffer(dst_el, self.dtype, "conv1d_im2col")?; + let command_buffer = self.device.command_buffer()?; + let name = match self.dtype { + DType::F32 => "im2col1d_f32", + dtype => crate::bail!("Metal conv1d {dtype:?} not implemented"), + }; + let src = buffer_o(&self.buffer, layout, self.dtype); + mcandle_metal_kernels::call_im2col1d_strided( + &self.device.device, + &command_buffer, + &self.device.kernels, + name, + layout.shape().dims(), + strides, + (k_size, stride, padding, dilation), + src, + &dst, + ) + .map_err(MetalError::from)?; + let col = Self { + buffer: dst, + device, + count: dst_el, + dtype: self.dtype, + }; + let l_out = params.l_out(); + let b = params.b_size; + let n = params.c_out; + let k = params.k_size * params.c_in; + let m = l_out; + let col_l = Layout::contiguous((b, m, k)); + let res = if kernel_l.is_contiguous() { + let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) + .transpose(1, 2)? + .broadcast_as((b, k, n))?; + col.matmul_with_alpha(kernel, None, (b, m, n, k), &col_l, &kernel_l)? + } else { + // Make the kernel contiguous if not already the case. + let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?; + kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?; + let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) + .transpose(1, 2)? + .broadcast_as((b, k, n))?; + col.matmul_with_alpha(kernel, None, (b, m, n, k), &col_l, &kernel_l)? + }; + let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?; + let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?; + res.copy_strided_src(&mut res_t, 0, &res_l)?; + Ok(res_t) + } + + fn conv_transpose1d( + &self, + layout: &Layout, + k: &Self, + k_layout: &Layout, + params: &ParamsConvTranspose1D, + ) -> Result { + const USE_COL2IM_CONV1D_TR: bool = true; + + let can_use_col2im = k_layout.is_contiguous() + && params.dilation == 1 + && params.padding == 0 + && params.output_padding == 0; + let l_out = params.l_out(); + let dst_el = params.c_out * l_out * params.b_size; + + let buffer = if USE_COL2IM_CONV1D_TR && can_use_col2im { + let (b_size, c_in, l_in) = layout.shape().dims3()?; + let (c_in2, c_out, k_size) = k_layout.shape().dims3()?; + if c_in != c_in2 { + crate::bail!( + "convtr1d: shape mismatch on c_in {:?} {:?}", + layout.shape(), + k_layout.shape() + ) + } + let buffer = self + .device + .new_buffer(dst_el, self.dtype, "conv_transpose1d")?; + + let name = match self.dtype { + DType::F32 => "col2im1d_f32", + DType::U32 => "col2im1d_u32", + DType::U8 => "col2im1d_u8", + dtype => crate::bail!("metal col2im1d {dtype:?} not implemented"), + }; + let col = { + // This merges the last two dimensions of the kernel together. + let kernel_l_mm = Layout::new( + (b_size, c_in, k_size * c_out).into(), + vec![0, k_size * c_out, 1], + k_layout.start_offset(), + ); + self.matmul_with_alpha( + k, + None, + (b_size, l_in, c_out * k_size, c_in), + &layout.transpose(1, 2)?, + &kernel_l_mm, + )? + }; + // It is important for the command buffer to be obtained *after* the matmul + // kernel has run, otherwise we might use a command-buffer that has been commited + // already resulting in the following error. + // _status < MTLCommandBufferStatusCommitted > + // -[IOGPUMetalCommandBuffer setCurrentCommandEncoder:] + let command_buffer = self.device.command_buffer()?; + mcandle_metal_kernels::call_col2im1d( + &self.device.device, + &command_buffer, + &self.device.kernels, + name, + &[b_size, l_in, c_out, k_size], + params.k_size, + params.stride, + BufferOffset::zero_offset(&col.buffer), + &buffer, + ) + .map_err(MetalError::from)?; + buffer + } else { + let buffer = self + .device + .new_buffer(dst_el, self.dtype, "conv_transpose1d")?; + + let command_buffer = self.device.command_buffer()?; + let name = match self.dtype { + DType::F32 => "conv_transpose1d_f32", + DType::F16 => "conv_transpose1d_f16", + DType::BF16 => "conv_transpose1d_bf16", + DType::U32 => "conv_transpose1d_u32", + DType::U8 => "conv_transpose1d_u8", + dtype => crate::bail!("Metal conv_transpose1d {dtype:?} not implemented"), + }; + mcandle_metal_kernels::call_conv_transpose1d( + &self.device.device, + &command_buffer, + &self.device.kernels, + name, + params.dilation, + params.stride, + params.padding, + params.output_padding, + params.c_out, + l_out, + params.b_size, + layout.dims(), + layout.stride(), + k_layout.dims(), + k_layout.stride(), + &self.buffer, + layout.start_offset() * self.dtype.size_in_bytes(), + &k.buffer, + k_layout.start_offset() * k.dtype.size_in_bytes(), + &buffer, + ) + .map_err(MetalError::from)?; + buffer + }; + Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype)) + } + + fn conv2d( + &self, + layout: &Layout, + kernel: &Self, + kernel_l: &Layout, + params: &ParamsConv2D, + ) -> Result { + let device = self.device().clone(); + let shape = layout.shape(); + let dims = shape.dims(); + + let stride = params.stride; + let dilation = params.dilation; + let padding = params.padding; + let h_k = params.k_h; + let w_k = params.k_w; + let h = dims[2]; + let w = dims[3]; + let h_out = (h + 2 * padding - dilation * (h_k - 1) - 1) / stride + 1; + let w_out = (w + 2 * padding - dilation * (w_k - 1) - 1) / stride + 1; + let dst_el = dims[0] * h_out * w_out * dims[1] * h_k * w_k; + + let dst = self + .device + .new_buffer(dst_el, self.dtype, "conv2d_im2col")?; + let command_buffer = self.device.command_buffer()?; + let name = match self.dtype { + DType::F32 => "im2col_f32", + DType::F16 => "im2col_f16", + DType::BF16 => "im2col_bf16", + DType::U8 => "im2col_u8", + DType::U32 => "im2col_u32", + dtype => crate::bail!("Metal conv2d {dtype:?} not implemented"), + }; + let src = buffer_o(&self.buffer, layout, self.dtype); + mcandle_metal_kernels::call_im2col_strided( + &self.device.device, + &command_buffer, + &self.device.kernels, + name, + layout.shape().dims(), + layout.stride(), + (h_k, w_k, stride, padding, dilation), + src, + &dst, + ) + .map_err(MetalError::from)?; + let col = Self { + buffer: dst, + device, + count: dst_el, + dtype: self.dtype, + }; + let h_out = params.out_h(); + let w_out = params.out_w(); + let b = params.b_size; + let n = params.c_out; + let k = params.k_h * params.k_w * params.c_in; + let m = h_out * w_out; + let col_l = Layout::contiguous((b, m, k)); + let res = if kernel_l.is_contiguous() { + let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) + .transpose(1, 2)? + .broadcast_as((b, k, n))?; + col.matmul_with_alpha(kernel, None, (b, m, n, k), &col_l, &kernel_l)? + } else { + // Make the kernel contiguous if not already the case. + let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?; + kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?; + let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) + .transpose(1, 2)? + .broadcast_as((b, k, n))?; + col.matmul_with_alpha(kernel, None, (b, m, n, k), &col_l, &kernel_l)? + }; + let res_l = Layout::contiguous((b, h_out, w_out, n)) + .transpose(1, 2)? + .transpose(1, 3)?; + let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?; + res.copy_strided_src(&mut res_t, 0, &res_l)?; + Ok(res_t) + } + + fn conv_transpose2d( + &self, + l: &Layout, + kernel: &Self, + kernel_l: &Layout, + params: &ParamsConvTranspose2D, + ) -> Result { + // Kernel shape: (c_in_k, c_out, h_k, w_k) + // Input shape: (b_size, c_in, h_in, w_in) + let (out_w, out_h) = (params.out_w(), params.out_h()); + let dst_el = params.c_out * out_w * out_h * params.b_size; + + let dims = l.dims(); + if dims.len() != 4 { + crate::bail!("unexpected input shape for conv_transpose2d {dims:?}, expected 4") + } + + let k_dims = kernel_l.dims(); + if k_dims.len() != 4 { + crate::bail!("unexpected kernel shape for conv_transpose2d {k_dims:?}, expected 4") + } + + let buffer = self + .device + .new_buffer(dst_el, self.dtype, "conv_transpose2d")?; + + let command_buffer = self.device.command_buffer()?; + + let name = match self.dtype { + DType::F32 => "conv_transpose2d_f32", + DType::F16 => "conv_transpose2d_f16", + DType::BF16 => "conv_transpose2d_bf16", + dtype => crate::bail!("Metal conv_transpose2d {dtype:?} not implemented"), + }; + + mcandle_metal_kernels::call_conv_transpose2d( + &self.device.device, + &command_buffer, + &self.device.kernels, + name, + CallConvTranspose2dCfg { + dilation: params.dilation, + stride: params.stride, + padding: params.padding, + output_padding: params.output_padding, + c_out: params.c_out, + out_h, + out_w, + b_size: params.b_size, + input_dims: l.dims(), + input_stride: l.stride(), + kernel_dims: kernel_l.dims(), + kernel_stride: kernel_l.stride(), + input_offset: l.start_offset() * self.dtype.size_in_bytes(), + kernel_offset: kernel_l.start_offset() * kernel.dtype.size_in_bytes(), + }, + &self.buffer, + &kernel.buffer, + &buffer, + ) + .map_err(MetalError::from)?; + Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype)) + } + + fn avg_pool2d( + &self, + inp_l: &Layout, + (w_k, h_k): (usize, usize), + (w_stride, h_stride): (usize, usize), + ) -> Result { + let shape = inp_l.shape(); + let (b_size, channels, width, height) = shape.dims4()?; + let strides = inp_l.stride(); + let name = match self.dtype { + DType::F32 => "avg_pool2d_f32", + DType::F16 => "avg_pool2d_f16", + DType::BF16 => "avg_pool2d_bf16", + DType::U8 => "avg_pool2d_u8", + DType::U32 => "avg_pool2d_u32", + dtype => crate::bail!("Metal avg_pool2d {dtype:?} not implemented"), + }; + let out_w = (width - w_k) / w_stride + 1; + let out_h = (height - h_k) / h_stride + 1; + let dst_el = out_w * out_h * b_size * channels; + let buffer = self.device.new_buffer(dst_el, self.dtype, "avg_pool2d")?; + let command_buffers = self.device.command_buffer()?; + mcandle_metal_kernels::call_pool2d( + &self.device.device, + &command_buffers, + &self.device.kernels, + name, + inp_l.dims(), + strides, + out_w, + out_h, + w_k, + h_k, + w_stride, + h_stride, + &self.buffer, + &buffer, + ) + .map_err(MetalError::from)?; + Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype)) + } + + fn max_pool2d( + &self, + inp_l: &Layout, + (w_k, h_k): (usize, usize), + (w_stride, h_stride): (usize, usize), + ) -> Result { + let shape = inp_l.shape(); + let (b_size, channels, width, height) = shape.dims4()?; + let strides = inp_l.stride(); + let name = match self.dtype { + DType::F32 => "max_pool2d_f32", + DType::F16 => "max_pool2d_f16", + DType::BF16 => "max_pool2d_bf16", + DType::U8 => "max_pool2d_u8", + DType::U32 => "max_pool2d_u32", + dtype => crate::bail!("Metal max_pool2d {dtype:?} not implemented"), + }; + let out_w = (width - w_k) / w_stride + 1; + let out_h = (height - h_k) / h_stride + 1; + let dst_el = out_w * out_h * b_size * channels; + let buffer = self.device.new_buffer(dst_el, self.dtype, "max_pool2d")?; + let command_buffers = self.device.command_buffer()?; + mcandle_metal_kernels::call_pool2d( + &self.device.device, + &command_buffers, + &self.device.kernels, + name, + inp_l.dims(), + strides, + out_w, + out_h, + w_k, + h_k, + w_stride, + h_stride, + &self.buffer, + &buffer, + ) + .map_err(MetalError::from)?; + Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype)) + } + + fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result { + crate::bail!("Metal upsample_nearest1d not implemented") + } + + fn upsample_nearest2d(&self, inp_l: &Layout, out_w: usize, out_h: usize) -> Result { + // let inp = &inp.slice(inp_l.start_offset()..); + let shape = inp_l.shape(); + let dims = shape.dims(); + let strides = inp_l.stride(); + if dims.len() != 4 { + crate::bail!("unexpected input shape for upsample {dims:?}") + } + let name = match self.dtype { + DType::F32 => "upsample_nearest2d_f32", + DType::F16 => "upsample_nearest2d_f16", + DType::BF16 => "upsample_nearest2d_bf16", + DType::U8 => "upsample_nearest2d_u8", + DType::U32 => "upsample_nearest2d_u32", + dtype => crate::bail!("Metal upsample_nearest2d {dtype:?} not implemented"), + }; + + let dst_el = out_w * out_h * dims[0] * dims[1]; + let buffer = self + .device + .new_buffer(dst_el, self.dtype, "upsample_nearest2d")?; + let command_buffer = self.device.command_buffer()?; + let src = buffer_o(&self.buffer, inp_l, self.dtype); + mcandle_metal_kernels::call_upsample_nearest_2d( + &self.device.device, + &command_buffer, + &self.device.kernels, + name, + dims, + strides, + out_w, + out_h, + src, + &buffer, + ) + .map_err(MetalError::from)?; + Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype)) + } + + fn gather(&self, src_l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result { + if !ids_l.is_contiguous() { + return Err(crate::Error::RequiresContiguous { op: "gather" }.bt()); + }; + let ids_el = ids_l.dims()[dim]; + let dst_el = ids_l.shape().elem_count(); + let dtype = self.dtype; + let device = self.device(); + let buffer = device.new_buffer(dst_el, dtype, "gather")?; + let name = match (ids.dtype, self.dtype) { + (DType::U32, DType::F32) => "gather_u32_f32", + (DType::U32, DType::F16) => "gather_u32_f16", + (DType::U32, DType::BF16) => "gather_u32_bf16", + (left, right) => crate::bail!("Metal gather {left:?} {right:?} not implemented"), + }; + let command_buffer = self.device.command_buffer()?; + let src = buffer_o(&self.buffer, src_l, dtype); + let ids = buffer_o(&ids.buffer, ids_l, ids.dtype); + mcandle_metal_kernels::call_gather( + &device.device, + &command_buffer, + &self.device.kernels, + name, + src_l.dims(), + ids_el, + dim, + src, + ids, + &buffer, + ) + .map_err(MetalError::from)?; + Ok(Self::new(buffer, device.clone(), dst_el, dtype)) + } + + fn scatter_add( + &self, + l: &Layout, + ids: &Self, + ids_l: &Layout, + src: &Self, + src_l: &Layout, + dim: usize, + ) -> Result { + let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?; + self.copy_strided_src(&mut acc, 0, l)?; + if !ids_l.is_contiguous() || !src_l.is_contiguous() { + return Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt()); + }; + let name = match (ids.dtype, self.dtype) { + (DType::U8, DType::F32) => "sa_u8_f32", + (DType::U8, DType::F16) => "sa_u8_f16", + (DType::U8, DType::BF16) => "sa_u8_bf16", + (DType::U32, DType::F32) => "sa_u32_f32", + (DType::U32, DType::F16) => "sa_u32_f16", + (DType::U32, DType::BF16) => "sa_u32_bf16", + (DType::I16, DType::F32) => "sa_i16_f32", + (DType::I16, DType::F16) => "sa_i16_f16", + (DType::I16, DType::BF16) => "sa_i16_bf16", + (DType::I32, DType::F32) => "sa_i32_f32", + (DType::I32, DType::F16) => "sa_i32_f16", + (DType::I32, DType::BF16) => "sa_i32_bf16", + (DType::I64, DType::F32) => "sa_i64_f32", + (DType::I64, DType::F16) => "sa_i64_f16", + (DType::I64, DType::BF16) => "sa_i64_bf16", + _ => Err(MetalError::UnexpectedDType { + msg: "scatter-add ids should be u8/u32/i64", + expected: DType::U32, + got: ids.dtype(), + })?, + }; + let command_buffer = self.device.command_buffer()?; + let src = buffer_o(&src.buffer, src_l, src.dtype); + let ids = buffer_o(&ids.buffer, ids_l, ids.dtype); + mcandle_metal_kernels::call_scatter_add( + &self.device.device, + &command_buffer, + &self.device.kernels, + name, + src_l.dims(), + l.dims(), + dim, + src, + ids, + &acc.buffer, + ) + .map_err(MetalError::from)?; + Ok(acc) + } + + fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result { + if !ids_l.is_contiguous() { + crate::bail!("Metal index_select requires contiguous ids") + } + let left_size: usize = src_l.dims()[..dim].iter().product(); + let right_size: usize = src_l.dims()[dim + 1..].iter().product(); + let ids_el = ids_l.shape().elem_count(); + let dst_el = ids_el * left_size * right_size; + let dtype = self.dtype; + let device = self.device(); + let buffer = device.new_buffer(dst_el, dtype, "index_select")?; + let name = match (ids.dtype, self.dtype) { + (DType::U8, DType::U8) => "is_u8_u8", + (DType::U8, DType::U32) => "is_u8_u32", + (DType::U8, DType::I64) => "is_u8_i64", + (DType::U8, DType::BF16) => "is_u8_bf16", + (DType::U8, DType::F32) => "is_u8_f32", + (DType::U8, DType::F16) => "is_u8_f16", + + (DType::U32, DType::U8) => "is_u32_u8", + (DType::U32, DType::U32) => "is_u32_u32", + (DType::U32, DType::I64) => "is_u32_i64", + (DType::U32, DType::F32) => "is_u32_f32", + (DType::U32, DType::F16) => "is_u32_f16", + (DType::U32, DType::BF16) => "is_u32_bf16", + + (DType::I16, DType::F32) => "is_i16_f32", + (DType::I16, DType::F16) => "is_i16_f16", + (DType::I16, DType::BF16) => "is_i16_bf16", + + (DType::I32, DType::F32) => "is_i32_f32", + (DType::I32, DType::F16) => "is_i32_f16", + (DType::I32, DType::BF16) => "is_i32_bf16", + + (DType::I64, DType::U8) => "is_i64_u8", + (DType::I64, DType::U32) => "is_i64_u32", + (DType::I64, DType::I64) => "is_i64_i64", + (DType::I64, DType::F32) => "is_i64_f32", + (DType::I64, DType::F16) => "is_i64_f16", + (DType::I64, DType::BF16) => "is_i64_bf16", + + (left, right) => { + crate::bail!("Metal contiguous index_select {left:?} {right:?} not implemented") + } + }; + let command_buffer = self.device.command_buffer()?; + let src = buffer_o(&self.buffer, src_l, dtype); + let ids = buffer_o(&ids.buffer, ids_l, ids.dtype); + mcandle_metal_kernels::call_index_select( + &device.device, + &command_buffer, + &self.device.kernels, + name, + src_l.dims(), + ids_el, + dim, + src_l.is_contiguous(), + src_l.dims(), + src_l.stride(), + src, + ids, + &buffer, + ) + .map_err(MetalError::from)?; + Ok(Self::new(buffer, device.clone(), dst_el, dtype)) + } + + fn index_add( + &self, + l: &Layout, + ids: &Self, + ids_l: &Layout, + src: &Self, + src_l: &Layout, + dim: usize, + ) -> Result { + let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?; + self.copy_strided_src(&mut acc, 0, l)?; + if !ids_l.is_contiguous() || !src_l.is_contiguous() { + return Err(crate::Error::RequiresContiguous { op: "index-add" }.bt()); + }; + let name = match (ids.dtype, self.dtype) { + (DType::I16, DType::BF16) => "ia_i16_bf16", + (DType::I16, DType::F16) => "ia_i16_f16", + (DType::I16, DType::F32) => "ia_i16_f32", + (DType::I16, DType::I32) => "ia_i16_i32", + (DType::I16, DType::I64) => "ia_i16_i64", + (DType::I16, DType::U32) => "ia_i16_u32", + (DType::I16, DType::U8) => "ia_i16_u8", + + (DType::I32, DType::BF16) => "ia_i32_bf16", + (DType::I32, DType::F16) => "ia_i32_f16", + (DType::I32, DType::F32) => "ia_i32_f32", + (DType::I32, DType::I32) => "ia_i32_i32", + (DType::I32, DType::I64) => "ia_i32_i64", + (DType::I32, DType::U32) => "ia_i32_u32", + (DType::I32, DType::U8) => "ia_i32_u8", + + (DType::I64, DType::BF16) => "ia_i64_bf16", + (DType::I64, DType::F16) => "ia_i64_f16", + (DType::I64, DType::F32) => "ia_i64_f32", + (DType::I64, DType::I16) => "ia_i64_i16", + (DType::I64, DType::I32) => "ia_i64_i32", + (DType::I64, DType::I64) => "ia_i64_i64", + (DType::I64, DType::U32) => "ia_i64_u32", + (DType::I64, DType::U8) => "ia_i64_u8", + + (DType::U32, DType::BF16) => "ia_u32_bf16", + (DType::U32, DType::F16) => "ia_u32_f16", + (DType::U32, DType::F32) => "ia_u32_f32", + (DType::U32, DType::I16) => "ia_u32_i16", + (DType::U32, DType::I32) => "ia_u32_i32", + (DType::U32, DType::I64) => "ia_u32_i64", + (DType::U32, DType::U32) => "ia_u32_u32", + (DType::U32, DType::U8) => "ia_u32_u8", + + (DType::U8, DType::BF16) => "ia_u8_bf16", + (DType::U8, DType::F16) => "ia_u8_f16", + (DType::U8, DType::F32) => "ia_u8_f32", + (DType::U8, DType::I16) => "ia_u8_i16", + (DType::U8, DType::I32) => "ia_u8_i32", + (DType::U8, DType::I64) => "ia_u8_i64", + (DType::U8, DType::U32) => "ia_u8_u32", + (DType::U8, DType::U8) => "ia_u8_u8", + + _ => Err(MetalError::UnexpectedDType { + msg: "index-add ids should be u8/u32/i64", + expected: DType::U32, + got: ids.dtype(), + })?, + }; + let command_buffer = self.device.command_buffer()?; + let src = buffer_o(&src.buffer, src_l, src.dtype); + let ids = buffer_o(&ids.buffer, ids_l, ids.dtype); + mcandle_metal_kernels::call_index_add( + &self.device.device, + &command_buffer, + &self.device.kernels, + name, + src_l.dims(), + l.dims(), + ids_l.dims(), + dim, + src, + ids, + &acc.buffer, + ) + .map_err(MetalError::from)?; + Ok(acc) + } + + fn matmul_with_alpha_beta( + &self, + rhs: &Self, + c: &mut Self, + s: Option, + (b, m, n, k): (usize, usize, usize, usize), + lhs_l: &Layout, + rhs_l: &Layout, + c_l: &Layout, + ) -> Result<()> { + let name = match self.dtype { + DType::F32 => "sgemm", + DType::F16 => "hgemm", + DType::BF16 => "bgemm", + dtype => { + return Err(MetalError::Message(format!("matmul doesn't support {dtype:?}")).into()) + } + }; + + let elem_count = b * m * n; + + match c_l.contiguous_offsets() { + Some((o1, o2)) => { + if o1 != 0 { + crate::bail!("`c` start offset must be 0"); + } + if o2 != elem_count { + crate::bail!("`c` end offset must be {}", elem_count) + } + } + None => crate::bail!("`c` has to be contiguous"), + }; + + let command_buffer = self.device.command_buffer()?; + command_buffer.set_label("matmul"); + mcandle_metal_kernels::call_gemm( + &self.device.device, + &command_buffer, + &self.device.kernels, + name, + (b, m, n, k), + lhs_l.stride(), + lhs_l.start_offset() * self.dtype.size_in_bytes(), + &self.buffer, + rhs_l.stride(), + rhs_l.start_offset() * rhs.dtype.size_in_bytes(), + &rhs.buffer, + &c.buffer, + s.unwrap_or(1.) as f32, + 1., + ) + .map_err(MetalError::from)?; + Ok(()) + } + + fn matmul_with_alpha( + &self, + rhs: &Self, + s: Option, + (b, m, n, k): (usize, usize, usize, usize), + lhs_l: &Layout, + rhs_l: &Layout, + ) -> Result { + let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul")?; + let command_buffer = self.device.command_buffer()?; + command_buffer.set_label("matmul"); + if self.dtype == DType::BF16 { + if s.unwrap_or(1.) != 1. { + return Err( + MetalError::Message(format!("mlx matmul doesn't support alpha {s:?}")).into(), + ); + } + mcandle_metal_kernels::call_mlx_gemm( + &self.device.device, + &command_buffer, + &self.device.kernels, + mcandle_metal_kernels::GemmDType::BF16, + (b, m, n, k), + lhs_l.stride(), + lhs_l.start_offset() * self.dtype.size_in_bytes(), + &self.buffer, + rhs_l.stride(), + rhs_l.start_offset() * rhs.dtype.size_in_bytes(), + &rhs.buffer, + &buffer, + ) + .map_err(MetalError::from)?; + } else if self.device.use_mlx_mm && s.unwrap_or(1.) == 1. { + let dtype = match self.dtype { + DType::F32 => mcandle_metal_kernels::GemmDType::F32, + DType::F16 => mcandle_metal_kernels::GemmDType::F16, + DType::BF16 => mcandle_metal_kernels::GemmDType::BF16, + dtype => { + return Err(MetalError::Message(format!( + "mlx matmul doesn't support {dtype:?}" + )) + .into()) + } + }; + mcandle_metal_kernels::call_mlx_gemm( + &self.device.device, + &command_buffer, + &self.device.kernels, + dtype, + (b, m, n, k), + lhs_l.stride(), + lhs_l.start_offset() * self.dtype.size_in_bytes(), + &self.buffer, + rhs_l.stride(), + rhs_l.start_offset() * rhs.dtype.size_in_bytes(), + &rhs.buffer, + &buffer, + ) + .map_err(MetalError::from)?; + } else { + let name = match self.dtype { + DType::F32 => "sgemm", + DType::F16 => "hgemm", + dtype => { + return Err( + MetalError::Message(format!("matmul doesn't support {dtype:?}")).into(), + ) + } + }; + + mcandle_metal_kernels::call_gemm( + &self.device.device, + &command_buffer, + &self.device.kernels, + name, + (b, m, n, k), + lhs_l.stride(), + lhs_l.start_offset() * self.dtype.size_in_bytes(), + &self.buffer, + rhs_l.stride(), + rhs_l.start_offset() * rhs.dtype.size_in_bytes(), + &rhs.buffer, + &buffer, + s.unwrap_or(1.) as f32, + 0., + ) + .map_err(MetalError::from)?; + } + Ok(Self::new( + buffer, + self.device.clone(), + b * m * n, + self.dtype(), + )) + } + + fn copy2d( + &self, + dst: &mut Self, + d1: usize, + d2: usize, + src_s: usize, + dst_s: usize, + src_o: usize, + dst_o: usize, + ) -> Result<()> { + if self.dtype() != dst.dtype() { + crate::bail!( + "copy2d with inconsistent dtypes {:?} {:?}", + self.dtype(), + dst.dtype() + ) + } + let command_buffer = self.device.command_buffer()?; + if src_s == d2 && dst_s == d2 { + command_buffer.set_label("copy2d_contiguous"); + let blit = command_buffer.new_blit_command_encoder(); + blit.set_label("copy2d_contiguous"); + let src_offset = (src_o * self.dtype.size_in_bytes()) as NSUInteger; + let length = (d1 * d2 * self.dtype.size_in_bytes()) as NSUInteger; + let dst_offset = (dst_o * dst.dtype().size_in_bytes()) as NSUInteger; + blit.copy_from_buffer(&self.buffer, src_offset, dst.buffer(), dst_offset, length); + blit.end_encoding(); + } else { + let el_count = d1 * d2; + if el_count == 0 { + return Ok(()); + } + let kernel_name = match self.dtype { + DType::F32 => mcandle_metal_kernels::copy2d::FLOAT, + DType::F16 => mcandle_metal_kernels::copy2d::HALF, + DType::BF16 => mcandle_metal_kernels::copy2d::BFLOAT, + DType::I16 => mcandle_metal_kernels::copy2d::I16, + DType::I32 => mcandle_metal_kernels::copy2d::I32, + DType::I64 => mcandle_metal_kernels::copy2d::I64, + DType::U32 => mcandle_metal_kernels::copy2d::U32, + DType::U8 => mcandle_metal_kernels::copy2d::U8, + dtype => crate::bail!("Metal copy2d {dtype:?} not implemented"), + }; + mcandle_metal_kernels::call_copy2d( + &self.device.device, + &command_buffer, + &self.device.kernels, + kernel_name, + &self.buffer, + &dst.buffer, + d1, + d2, + src_s, + dst_s, + src_o * self.dtype.size_in_bytes(), + dst_o * self.dtype.size_in_bytes(), + ) + .map_err(MetalError::from)?; + command_buffer.set_label("copy2d"); + } + Ok(()) + } + + fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { + let command_buffer = self.device.command_buffer()?; + if src_l.is_contiguous() && self.dtype == dst.dtype() { + command_buffer.set_label("copy_contiguous"); + let blit = command_buffer.new_blit_command_encoder(); + blit.set_label("copy_contiguous"); + let src_offset = (src_l.start_offset() * self.dtype.size_in_bytes()) as NSUInteger; + let length = (src_l.shape().elem_count() * self.dtype.size_in_bytes()) as NSUInteger; + let dst_offset = (dst_offset * dst.dtype().size_in_bytes()) as NSUInteger; + blit.copy_from_buffer(&self.buffer, src_offset, dst.buffer(), dst_offset, length); + blit.end_encoding(); + } else { + let src_shape = src_l.shape(); + let el_count = src_shape.elem_count(); + if el_count == 0 { + return Ok(()); + } + let kernel_name = match self.dtype { + DType::F32 => mcandle_metal_kernels::unary::strided::copy::FLOAT, + DType::F16 => mcandle_metal_kernels::unary::strided::copy::HALF, + DType::BF16 => mcandle_metal_kernels::unary::strided::copy::BFLOAT, + DType::I16 => mcandle_metal_kernels::unary::strided::copy::I16, + DType::I32 => mcandle_metal_kernels::unary::strided::copy::I32, + DType::I64 => mcandle_metal_kernels::unary::strided::copy::I64, + DType::U32 => mcandle_metal_kernels::unary::strided::copy::U32, + DType::U8 => mcandle_metal_kernels::unary::strided::copy::U8, + dtype => crate::bail!("Metal copy_strided {dtype:?} not implemented"), + }; + let src = buffer_o(&self.buffer, src_l, self.dtype); + let dst = BufferOffset { + buffer: &dst.buffer, + offset_in_bytes: dst_offset * dst.dtype.size_in_bytes(), + }; + mcandle_metal_kernels::call_unary_strided( + &self.device.device, + &command_buffer, + &self.device.kernels, + kernel_name, + src_l.dims(), + src, + src_l.stride(), + dst, + ) + .map_err(MetalError::from)?; + command_buffer.set_label("copy_strided"); + } + Ok(()) + } +} + +impl MetalStorage { + pub fn new(buffer: Arc, device: MetalDevice, count: usize, dtype: DType) -> Self { + Self { + buffer, + device, + count, + dtype, + } + } + + pub fn buffer(&self) -> &Buffer { + &self.buffer + } + + pub fn binary( + &self, + op: &'static str, + rhs: &Self, + lhs_l: &Layout, + rhs_l: &Layout, + ) -> Result { + let device = self.device(); + let shape = lhs_l.shape(); + let el_count = shape.elem_count(); + let command_buffer = device.command_buffer()?; + let lhs = buffer_o(&self.buffer, lhs_l, self.dtype); + let rhs = buffer_o(&rhs.buffer, rhs_l, rhs.dtype); + let (buffer, dtype) = if lhs_l.is_contiguous() && rhs_l.is_contiguous() && &op[..1] != "b" { + use mcandle_metal_kernels::binary::contiguous; + + let (kernel_name, dtype) = match (op, self.dtype) { + ("add", DType::F32) => (contiguous::add::FLOAT, self.dtype), + ("sub", DType::F32) => (contiguous::sub::FLOAT, self.dtype), + ("mul", DType::F32) => (contiguous::mul::FLOAT, self.dtype), + ("div", DType::F32) => (contiguous::div::FLOAT, self.dtype), + ("eq", DType::F32) => (contiguous::eq::FLOAT, DType::U8), + ("ne", DType::F32) => (contiguous::ne::FLOAT, DType::U8), + ("le", DType::F32) => (contiguous::le::FLOAT, DType::U8), + ("lt", DType::F32) => (contiguous::lt::FLOAT, DType::U8), + ("ge", DType::F32) => (contiguous::ge::FLOAT, DType::U8), + ("gt", DType::F32) => (contiguous::gt::FLOAT, DType::U8), + + ("add", DType::F16) => (contiguous::add::HALF, self.dtype), + ("sub", DType::F16) => (contiguous::sub::HALF, self.dtype), + ("mul", DType::F16) => (contiguous::mul::HALF, self.dtype), + ("div", DType::F16) => (contiguous::div::HALF, self.dtype), + ("eq", DType::F16) => (contiguous::eq::HALF, DType::U8), + ("ne", DType::F16) => (contiguous::ne::HALF, DType::U8), + ("le", DType::F16) => (contiguous::le::HALF, DType::U8), + ("lt", DType::F16) => (contiguous::lt::HALF, DType::U8), + ("ge", DType::F16) => (contiguous::ge::HALF, DType::U8), + ("gt", DType::F16) => (contiguous::gt::HALF, DType::U8), + + ("add", DType::BF16) => (contiguous::add::BFLOAT, self.dtype), + ("sub", DType::BF16) => (contiguous::sub::BFLOAT, self.dtype), + ("mul", DType::BF16) => (contiguous::mul::BFLOAT, self.dtype), + ("div", DType::BF16) => (contiguous::div::BFLOAT, self.dtype), + ("eq", DType::BF16) => (contiguous::eq::BFLOAT, DType::U8), + ("ne", DType::BF16) => (contiguous::ne::BFLOAT, DType::U8), + ("le", DType::BF16) => (contiguous::le::BFLOAT, DType::U8), + ("lt", DType::BF16) => (contiguous::lt::BFLOAT, DType::U8), + ("ge", DType::BF16) => (contiguous::ge::BFLOAT, DType::U8), + ("gt", DType::BF16) => (contiguous::gt::BFLOAT, DType::U8), + + ("add", DType::I16) => (contiguous::add::I16, self.dtype), + ("sub", DType::I16) => (contiguous::sub::I16, self.dtype), + ("mul", DType::I16) => (contiguous::mul::I16, self.dtype), + ("div", DType::I16) => (contiguous::div::I16, self.dtype), + ("eq", DType::I16) => (contiguous::eq::I16, DType::U8), + ("ne", DType::I16) => (contiguous::ne::I16, DType::U8), + ("le", DType::I16) => (contiguous::le::I16, DType::U8), + ("lt", DType::I16) => (contiguous::lt::I16, DType::U8), + ("ge", DType::I16) => (contiguous::ge::I16, DType::U8), + ("gt", DType::I16) => (contiguous::gt::I16, DType::U8), + + ("add", DType::I32) => (contiguous::add::I32, self.dtype), + ("sub", DType::I32) => (contiguous::sub::I32, self.dtype), + ("mul", DType::I32) => (contiguous::mul::I32, self.dtype), + ("div", DType::I32) => (contiguous::div::I32, self.dtype), + ("eq", DType::I32) => (contiguous::eq::I32, DType::U8), + ("ne", DType::I32) => (contiguous::ne::I32, DType::U8), + ("le", DType::I32) => (contiguous::le::I32, DType::U8), + ("lt", DType::I32) => (contiguous::lt::I32, DType::U8), + ("ge", DType::I32) => (contiguous::ge::I32, DType::U8), + ("gt", DType::I32) => (contiguous::gt::I32, DType::U8), + + ("add", DType::I64) => (contiguous::add::I64, self.dtype), + ("sub", DType::I64) => (contiguous::sub::I64, self.dtype), + ("mul", DType::I64) => (contiguous::mul::I64, self.dtype), + ("div", DType::I64) => (contiguous::div::I64, self.dtype), + ("eq", DType::I64) => (contiguous::eq::I64, DType::U8), + ("ne", DType::I64) => (contiguous::ne::I64, DType::U8), + ("le", DType::I64) => (contiguous::le::I64, DType::U8), + ("lt", DType::I64) => (contiguous::lt::I64, DType::U8), + ("ge", DType::I64) => (contiguous::ge::I64, DType::U8), + ("gt", DType::I64) => (contiguous::gt::I64, DType::U8), + + ("add", DType::U32) => (contiguous::add::U32, self.dtype), + ("sub", DType::U32) => (contiguous::sub::U32, self.dtype), + ("mul", DType::U32) => (contiguous::mul::U32, self.dtype), + ("div", DType::U32) => (contiguous::div::U32, self.dtype), + ("eq", DType::U32) => (contiguous::eq::U32, DType::U8), + ("ne", DType::U32) => (contiguous::ne::U32, DType::U8), + ("le", DType::U32) => (contiguous::le::U32, DType::U8), + ("lt", DType::U32) => (contiguous::lt::U32, DType::U8), + ("ge", DType::U32) => (contiguous::ge::U32, DType::U8), + ("gt", DType::U32) => (contiguous::gt::U32, DType::U8), + + ("add", DType::U8) => (contiguous::add::U8, self.dtype), + ("sub", DType::U8) => (contiguous::sub::U8, self.dtype), + ("mul", DType::U8) => (contiguous::mul::U8, self.dtype), + ("div", DType::U8) => (contiguous::div::U8, self.dtype), + ("eq", DType::U8) => (contiguous::eq::U8, DType::U8), + ("ne", DType::U8) => (contiguous::ne::U8, DType::U8), + ("le", DType::U8) => (contiguous::le::U8, DType::U8), + ("lt", DType::U8) => (contiguous::lt::U8, DType::U8), + ("ge", DType::U8) => (contiguous::ge::U8, DType::U8), + ("gt", DType::U8) => (contiguous::gt::U8, DType::U8), + + (name, dtype) => { + crate::bail!("Metal contiguous binary {name} {dtype:?} not implemented") + } + }; + let buffer = device.new_buffer(el_count, dtype, op)?; + mcandle_metal_kernels::call_binary_contiguous( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + el_count, + lhs, + rhs, + &buffer, + ) + .map_err(MetalError::from)?; + (buffer, dtype) + } else { + use mcandle_metal_kernels::binary::strided; + + let (kernel_name, dtype) = match (op, self.dtype) { + ("badd", DType::F32) => (strided::add::FLOAT, self.dtype), + ("bsub", DType::F32) => (strided::sub::FLOAT, self.dtype), + ("bmul", DType::F32) => (strided::mul::FLOAT, self.dtype), + ("bdiv", DType::F32) => (strided::div::FLOAT, self.dtype), + ("bminimum", DType::F32) => (strided::min::FLOAT, self.dtype), + ("bmaximum", DType::F32) => (strided::max::FLOAT, self.dtype), + ("eq", DType::F32) => (strided::eq::FLOAT, DType::U8), + ("ne", DType::F32) => (strided::ne::FLOAT, DType::U8), + ("le", DType::F32) => (strided::le::FLOAT, DType::U8), + ("lt", DType::F32) => (strided::lt::FLOAT, DType::U8), + ("ge", DType::F32) => (strided::ge::FLOAT, DType::U8), + ("gt", DType::F32) => (strided::gt::FLOAT, DType::U8), + + ("badd", DType::F16) => (strided::add::HALF, self.dtype), + ("bsub", DType::F16) => (strided::sub::HALF, self.dtype), + ("bmul", DType::F16) => (strided::mul::HALF, self.dtype), + ("bdiv", DType::F16) => (strided::div::HALF, self.dtype), + ("bminimum", DType::F16) => (strided::min::HALF, self.dtype), + ("bmaximum", DType::F16) => (strided::max::HALF, self.dtype), + ("eq", DType::F16) => (strided::eq::HALF, DType::U8), + ("ne", DType::F16) => (strided::ne::HALF, DType::U8), + ("le", DType::F16) => (strided::le::HALF, DType::U8), + ("lt", DType::F16) => (strided::lt::HALF, DType::U8), + ("ge", DType::F16) => (strided::ge::HALF, DType::U8), + ("gt", DType::F16) => (strided::gt::HALF, DType::U8), + + ("badd", DType::BF16) => (strided::add::BFLOAT, self.dtype), + ("bsub", DType::BF16) => (strided::sub::BFLOAT, self.dtype), + ("bmul", DType::BF16) => (strided::mul::BFLOAT, self.dtype), + ("bdiv", DType::BF16) => (strided::div::BFLOAT, self.dtype), + ("bminimum", DType::BF16) => (strided::min::BFLOAT, self.dtype), + ("bmaximum", DType::BF16) => (strided::max::BFLOAT, self.dtype), + ("eq", DType::BF16) => (strided::eq::BFLOAT, DType::U8), + ("ne", DType::BF16) => (strided::ne::BFLOAT, DType::U8), + ("le", DType::BF16) => (strided::le::BFLOAT, DType::U8), + ("lt", DType::BF16) => (strided::lt::BFLOAT, DType::U8), + ("ge", DType::BF16) => (strided::ge::BFLOAT, DType::U8), + ("gt", DType::BF16) => (strided::gt::BFLOAT, DType::U8), + + ("badd", DType::I16) => (strided::add::I16, self.dtype), + ("bsub", DType::I16) => (strided::sub::I16, self.dtype), + ("bmul", DType::I16) => (strided::mul::I16, self.dtype), + ("bdiv", DType::I16) => (strided::div::I16, self.dtype), + ("bminimum", DType::I16) => (strided::min::I16, self.dtype), + ("bmaximum", DType::I16) => (strided::max::I16, self.dtype), + ("eq", DType::I16) => (strided::eq::I16, DType::U8), + ("ne", DType::I16) => (strided::ne::I16, DType::U8), + ("le", DType::I16) => (strided::le::I16, DType::U8), + ("lt", DType::I16) => (strided::lt::I16, DType::U8), + ("ge", DType::I16) => (strided::ge::I16, DType::U8), + ("gt", DType::I16) => (strided::gt::I16, DType::U8), + + ("badd", DType::I32) => (strided::add::I32, self.dtype), + ("bsub", DType::I32) => (strided::sub::I32, self.dtype), + ("bmul", DType::I32) => (strided::mul::I32, self.dtype), + ("bdiv", DType::I32) => (strided::div::I32, self.dtype), + ("bminimum", DType::I32) => (strided::min::I32, self.dtype), + ("bmaximum", DType::I32) => (strided::max::I32, self.dtype), + ("eq", DType::I32) => (strided::eq::I32, DType::U8), + ("ne", DType::I32) => (strided::ne::I32, DType::U8), + ("le", DType::I32) => (strided::le::I32, DType::U8), + ("lt", DType::I32) => (strided::lt::I32, DType::U8), + ("ge", DType::I32) => (strided::ge::I32, DType::U8), + ("gt", DType::I32) => (strided::gt::I32, DType::U8), + + ("badd", DType::I64) => (strided::add::I64, self.dtype), + ("bsub", DType::I64) => (strided::sub::I64, self.dtype), + ("bmul", DType::I64) => (strided::mul::I64, self.dtype), + ("bdiv", DType::I64) => (strided::div::I64, self.dtype), + ("bminimum", DType::I64) => (strided::min::I64, self.dtype), + ("bmaximum", DType::I64) => (strided::max::I64, self.dtype), + ("eq", DType::I64) => (strided::eq::I64, DType::U8), + ("ne", DType::I64) => (strided::ne::I64, DType::U8), + ("le", DType::I64) => (strided::le::I64, DType::U8), + ("lt", DType::I64) => (strided::lt::I64, DType::U8), + ("ge", DType::I64) => (strided::ge::I64, DType::U8), + ("gt", DType::I64) => (strided::gt::I64, DType::U8), + + ("badd", DType::U32) => (strided::add::U32, self.dtype), + ("bsub", DType::U32) => (strided::sub::U32, self.dtype), + ("bmul", DType::U32) => (strided::mul::U32, self.dtype), + ("bdiv", DType::U32) => (strided::div::U32, self.dtype), + ("bminimum", DType::U32) => (strided::min::U32, self.dtype), + ("bmaximum", DType::U32) => (strided::max::U32, self.dtype), + ("eq", DType::U32) => (strided::eq::U32, DType::U8), + ("ne", DType::U32) => (strided::ne::U32, DType::U8), + ("le", DType::U32) => (strided::le::U32, DType::U8), + ("lt", DType::U32) => (strided::lt::U32, DType::U8), + ("ge", DType::U32) => (strided::ge::U32, DType::U8), + ("gt", DType::U32) => (strided::gt::U32, DType::U8), + + ("badd", DType::U8) => (strided::add::U8, self.dtype), + ("bsub", DType::U8) => (strided::sub::U8, self.dtype), + ("bmul", DType::U8) => (strided::mul::U8, self.dtype), + ("bdiv", DType::U8) => (strided::div::U8, self.dtype), + ("bminimum", DType::U8) => (strided::min::U8, self.dtype), + ("bmaximum", DType::U8) => (strided::max::U8, self.dtype), + ("eq", DType::U8) => (strided::eq::U8, DType::U8), + ("ne", DType::U8) => (strided::ne::U8, DType::U8), + ("le", DType::U8) => (strided::le::U8, DType::U8), + ("lt", DType::U8) => (strided::lt::U8, DType::U8), + ("ge", DType::U8) => (strided::ge::U8, DType::U8), + ("gt", DType::U8) => (strided::gt::U8, DType::U8), + + (name, dtype) => { + crate::bail!("Metal strided binary {name} {dtype:?} not implemented") + } + }; + let buffer = device.new_buffer(el_count, dtype, op)?; + mcandle_metal_kernels::call_binary_strided( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + lhs_l.dims(), + lhs, + lhs_l.stride(), + rhs, + rhs_l.stride(), + &buffer, + ) + .map_err(MetalError::from)?; + (buffer, dtype) + }; + command_buffer.set_label("binary"); + Ok(Self::new(buffer, device.clone(), el_count, dtype)) + } + + pub(crate) fn to_cpu(&self) -> Result> { + self.device.wait_until_completed()?; + + let ptr = self.buffer.contents() as *mut T; + assert!(!ptr.is_null()); + let slice = unsafe { std::slice::from_raw_parts(ptr, self.count) }; + Ok(slice.to_vec()) + } +} + +impl BackendDevice for MetalDevice { + type Storage = MetalStorage; + + fn new(ordinal: usize) -> Result { + let device = metal::Device::all().swap_remove(ordinal); + let command_queue = device.new_command_queue(); + let kernels = Arc::new(Kernels::new()); + let use_mlx_mm = match std::env::var("CANDLE_USE_MFA_MM").as_deref() { + Ok("false") | Ok("False") | Ok("FALSE") | Ok("0") | Err(_) => true, + Ok(_) => false, + }; + let seed = Arc::new(Mutex::new(device.new_buffer_with_data( + [299792458].as_ptr() as *const c_void, + 4, + MTLResourceOptions::StorageModeShared, + ))); + let commands = device::Commands::new(command_queue)?; + Ok(Self { + id: DeviceId::new(), + device, + commands: Arc::new(RwLock::new(commands)), + buffers: Arc::new(RwLock::new(HashMap::new())), + kernels, + seed, + seed_value: Arc::new(RwLock::new(299792458)), + use_mlx_mm, + }) + } + + fn location(&self) -> crate::DeviceLocation { + crate::DeviceLocation::Metal { + gpu_id: self.registry_id() as usize, + } + } + + fn same_device(&self, rhs: &Self) -> bool { + self.id == rhs.id + } + + unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result { + let buffer = self.new_buffer(shape.elem_count(), dtype, "alloc-uninit")?; + Ok(MetalStorage::new( + buffer, + self.clone(), + shape.elem_count(), + dtype, + )) + } + + fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result { + let size = shape.elem_count() * dtype.size_in_bytes(); + let buffer = self.allocate_zeros(size)?; + Ok(MetalStorage::new( + buffer, + self.clone(), + shape.elem_count(), + dtype, + )) + } + + fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result { + let name = match dtype { + DType::U8 => "fill_u8", + DType::U32 => "fill_u32", + DType::I64 => "fill_i64", + DType::F16 => "fill_f16", + DType::BF16 => "fill_bf16", + DType::F32 => "fill_f32", + DType::I32 => "fill_i32", + DType::I16 => "fill_i16", + DType::F8E4M3 => crate::bail!("Metal device does not yet support F8E4M3."), + DType::F64 => { + let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?; + return self.storage_from_cpu_storage(&cpu_storage); + } + }; + let buffer = self.new_buffer(shape.elem_count(), dtype, "alloc-ones")?; + let command_buffer = self.command_buffer()?; + mcandle_metal_kernels::call_const_fill( + &self.device, + &command_buffer, + &self.kernels, + name, + shape.elem_count(), + &buffer, + 1., + ) + .map_err(MetalError::from)?; + + Ok(MetalStorage::new( + buffer, + self.clone(), + shape.elem_count(), + dtype, + )) + } + + fn storage_from_slice(&self, s: &[T]) -> Result { + let (count, buffer) = match T::cpu_storage_ref(s) { + CpuStorageRef::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorageRef::U32(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorageRef::I16(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorageRef::I32(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorageRef::I64(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorageRef::BF16(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorageRef::F16(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorageRef::F32(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorageRef::F64(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorageRef::F8E4M3(_) => crate::bail!("Metal device does not yet support F8E4M3."), + }; + Ok(Self::Storage::new(buffer?, self.clone(), count, T::DTYPE)) + } + + fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result { + let (count, buffer) = match storage { + CpuStorage::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorage::U32(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorage::I16(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorage::I32(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorage::I64(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorage::BF16(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorage::F16(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorage::F32(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorage::F64(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorage::F8E4M3(_) => crate::bail!("Metal device does not yet support F8E4M3."), + }; + Ok(Self::Storage::new( + buffer?, + self.clone(), + count, + storage.dtype(), + )) + } + + fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result { + self.storage_from_cpu_storage(&storage) + } + + fn rand_uniform( + &self, + shape: &Shape, + dtype: DType, + min: f64, + max: f64, + ) -> Result { + let name = match dtype { + DType::F32 => "rand_uniform_f32", + DType::F16 => "rand_uniform_f16", + DType::BF16 => "rand_uniform_bf16", + dtype => crate::bail!("rand_uniform not implemented for {dtype:?}"), + }; + let buffer = self.new_buffer(shape.elem_count(), dtype, "rand_uniform")?; + let command_buffer = self.command_buffer()?; + mcandle_metal_kernels::call_random_uniform( + &self.device, + &command_buffer, + &self.kernels, + name, + min as f32, + max as f32, + shape.elem_count(), + &self.seed.lock().unwrap(), + &buffer, + ) + .map_err(MetalError::from)?; + + Ok(Self::Storage::new( + buffer, + self.clone(), + shape.elem_count(), + dtype, + )) + } + + fn rand_normal( + &self, + shape: &Shape, + dtype: DType, + mean: f64, + stddev: f64, + ) -> Result { + let name = match dtype { + DType::F32 => "rand_normal_f32", + DType::F16 => "rand_normal_f16", + DType::BF16 => "rand_normal_bf16", + dtype => crate::bail!("rand_uniform not implemented for {dtype:?}"), + }; + let buffer = self.new_buffer(shape.elem_count(), dtype, "rand_normal")?; + let command_buffer = self.command_buffer()?; + mcandle_metal_kernels::call_random_normal( + &self.device, + &command_buffer, + &self.kernels, + name, + mean as f32, + stddev as f32, + shape.elem_count(), + &self.seed.lock().unwrap(), + &buffer, + ) + .map_err(MetalError::from)?; + + Ok(Self::Storage::new( + buffer, + self.clone(), + shape.elem_count(), + dtype, + )) + } + + fn set_seed(&self, seed: u64) -> Result<()> { + let seed: u32 = seed.try_into().map_err(|_| { + MetalError::Message("Metal seed must be less than or equal to u32::MAX".to_string()) + })?; + + let seed_buffer = self.seed.try_lock().map_err(MetalError::from)?; + let contents = seed_buffer.contents(); + unsafe { + std::ptr::copy([seed].as_ptr(), contents as *mut u32, 1); + } + seed_buffer.did_modify_range(metal::NSRange::new(0, 4)); + + *self.seed_value.write().unwrap() = seed as u64; + + Ok(()) + } + + fn get_current_seed(&self) -> Result { + Ok(*self.seed_value.read().unwrap()) + } + + fn synchronize(&self) -> Result<()> { + self.wait_until_completed() + } +} diff --git a/mcandle/mcandle-core/src/mkl.rs b/mcandle/mcandle-core/src/mkl.rs new file mode 100644 index 000000000..359add741 --- /dev/null +++ b/mcandle/mcandle-core/src/mkl.rs @@ -0,0 +1,418 @@ +#![allow(dead_code)] +use libc::{c_char, c_double, c_float, c_int}; + +mod ffi { + use super::*; + extern "C" { + pub fn vsTanh(n: c_int, a: *const c_float, y: *mut c_float); + pub fn vdTanh(n: c_int, a: *const c_double, y: *mut c_double); + pub fn vsExp(n: c_int, a: *const c_float, y: *mut c_float); + pub fn vdExp(n: c_int, a: *const c_double, y: *mut c_double); + pub fn vsLn(n: c_int, a: *const c_float, y: *mut c_float); + pub fn vdLn(n: c_int, a: *const c_double, y: *mut c_double); + pub fn vsSin(n: c_int, a: *const c_float, y: *mut c_float); + pub fn vdSin(n: c_int, a: *const c_double, y: *mut c_double); + pub fn vsCos(n: c_int, a: *const c_float, y: *mut c_float); + pub fn vdCos(n: c_int, a: *const c_double, y: *mut c_double); + pub fn vsSqrt(n: c_int, a: *const c_float, y: *mut c_float); + pub fn vdSqrt(n: c_int, a: *const c_double, y: *mut c_double); + + pub fn vsAdd(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float); + pub fn vdAdd(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double); + pub fn vsSub(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float); + pub fn vdSub(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double); + pub fn vsMul(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float); + pub fn vdMul(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double); + pub fn vsDiv(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float); + pub fn vdDiv(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double); + pub fn vsFmax(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float); + pub fn vdFmax(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double); + pub fn vsFmin(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float); + pub fn vdFmin(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double); + + pub fn sgemm_( + transa: *const c_char, + transb: *const c_char, + m: *const c_int, + n: *const c_int, + k: *const c_int, + alpha: *const c_float, + a: *const c_float, + lda: *const c_int, + b: *const c_float, + ldb: *const c_int, + beta: *const c_float, + c: *mut c_float, + ldc: *const c_int, + ); + pub fn dgemm_( + transa: *const c_char, + transb: *const c_char, + m: *const c_int, + n: *const c_int, + k: *const c_int, + alpha: *const c_double, + a: *const c_double, + lda: *const c_int, + b: *const c_double, + ldb: *const c_int, + beta: *const c_double, + c: *mut c_double, + ldc: *const c_int, + ); + pub fn hgemm_( + transa: *const c_char, + transb: *const c_char, + m: *const c_int, + n: *const c_int, + k: *const c_int, + alpha: *const half::f16, + a: *const half::f16, + lda: *const c_int, + b: *const half::f16, + ldb: *const c_int, + beta: *const half::f16, + c: *mut half::f16, + ldc: *const c_int, + ); + } +} + +#[allow(clippy::too_many_arguments)] +#[inline] +pub unsafe fn sgemm( + transa: u8, + transb: u8, + m: i32, + n: i32, + k: i32, + alpha: f32, + a: &[f32], + lda: i32, + b: &[f32], + ldb: i32, + beta: f32, + c: &mut [f32], + ldc: i32, +) { + ffi::sgemm_( + &(transa as c_char), + &(transb as c_char), + &m, + &n, + &k, + &alpha, + a.as_ptr(), + &lda, + b.as_ptr(), + &ldb, + &beta, + c.as_mut_ptr(), + &ldc, + ) +} + +#[allow(clippy::too_many_arguments)] +#[inline] +pub unsafe fn dgemm( + transa: u8, + transb: u8, + m: i32, + n: i32, + k: i32, + alpha: f64, + a: &[f64], + lda: i32, + b: &[f64], + ldb: i32, + beta: f64, + c: &mut [f64], + ldc: i32, +) { + ffi::dgemm_( + &(transa as c_char), + &(transb as c_char), + &m, + &n, + &k, + &alpha, + a.as_ptr(), + &lda, + b.as_ptr(), + &ldb, + &beta, + c.as_mut_ptr(), + &ldc, + ) +} + +#[allow(clippy::too_many_arguments)] +#[inline] +pub unsafe fn hgemm( + transa: u8, + transb: u8, + m: i32, + n: i32, + k: i32, + alpha: half::f16, + a: &[half::f16], + lda: i32, + b: &[half::f16], + ldb: i32, + beta: half::f16, + c: &mut [half::f16], + ldc: i32, +) { + ffi::hgemm_( + &(transa as c_char), + &(transb as c_char), + &m, + &n, + &k, + &alpha, + a.as_ptr(), + &lda, + b.as_ptr(), + &ldb, + &beta, + c.as_mut_ptr(), + &ldc, + ) +} + +#[inline] +pub fn vs_exp(a: &[f32], y: &mut [f32]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vsExp(a_len as i32, a.as_ptr(), y.as_mut_ptr()) } +} + +#[inline] +pub fn vd_exp(a: &[f64], y: &mut [f64]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vdExp(a_len as i32, a.as_ptr(), y.as_mut_ptr()) } +} + +#[inline] +pub fn vs_ln(a: &[f32], y: &mut [f32]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vsLn(a_len as i32, a.as_ptr(), y.as_mut_ptr()) } +} + +#[inline] +pub fn vd_ln(a: &[f64], y: &mut [f64]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vdLn(a_len as i32, a.as_ptr(), y.as_mut_ptr()) } +} + +#[inline] +pub fn vs_sin(a: &[f32], y: &mut [f32]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vsSin(a_len as i32, a.as_ptr(), y.as_mut_ptr()) } +} + +#[inline] +pub fn vd_sin(a: &[f64], y: &mut [f64]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vdSin(a_len as i32, a.as_ptr(), y.as_mut_ptr()) } +} + +#[inline] +pub fn vs_cos(a: &[f32], y: &mut [f32]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vsCos(a_len as i32, a.as_ptr(), y.as_mut_ptr()) } +} + +#[inline] +pub fn vd_cos(a: &[f64], y: &mut [f64]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vdCos(a_len as i32, a.as_ptr(), y.as_mut_ptr()) } +} + +#[inline] +pub fn vs_sqrt(a: &[f32], y: &mut [f32]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vsSqrt(a_len as i32, a.as_ptr(), y.as_mut_ptr()) } +} + +#[inline] +pub fn vd_sqrt(a: &[f64], y: &mut [f64]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vdSqrt(a_len as i32, a.as_ptr(), y.as_mut_ptr()) } +} + +#[inline] +pub fn vs_sqr(a: &[f32], y: &mut [f32]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vsMul(a_len as i32, a.as_ptr(), a.as_ptr(), y.as_mut_ptr()) } +} + +#[inline] +pub fn vd_sqr(a: &[f64], y: &mut [f64]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vdMul(a_len as i32, a.as_ptr(), a.as_ptr(), y.as_mut_ptr()) } +} + +#[inline] +pub fn vs_tanh(a: &[f32], y: &mut [f32]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vsTanh(a_len as i32, a.as_ptr(), y.as_mut_ptr()) } +} + +#[inline] +pub fn vd_tanh(a: &[f64], y: &mut [f64]) { + let a_len = a.len(); + let y_len = y.len(); + if a_len != y_len { + panic!("a and y have different lengths {a_len} <> {y_len}") + } + unsafe { ffi::vdTanh(a_len as i32, a.as_ptr(), y.as_mut_ptr()) } +} + +// The vector functions from mkl can be performed in place by using the same array for input and +// output. +// https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-c/2023-2/vector-mathematical-functions.html +#[inline] +pub fn vs_tanh_inplace(y: &mut [f32]) { + unsafe { ffi::vsTanh(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) } +} + +#[inline] +pub fn vd_tanh_inplace(y: &mut [f64]) { + unsafe { ffi::vdTanh(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) } +} + +#[inline] +pub fn vs_exp_inplace(y: &mut [f32]) { + unsafe { ffi::vsExp(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) } +} + +#[inline] +pub fn vd_exp_inplace(y: &mut [f64]) { + unsafe { ffi::vdExp(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) } +} + +#[inline] +pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) { + for (&v, y) in vs.iter().zip(ys.iter_mut()) { + *y = (2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v) + } + vs_tanh_inplace(ys); + for (&v, y) in vs.iter().zip(ys.iter_mut()) { + *y = 0.5 * v * (1.0 + *y) + } +} + +#[inline] +pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) { + for (&v, y) in vs.iter().zip(ys.iter_mut()) { + *y = (2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v) + } + vd_tanh_inplace(ys); + for (&v, y) in vs.iter().zip(ys.iter_mut()) { + *y = 0.5 * v * (1.0 + *y) + } +} + +#[inline] +pub fn vs_silu(vs: &[f32], ys: &mut [f32]) { + for (&v, y) in vs.iter().zip(ys.iter_mut()) { + *y = -v + } + vs_exp_inplace(ys); + for (&v, y) in vs.iter().zip(ys.iter_mut()) { + *y = v / (1.0 + *y) + } +} + +#[inline] +pub fn vd_silu(vs: &[f64], ys: &mut [f64]) { + for (&v, y) in vs.iter().zip(ys.iter_mut()) { + *y = -v + } + vd_exp_inplace(ys); + for (&v, y) in vs.iter().zip(ys.iter_mut()) { + *y = v / (1.0 + *y) + } +} + +macro_rules! binary_op { + ($fn_name:ident, $ty:ty, $mkl_name:ident) => { + #[inline] + pub fn $fn_name(a: &[$ty], b: &[$ty], y: &mut [$ty]) { + let a_len = a.len(); + let b_len = b.len(); + let y_len = y.len(); + if a_len != y_len || b_len != y_len { + panic!( + "{} a,b,y len mismatch {a_len} {b_len} {y_len}", + stringify!($fn_name) + ); + } + unsafe { ffi::$mkl_name(a_len as i32, a.as_ptr(), b.as_ptr(), y.as_mut_ptr()) } + } + }; +} +binary_op!(vs_add, f32, vsAdd); +binary_op!(vd_add, f64, vdAdd); +binary_op!(vs_sub, f32, vsSub); +binary_op!(vd_sub, f64, vdSub); +binary_op!(vs_mul, f32, vsMul); +binary_op!(vd_mul, f64, vdMul); +binary_op!(vs_div, f32, vsDiv); +binary_op!(vd_div, f64, vdDiv); +binary_op!(vs_max, f32, vsFmax); +binary_op!(vd_max, f64, vdFmax); +binary_op!(vs_min, f32, vsFmin); +binary_op!(vd_min, f64, vdFmin); diff --git a/mcandle/mcandle-core/src/npy.rs b/mcandle/mcandle-core/src/npy.rs new file mode 100644 index 000000000..28d5a63e9 --- /dev/null +++ b/mcandle/mcandle-core/src/npy.rs @@ -0,0 +1,472 @@ +//! Numpy support for tensors. +//! +//! The spec for the npy format can be found in +//! [npy-format](https://docs.scipy.org/doc/numpy-1.14.2/neps/npy-format.html). +//! The functions from this module can be used to read tensors from npy/npz files +//! or write tensors to these files. A npy file contains a single tensor (unnamed) +//! whereas a npz file can contain multiple named tensors. npz files are also compressed. +//! +//! These two formats are easy to use in Python using the numpy library. +//! +//! ```python +//! import numpy as np +//! x = np.arange(10) +//! +//! # Write a npy file. +//! np.save("test.npy", x) +//! +//! # Read a value from the npy file. +//! x = np.load("test.npy") +//! +//! # Write multiple values to a npz file. +//! values = { "x": x, "x_plus_one": x + 1 } +//! np.savez("test.npz", **values) +//! +//! # Load multiple values from a npz file. +//! values = np.loadz("test.npz") +//! ``` +use crate::{DType, Device, Error, Result, Shape, Tensor}; +use byteorder::{LittleEndian, ReadBytesExt}; +use float8::F8E4M3; +use half::{bf16, f16, slice::HalfFloatSliceExt}; +use std::collections::HashMap; +use std::fs::File; +use std::io::{BufReader, Read, Write}; +use std::path::Path; +use std::slice; + +const NPY_MAGIC_STRING: &[u8] = b"\x93NUMPY"; +const NPY_SUFFIX: &str = ".npy"; + +fn read_header(reader: &mut R) -> Result { + let mut magic_string = vec![0u8; NPY_MAGIC_STRING.len()]; + reader.read_exact(&mut magic_string)?; + if magic_string != NPY_MAGIC_STRING { + return Err(Error::Npy("magic string mismatch".to_string())); + } + let mut version = [0u8; 2]; + reader.read_exact(&mut version)?; + let header_len_len = match version[0] { + 1 => 2, + 2 => 4, + otherwise => return Err(Error::Npy(format!("unsupported version {otherwise}"))), + }; + let mut header_len = vec![0u8; header_len_len]; + reader.read_exact(&mut header_len)?; + let header_len = header_len + .iter() + .rev() + .fold(0_usize, |acc, &v| 256 * acc + v as usize); + let mut header = vec![0u8; header_len]; + reader.read_exact(&mut header)?; + Ok(String::from_utf8_lossy(&header).to_string()) +} + +#[derive(Debug, PartialEq)] +struct Header { + descr: DType, + fortran_order: bool, + shape: Vec, +} + +impl Header { + fn shape(&self) -> Shape { + Shape::from(self.shape.as_slice()) + } + + fn to_string(&self) -> Result { + let fortran_order = if self.fortran_order { "True" } else { "False" }; + let mut shape = self + .shape + .iter() + .map(|x| x.to_string()) + .collect::>() + .join(","); + let descr = match self.descr { + DType::BF16 => Err(Error::Npy("bf16 is not supported".into()))?, + DType::F16 => "f2", + DType::F32 => "f4", + DType::F64 => "f8", + DType::I16 => "i2", + DType::I32 => "i4", + DType::I64 => "i8", + DType::U32 => "u4", + DType::U8 => "u1", + DType::F8E4M3 => Err(Error::Npy("f8e4m3 is not supported".into()))?, + }; + if !shape.is_empty() { + shape.push(',') + } + Ok(format!( + "{{'descr': '<{descr}', 'fortran_order': {fortran_order}, 'shape': ({shape}), }}" + )) + } + + // Hacky parser for the npy header, a typical example would be: + // {'descr': ' Result

{ + let header = + header.trim_matches(|c: char| c == '{' || c == '}' || c == ',' || c.is_whitespace()); + + let mut parts: Vec = vec![]; + let mut start_index = 0usize; + let mut cnt_parenthesis = 0i64; + for (index, c) in header.chars().enumerate() { + match c { + '(' => cnt_parenthesis += 1, + ')' => cnt_parenthesis -= 1, + ',' => { + if cnt_parenthesis == 0 { + parts.push(header[start_index..index].to_owned()); + start_index = index + 1; + } + } + _ => {} + } + } + parts.push(header[start_index..].to_owned()); + let mut part_map: HashMap = HashMap::new(); + for part in parts.iter() { + let part = part.trim(); + if !part.is_empty() { + match part.split(':').collect::>().as_slice() { + [key, value] => { + let key = key.trim_matches(|c: char| c == '\'' || c.is_whitespace()); + let value = value.trim_matches(|c: char| c == '\'' || c.is_whitespace()); + let _ = part_map.insert(key.to_owned(), value.to_owned()); + } + _ => return Err(Error::Npy(format!("unable to parse header {header}"))), + } + } + } + let fortran_order = match part_map.get("fortran_order") { + None => false, + Some(fortran_order) => match fortran_order.as_ref() { + "False" => false, + "True" => true, + _ => return Err(Error::Npy(format!("unknown fortran_order {fortran_order}"))), + }, + }; + let descr = match part_map.get("descr") { + None => return Err(Error::Npy("no descr in header".to_string())), + Some(descr) => { + if descr.is_empty() { + return Err(Error::Npy("empty descr".to_string())); + } + if descr.starts_with('>') { + return Err(Error::Npy(format!("little-endian descr {descr}"))); + } + // the only supported types in tensor are: + // float64, float32, float16, + // complex64, complex128, + // int64, int32, int16, int8, + // uint8, and bool. + match descr.trim_matches(|c: char| c == '=' || c == '<' || c == '|') { + "e" | "f2" => DType::F16, + "f" | "f4" => DType::F32, + "d" | "f8" => DType::F64, + // "i" | "i4" => DType::S32, + "q" | "i8" => DType::I64, + // "h" | "i2" => DType::S16, + // "b" | "i1" => DType::S8, + "B" | "u1" => DType::U8, + "I" | "u4" => DType::U32, + "?" | "b1" => DType::U8, + // "F" | "F4" => DType::C64, + // "D" | "F8" => DType::C128, + descr => return Err(Error::Npy(format!("unrecognized descr {descr}"))), + } + } + }; + let shape = match part_map.get("shape") { + None => return Err(Error::Npy("no shape in header".to_string())), + Some(shape) => { + let shape = shape.trim_matches(|c: char| c == '(' || c == ')' || c == ','); + if shape.is_empty() { + vec![] + } else { + shape + .split(',') + .map(|v| v.trim().parse::()) + .collect::, _>>()? + } + } + }; + Ok(Header { + descr, + fortran_order, + shape, + }) + } +} + +impl Tensor { + // TODO: Add the possibility to read directly to a device? + pub(crate) fn from_reader( + shape: Shape, + dtype: DType, + reader: &mut R, + ) -> Result { + let elem_count = shape.elem_count(); + match dtype { + DType::BF16 => { + let mut data_t = vec![bf16::ZERO; elem_count]; + reader.read_u16_into::(data_t.reinterpret_cast_mut())?; + Tensor::from_vec(data_t, shape, &Device::Cpu) + } + DType::F16 => { + let mut data_t = vec![f16::ZERO; elem_count]; + reader.read_u16_into::(data_t.reinterpret_cast_mut())?; + Tensor::from_vec(data_t, shape, &Device::Cpu) + } + DType::F32 => { + let mut data_t = vec![0f32; elem_count]; + reader.read_f32_into::(&mut data_t)?; + Tensor::from_vec(data_t, shape, &Device::Cpu) + } + DType::F64 => { + let mut data_t = vec![0f64; elem_count]; + reader.read_f64_into::(&mut data_t)?; + Tensor::from_vec(data_t, shape, &Device::Cpu) + } + DType::U8 => { + let mut data_t = vec![0u8; elem_count]; + reader.read_exact(&mut data_t)?; + Tensor::from_vec(data_t, shape, &Device::Cpu) + } + DType::U32 => { + let mut data_t = vec![0u32; elem_count]; + reader.read_u32_into::(&mut data_t)?; + Tensor::from_vec(data_t, shape, &Device::Cpu) + } + DType::I16 => { + let mut data_t = vec![0i16; elem_count]; + reader.read_i16_into::(&mut data_t)?; + Tensor::from_vec(data_t, shape, &Device::Cpu) + } + DType::I32 => { + let mut data_t = vec![0i32; elem_count]; + reader.read_i32_into::(&mut data_t)?; + Tensor::from_vec(data_t, shape, &Device::Cpu) + } + DType::I64 => { + let mut data_t = vec![0i64; elem_count]; + reader.read_i64_into::(&mut data_t)?; + Tensor::from_vec(data_t, shape, &Device::Cpu) + } + DType::F8E4M3 => { + let mut data_t = vec![F8E4M3::ZERO; elem_count]; + let ptr = data_t.as_mut_ptr().cast::(); + let len = data_t.len(); + reader.read_i8_into(unsafe { slice::from_raw_parts_mut(ptr, len) })?; + Tensor::from_vec(data_t, shape, &Device::Cpu) + } + } + } + + /// Reads a npy file and return the stored multi-dimensional array as a tensor. + pub fn read_npy>(path: T) -> Result { + let mut reader = File::open(path.as_ref())?; + let header = read_header(&mut reader)?; + let header = Header::parse(&header)?; + if header.fortran_order { + return Err(Error::Npy("fortran order not supported".to_string())); + } + Self::from_reader(header.shape(), header.descr, &mut reader) + } + + /// Reads a npz file and returns the stored multi-dimensional arrays together with their names. + pub fn read_npz>(path: T) -> Result> { + let zip_reader = BufReader::new(File::open(path.as_ref())?); + let mut zip = zip::ZipArchive::new(zip_reader)?; + let mut result = vec![]; + for i in 0..zip.len() { + let mut reader = zip.by_index(i)?; + let name = { + let name = reader.name(); + name.strip_suffix(NPY_SUFFIX).unwrap_or(name).to_owned() + }; + let header = read_header(&mut reader)?; + let header = Header::parse(&header)?; + if header.fortran_order { + return Err(Error::Npy("fortran order not supported".to_string())); + } + let s = Self::from_reader(header.shape(), header.descr, &mut reader)?; + result.push((name, s)) + } + Ok(result) + } + + /// Reads a npz file and returns the stored multi-dimensional arrays for some specified names. + pub fn read_npz_by_name>(path: T, names: &[&str]) -> Result> { + let zip_reader = BufReader::new(File::open(path.as_ref())?); + let mut zip = zip::ZipArchive::new(zip_reader)?; + let mut result = vec![]; + for name in names.iter() { + let mut reader = match zip.by_name(&format!("{name}{NPY_SUFFIX}")) { + Ok(reader) => reader, + Err(_) => Err(Error::Npy(format!( + "no array for {name} in {:?}", + path.as_ref() + )))?, + }; + let header = read_header(&mut reader)?; + let header = Header::parse(&header)?; + if header.fortran_order { + return Err(Error::Npy("fortran order not supported".to_string())); + } + let s = Self::from_reader(header.shape(), header.descr, &mut reader)?; + result.push(s) + } + Ok(result) + } + + fn write(&self, f: &mut T) -> Result<()> { + f.write_all(NPY_MAGIC_STRING)?; + f.write_all(&[1u8, 0u8])?; + let header = Header { + descr: self.dtype(), + fortran_order: false, + shape: self.dims().to_vec(), + }; + let mut header = header.to_string()?; + let pad = 16 - (NPY_MAGIC_STRING.len() + 5 + header.len()) % 16; + for _ in 0..pad % 16 { + header.push(' ') + } + header.push('\n'); + f.write_all(&[(header.len() % 256) as u8, (header.len() / 256) as u8])?; + f.write_all(header.as_bytes())?; + self.write_bytes(f) + } + + /// Writes a multi-dimensional array in the npy format. + pub fn write_npy>(&self, path: T) -> Result<()> { + let mut f = File::create(path.as_ref())?; + self.write(&mut f) + } + + /// Writes multiple multi-dimensional arrays using the npz format. + pub fn write_npz, T: AsRef, P: AsRef>( + ts: &[(S, T)], + path: P, + ) -> Result<()> { + let mut zip = zip::ZipWriter::new(File::create(path.as_ref())?); + let options: zip::write::FileOptions<()> = + zip::write::FileOptions::default().compression_method(zip::CompressionMethod::Stored); + + for (name, tensor) in ts.iter() { + zip.start_file(format!("{}.npy", name.as_ref()), options)?; + tensor.as_ref().write(&mut zip)? + } + Ok(()) + } +} + +/// Lazy tensor loader. +pub struct NpzTensors { + index_per_name: HashMap, + path: std::path::PathBuf, + // We do not store a zip reader as it needs mutable access to extract data. Instead we + // re-create a zip reader for each tensor. +} + +impl NpzTensors { + pub fn new>(path: T) -> Result { + let path = path.as_ref().to_owned(); + let zip_reader = BufReader::new(File::open(&path)?); + let mut zip = zip::ZipArchive::new(zip_reader)?; + let mut index_per_name = HashMap::new(); + for i in 0..zip.len() { + let file = zip.by_index(i)?; + let name = { + let name = file.name(); + name.strip_suffix(NPY_SUFFIX).unwrap_or(name).to_owned() + }; + index_per_name.insert(name, i); + } + Ok(Self { + index_per_name, + path, + }) + } + + pub fn names(&self) -> Vec<&String> { + self.index_per_name.keys().collect() + } + + /// This only returns the shape and dtype for a named tensor. Compared to `get`, this avoids + /// reading the whole tensor data. + pub fn get_shape_and_dtype(&self, name: &str) -> Result<(Shape, DType)> { + let index = match self.index_per_name.get(name) { + None => crate::bail!("cannot find tensor {name}"), + Some(index) => *index, + }; + let zip_reader = BufReader::new(File::open(&self.path)?); + let mut zip = zip::ZipArchive::new(zip_reader)?; + let mut reader = zip.by_index(index)?; + let header = read_header(&mut reader)?; + let header = Header::parse(&header)?; + Ok((header.shape(), header.descr)) + } + + pub fn get(&self, name: &str) -> Result> { + let index = match self.index_per_name.get(name) { + None => return Ok(None), + Some(index) => *index, + }; + // We hope that the file has not changed since first reading it. + let zip_reader = BufReader::new(File::open(&self.path)?); + let mut zip = zip::ZipArchive::new(zip_reader)?; + let mut reader = zip.by_index(index)?; + let header = read_header(&mut reader)?; + let header = Header::parse(&header)?; + if header.fortran_order { + return Err(Error::Npy("fortran order not supported".to_string())); + } + let tensor = Tensor::from_reader(header.shape(), header.descr, &mut reader)?; + Ok(Some(tensor)) + } +} + +#[cfg(test)] +mod tests { + use super::Header; + + #[test] + fn parse() { + let h = "{'descr': ' &'static str { + match self { + Self::ArgMax => "argmax", + Self::ArgMin => "argmin", + Self::Min => "min", + Self::Max => "max", + Self::Sum => "sum", + } + } +} + +// These ops return the same type as their input type. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum BinaryOp { + Add, + Mul, + Sub, + Div, + Maximum, + Minimum, +} + +// Unary ops with no argument +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum UnaryOp { + Exp, + Log, + Sin, + Cos, + Abs, + Neg, + Recip, + Sqr, + Sqrt, + Gelu, + GeluErf, + Erf, + Relu, + Silu, + Tanh, + Floor, + Ceil, + Round, + Sign, +} + +#[derive(Clone)] +pub enum Op { + Binary(Tensor, Tensor, BinaryOp), + Unary(Tensor, UnaryOp), + Cmp(Tensor, CmpOp), + // The third argument is the reduced shape with `keepdim=true`. + Reduce(Tensor, ReduceOp, Vec), + Matmul(Tensor, Tensor), + Gather(Tensor, Tensor, usize), + ScatterAdd(Tensor, Tensor, Tensor, usize), + IndexSelect(Tensor, Tensor, usize), + IndexAdd(Tensor, Tensor, Tensor, usize), + WhereCond(Tensor, Tensor, Tensor), + + #[allow(dead_code)] + Conv1D { + arg: Tensor, + kernel: Tensor, + padding: usize, + stride: usize, + dilation: usize, + }, + + #[allow(dead_code)] + ConvTranspose1D { + arg: Tensor, + kernel: Tensor, + padding: usize, + output_padding: usize, + stride: usize, + dilation: usize, + }, + + #[allow(dead_code)] + Conv2D { + arg: Tensor, + kernel: Tensor, + padding: usize, + stride: usize, + dilation: usize, + }, + + #[allow(dead_code)] + ConvTranspose2D { + arg: Tensor, + kernel: Tensor, + padding: usize, + output_padding: usize, + stride: usize, + dilation: usize, + }, + + AvgPool2D { + arg: Tensor, + kernel_size: (usize, usize), + stride: (usize, usize), + }, + + MaxPool2D { + arg: Tensor, + kernel_size: (usize, usize), + stride: (usize, usize), + }, + + UpsampleNearest1D { + arg: Tensor, + target_size: usize, + }, + UpsampleNearest2D { + arg: Tensor, + target_h: usize, + target_w: usize, + }, + + Cat(Vec, usize), + + #[allow(dead_code)] // add is currently unused. + Affine { + arg: Tensor, + mul: f64, + add: f64, + }, + ToDType(Tensor), + Copy(Tensor), + Broadcast(Tensor), + Narrow(Tensor, usize, usize, usize), + SliceScatter0(Tensor, Tensor, usize), + Reshape(Tensor), + ToDevice(Tensor), + Transpose(Tensor, usize, usize), + Permute(Tensor, Vec), + Elu(Tensor, f64), + Powf(Tensor, f64), + CustomOp1( + Tensor, + std::sync::Arc>, + ), + CustomOp2( + Tensor, + Tensor, + std::sync::Arc>, + ), + CustomOp3( + Tensor, + Tensor, + Tensor, + std::sync::Arc>, + ), +} + +pub trait UnaryOpT { + const NAME: &'static str; + const KERNEL: &'static str; + const V: Self; + fn bf16(v1: bf16) -> bf16; + fn f16(v1: f16) -> f16; + fn f32(v1: f32) -> f32; + fn f64(v1: f64) -> f64; + fn f8e4m3(v1: F8E4M3) -> F8E4M3; + fn u8(v1: u8) -> u8; + fn u32(v1: u32) -> u32; + fn i16(v1: i16) -> i16; + fn i32(v1: i32) -> i32; + fn i64(v1: i64) -> i64; + + // There is no very good way to represent optional function in traits so we go for an explicit + // boolean flag to mark the function as existing. + const BF16_VEC: bool = false; + fn bf16_vec(_xs: &[bf16], _ys: &mut [bf16]) {} + const F16_VEC: bool = false; + fn f16_vec(_xs: &[f16], _ys: &mut [f16]) {} + const F8E4M3_VEC: bool = false; + fn f8e4m3_vec(_xs: &[F8E4M3], _ys: &mut [F8E4M3]) {} + const F32_VEC: bool = false; + fn f32_vec(_xs: &[f32], _ys: &mut [f32]) {} + const F64_VEC: bool = false; + fn f64_vec(_xs: &[f64], _ys: &mut [f64]) {} +} + +pub trait BinaryOpT { + const NAME: &'static str; + const KERNEL: &'static str; + const V: Self; + fn bf16(v1: bf16, v2: bf16) -> bf16; + fn f16(v1: f16, v2: f16) -> f16; + fn f32(v1: f32, v2: f32) -> f32; + fn f64(v1: f64, v2: f64) -> f64; + fn f8e4m3(v1: F8E4M3, v2: F8E4M3) -> F8E4M3; + fn u8(v1: u8, v2: u8) -> u8; + fn u32(v1: u32, v2: u32) -> u32; + fn i16(v1: i16, v2: i16) -> i16; + fn i32(v1: i32, v2: i32) -> i32; + fn i64(v1: i64, v2: i64) -> i64; + + const BF16_VEC: bool = false; + fn bf16_vec(_xs1: &[bf16], _xs2: &[bf16], _ys: &mut [bf16]) {} + const F16_VEC: bool = false; + fn f16_vec(_xs1: &[f16], _xs2: &[f16], _ys: &mut [f16]) {} + const F32_VEC: bool = false; + fn f32_vec(_xs1: &[f32], _xs2: &[f32], _ys: &mut [f32]) {} + const F64_VEC: bool = false; + fn f64_vec(_xs1: &[f64], _xs2: &[f64], _ys: &mut [f64]) {} + const F8E4M3_VEC: bool = false; + fn f8e4m3_vec(_xs1: &[F8E4M3], __xs2: &[F8E4M3], _ys: &mut [F8E4M3]) {} + const U8_VEC: bool = false; + fn u8_vec(_xs1: &[u8], _xs2: &[u8], _ys: &mut [u8]) {} + const U32_VEC: bool = false; + fn u32_vec(_xs1: &[u32], _xs2: &[u32], _ys: &mut [u32]) {} + const I64_VEC: bool = false; + fn i64_vec(_xs1: &[i64], _xs2: &[i64], _ys: &mut [i64]) {} + const I32_VEC: bool = false; + fn i32_vec(_xs1: &[i32], _xs2: &[i32], _ys: &mut [i32]) {} + const I16_VEC: bool = false; + fn i16_vec(_xs1: &[i16], _xs2: &[i16], _ys: &mut [i16]) {} +} + +pub(crate) struct Add; +pub(crate) struct Div; +pub(crate) struct Mul; +pub(crate) struct Sub; +pub(crate) struct Maximum; +pub(crate) struct Minimum; +pub(crate) struct Exp; +pub(crate) struct Log; +pub(crate) struct Sin; +pub(crate) struct Cos; +pub(crate) struct Abs; +pub(crate) struct Neg; +pub(crate) struct Recip; +pub(crate) struct Sqr; +pub(crate) struct Sqrt; +pub(crate) struct Gelu; +pub(crate) struct GeluErf; +pub(crate) struct Erf; +pub(crate) struct Relu; +pub(crate) struct Silu; +pub(crate) struct Tanh; +pub(crate) struct Floor; +pub(crate) struct Ceil; +pub(crate) struct Round; +pub(crate) struct Sign; + +macro_rules! bin_op { + ($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => { + impl BinaryOpT for $op { + const NAME: &'static str = $name; + const KERNEL: &'static str = concat!("b", $name); + const V: Self = $op; + #[inline(always)] + fn bf16(v1: bf16, v2: bf16) -> bf16 { + $e(v1, v2) + } + #[inline(always)] + fn f16(v1: f16, v2: f16) -> f16 { + $e(v1, v2) + } + #[inline(always)] + fn f32(v1: f32, v2: f32) -> f32 { + $e(v1, v2) + } + #[inline(always)] + fn f64(v1: f64, v2: f64) -> f64 { + $e(v1, v2) + } + #[inline(always)] + fn f8e4m3(v1: F8E4M3, v2: F8E4M3) -> F8E4M3 { + $e(v1, v2) + } + #[inline(always)] + fn u8(v1: u8, v2: u8) -> u8 { + $e(v1, v2) + } + #[inline(always)] + fn u32(v1: u32, v2: u32) -> u32 { + $e(v1, v2) + } + #[inline(always)] + fn i16(v1: i16, v2: i16) -> i16 { + $e(v1, v2) + } + #[inline(always)] + fn i32(v1: i32, v2: i32) -> i32 { + $e(v1, v2) + } + #[inline(always)] + fn i64(v1: i64, v2: i64) -> i64 { + $e(v1, v2) + } + + #[cfg(feature = "mkl")] + const F32_VEC: bool = true; + #[cfg(feature = "mkl")] + const F64_VEC: bool = true; + #[cfg(feature = "mkl")] + #[inline(always)] + fn f32_vec(xs1: &[f32], xs2: &[f32], ys: &mut [f32]) { + crate::mkl::$f32_vec(xs1, xs2, ys) + } + #[cfg(feature = "mkl")] + #[inline(always)] + fn f64_vec(xs1: &[f64], xs2: &[f64], ys: &mut [f64]) { + crate::mkl::$f64_vec(xs1, xs2, ys) + } + + #[cfg(feature = "accelerate")] + const F32_VEC: bool = true; + #[cfg(feature = "accelerate")] + const F64_VEC: bool = true; + #[cfg(feature = "accelerate")] + #[inline(always)] + fn f32_vec(xs1: &[f32], xs2: &[f32], ys: &mut [f32]) { + crate::accelerate::$f32_vec(xs1, xs2, ys) + } + #[cfg(feature = "accelerate")] + #[inline(always)] + fn f64_vec(xs1: &[f64], xs2: &[f64], ys: &mut [f64]) { + crate::accelerate::$f64_vec(xs1, xs2, ys) + } + } + }; +} + +bin_op!(Add, "add", |v1, v2| v1 + v2, vs_add, vd_add); +bin_op!(Sub, "sub", |v1, v2| v1 - v2, vs_sub, vd_sub); +bin_op!(Mul, "mul", |v1, v2| v1 * v2, vs_mul, vd_mul); +bin_op!(Div, "div", |v1, v2| v1 / v2, vs_div, vd_div); +bin_op!( + Minimum, + "minimum", + |v1, v2| if v1 > v2 { v2 } else { v1 }, + vs_min, + vd_min +); +bin_op!( + Maximum, + "maximum", + |v1, v2| if v1 < v2 { v2 } else { v1 }, + vs_max, + vd_max +); + +#[allow(clippy::redundant_closure_call)] +macro_rules! unary_op { + ($op: ident, $name: literal, $a: ident, $e: expr) => { + impl UnaryOpT for $op { + const NAME: &'static str = $name; + const KERNEL: &'static str = concat!("u", $name); + const V: Self = $op; + #[inline(always)] + fn bf16($a: bf16) -> bf16 { + $e + } + #[inline(always)] + fn f16($a: f16) -> f16 { + $e + } + #[inline(always)] + fn f8e4m3($a: F8E4M3) -> F8E4M3 { + $e + } + #[inline(always)] + fn f32($a: f32) -> f32 { + $e + } + #[inline(always)] + fn f64($a: f64) -> f64 { + $e + } + #[inline(always)] + fn u8(_: u8) -> u8 { + todo!("no unary function for u8") + } + #[inline(always)] + fn u32(_: u32) -> u32 { + todo!("no unary function for u32") + } + #[inline(always)] + fn i64(_: i64) -> i64 { + todo!("no unary function for i64") + } + #[inline(always)] + fn i32(_: i32) -> i32 { + todo!("no unary function for i32") + } + #[inline(always)] + fn i16(_: i16) -> i16 { + todo!("no unary function for i16") + } + } + }; + + ($op: ident, $name: literal, $a: ident, $e: expr, $f32_vec:ident, $f64_vec:ident) => { + impl UnaryOpT for $op { + const NAME: &'static str = $name; + const KERNEL: &'static str = concat!("u", $name); + const V: Self = $op; + #[inline(always)] + fn bf16($a: bf16) -> bf16 { + $e + } + #[inline(always)] + fn f16($a: f16) -> f16 { + $e + } + #[inline(always)] + fn f32($a: f32) -> f32 { + $e + } + #[inline(always)] + fn f64($a: f64) -> f64 { + $e + } + #[inline(always)] + fn f8e4m3($a: F8E4M3) -> F8E4M3 { + $e + } + #[inline(always)] + fn u8(_: u8) -> u8 { + todo!("no unary function for u8") + } + #[inline(always)] + fn u32(_: u32) -> u32 { + todo!("no unary function for u32") + } + #[inline(always)] + fn i64(_: i64) -> i64 { + todo!("no unary function for i64") + } + #[inline(always)] + fn i32(_: i32) -> i32 { + todo!("no unary function for i32") + } + #[inline(always)] + fn i16(_: i16) -> i16 { + todo!("no unary function for i16") + } + + #[cfg(feature = "mkl")] + const F32_VEC: bool = true; + #[cfg(feature = "mkl")] + const F64_VEC: bool = true; + #[cfg(feature = "mkl")] + #[inline(always)] + fn f32_vec(xs: &[f32], ys: &mut [f32]) { + crate::mkl::$f32_vec(xs, ys) + } + #[cfg(feature = "mkl")] + #[inline(always)] + fn f64_vec(xs: &[f64], ys: &mut [f64]) { + crate::mkl::$f64_vec(xs, ys) + } + + #[cfg(feature = "accelerate")] + const F32_VEC: bool = true; + #[cfg(feature = "accelerate")] + const F64_VEC: bool = true; + #[cfg(feature = "accelerate")] + #[inline(always)] + fn f32_vec(xs: &[f32], ys: &mut [f32]) { + crate::accelerate::$f32_vec(xs, ys) + } + #[cfg(feature = "accelerate")] + #[inline(always)] + fn f64_vec(xs: &[f64], ys: &mut [f64]) { + crate::accelerate::$f64_vec(xs, ys) + } + } + }; +} + +unary_op!(Exp, "exp", v, v.exp(), vs_exp, vd_exp); +unary_op!(Log, "log", v, v.ln(), vs_ln, vd_ln); +unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin); +unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos); +unary_op!(Tanh, "tanh", v, v.tanh(), vs_tanh, vd_tanh); +unary_op!(Neg, "neg", v, -v); +unary_op!(Recip, "recip", v, v.recip()); +unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr); +unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt); + +// Hardcode the value for sqrt(2/pi) +// https://github.com/huggingface/mcandle/issues/1982 +#[allow(clippy::excessive_precision)] +const SQRT_TWO_OVER_PI_F32: f32 = 0.79788456080286535587989211986876373; +#[allow(clippy::excessive_precision)] +const SQRT_TWO_OVER_PI_F64: f64 = 0.79788456080286535587989211986876373; + +/// Tanh based approximation of the `gelu` operation +/// GeluErf is the more precise one. +/// +impl UnaryOpT for Gelu { + const NAME: &'static str = "gelu"; + const V: Self = Gelu; + #[inline(always)] + fn bf16(v: bf16) -> bf16 { + bf16::from_f32_const(0.5) + * v + * (bf16::ONE + + bf16::tanh( + bf16::from_f32_const(SQRT_TWO_OVER_PI_F32) + * v + * (bf16::ONE + bf16::from_f32_const(0.044715) * v * v), + )) + } + #[inline(always)] + fn f16(v: f16) -> f16 { + f16::from_f32_const(0.5) + * v + * (f16::ONE + + f16::tanh( + f16::from_f32_const(SQRT_TWO_OVER_PI_F32) + * v + * (f16::ONE + f16::from_f32_const(0.044715) * v * v), + )) + } + #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + F8E4M3::from_f32(0.5) + * v + * (F8E4M3::ONE + + F8E4M3::tanh( + F8E4M3::from_f32(SQRT_TWO_OVER_PI_F32) + * v + * (F8E4M3::ONE + F8E4M3::from_f32(0.044715) * v * v), + )) + } + #[inline(always)] + fn f32(v: f32) -> f32 { + 0.5 * v * (1.0 + f32::tanh(SQRT_TWO_OVER_PI_F32 * v * (1.0 + 0.044715 * v * v))) + } + #[inline(always)] + fn f64(v: f64) -> f64 { + 0.5 * v * (1.0 + f64::tanh(SQRT_TWO_OVER_PI_F64 * v * (1.0 + 0.044715 * v * v))) + } + #[inline(always)] + fn u8(_: u8) -> u8 { + 0 + } + #[inline(always)] + fn u32(_: u32) -> u32 { + 0 + } + #[inline(always)] + fn i64(_: i64) -> i64 { + 0 + } + #[inline(always)] + fn i32(_: i32) -> i32 { + 0 + } + #[inline(always)] + fn i16(_: i16) -> i16 { + 0 + } + const KERNEL: &'static str = "ugelu"; + + #[cfg(feature = "mkl")] + const F32_VEC: bool = true; + + #[cfg(feature = "mkl")] + #[inline(always)] + fn f32_vec(xs: &[f32], ys: &mut [f32]) { + crate::mkl::vs_gelu(xs, ys) + } + + #[cfg(feature = "mkl")] + const F64_VEC: bool = true; + + #[cfg(feature = "mkl")] + #[inline(always)] + fn f64_vec(xs: &[f64], ys: &mut [f64]) { + crate::mkl::vd_gelu(xs, ys) + } + + #[cfg(feature = "accelerate")] + const F32_VEC: bool = true; + + #[cfg(feature = "accelerate")] + #[inline(always)] + fn f32_vec(xs: &[f32], ys: &mut [f32]) { + crate::accelerate::vs_gelu(xs, ys) + } + + #[cfg(feature = "accelerate")] + const F64_VEC: bool = true; + + #[cfg(feature = "accelerate")] + #[inline(always)] + fn f64_vec(xs: &[f64], ys: &mut [f64]) { + crate::accelerate::vd_gelu(xs, ys) + } +} + +/// `erf` operation +/// +impl UnaryOpT for Erf { + const NAME: &'static str = "erf"; + const KERNEL: &'static str = "uerf"; + const V: Self = Erf; + #[inline(always)] + fn bf16(v: bf16) -> bf16 { + bf16::from_f64(Self::f64(v.to_f64())) + } + #[inline(always)] + fn f16(v: f16) -> f16 { + f16::from_f64(Self::f64(v.to_f64())) + } + #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + F8E4M3::from_f64(Self::f64(v.to_f64())) + } + #[inline(always)] + fn f32(v: f32) -> f32 { + Self::f64(v as f64) as f32 + } + #[inline(always)] + fn f64(v: f64) -> f64 { + crate::cpu::erf::erf(v) + } + #[inline(always)] + fn u8(_: u8) -> u8 { + 0 + } + #[inline(always)] + fn u32(_: u32) -> u32 { + 0 + } + #[inline(always)] + fn i64(_: i64) -> i64 { + 0 + } + #[inline(always)] + fn i32(_: i32) -> i32 { + 0 + } + #[inline(always)] + fn i16(_: i16) -> i16 { + 0 + } +} + +/// Silu operation +impl UnaryOpT for Silu { + const NAME: &'static str = "silu"; + const V: Self = Silu; + #[inline(always)] + fn bf16(v: bf16) -> bf16 { + v / (bf16::ONE + (-v).exp()) + } + #[inline(always)] + fn f16(v: f16) -> f16 { + v / (f16::ONE + (-v).exp()) + } + #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + v / (F8E4M3::ONE + (-v).exp()) + } + #[inline(always)] + fn f32(v: f32) -> f32 { + v / (1.0 + (-v).exp()) + } + #[inline(always)] + fn f64(v: f64) -> f64 { + v / (1.0 + (-v).exp()) + } + #[inline(always)] + fn u8(_: u8) -> u8 { + 0 + } + #[inline(always)] + fn u32(_: u32) -> u32 { + 0 + } + #[inline(always)] + fn i64(_: i64) -> i64 { + 0 + } + #[inline(always)] + fn i32(_: i32) -> i32 { + 0 + } + #[inline(always)] + fn i16(_: i16) -> i16 { + 0 + } + const KERNEL: &'static str = "usilu"; + + #[cfg(feature = "mkl")] + const F32_VEC: bool = true; + + #[cfg(feature = "mkl")] + #[inline(always)] + fn f32_vec(xs: &[f32], ys: &mut [f32]) { + crate::mkl::vs_silu(xs, ys) + } + + #[cfg(feature = "mkl")] + const F64_VEC: bool = true; + + #[cfg(feature = "mkl")] + #[inline(always)] + fn f64_vec(xs: &[f64], ys: &mut [f64]) { + crate::mkl::vd_silu(xs, ys) + } + + #[cfg(feature = "accelerate")] + const F32_VEC: bool = true; + + #[cfg(feature = "accelerate")] + #[inline(always)] + fn f32_vec(xs: &[f32], ys: &mut [f32]) { + crate::accelerate::vs_silu(xs, ys) + } + + #[cfg(feature = "accelerate")] + const F64_VEC: bool = true; + + #[cfg(feature = "accelerate")] + #[inline(always)] + fn f64_vec(xs: &[f64], ys: &mut [f64]) { + crate::accelerate::vd_silu(xs, ys) + } +} + +impl UnaryOpT for Abs { + const NAME: &'static str = "abs"; + const KERNEL: &'static str = "uabs"; + const V: Self = Abs; + #[inline(always)] + fn bf16(v: bf16) -> bf16 { + v.abs() + } + #[inline(always)] + fn f16(v: f16) -> f16 { + v.abs() + } + #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + v.abs() + } + #[inline(always)] + fn f32(v: f32) -> f32 { + v.abs() + } + #[inline(always)] + fn f64(v: f64) -> f64 { + v.abs() + } + #[inline(always)] + fn u8(v: u8) -> u8 { + v + } + #[inline(always)] + fn u32(v: u32) -> u32 { + v + } + #[inline(always)] + fn i64(v: i64) -> i64 { + v.abs() + } + #[inline(always)] + fn i32(v: i32) -> i32 { + v.abs() + } + #[inline(always)] + fn i16(v: i16) -> i16 { + v.abs() + } +} + +impl UnaryOpT for Ceil { + const NAME: &'static str = "ceil"; + const KERNEL: &'static str = "uceil"; + const V: Self = Ceil; + #[inline(always)] + fn bf16(v: bf16) -> bf16 { + v.ceil() + } + #[inline(always)] + fn f16(v: f16) -> f16 { + v.ceil() + } + #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + v.ceil() + } + #[inline(always)] + fn f32(v: f32) -> f32 { + v.ceil() + } + #[inline(always)] + fn f64(v: f64) -> f64 { + v.ceil() + } + #[inline(always)] + fn u8(v: u8) -> u8 { + v + } + #[inline(always)] + fn u32(v: u32) -> u32 { + v + } + #[inline(always)] + fn i64(v: i64) -> i64 { + v + } + #[inline(always)] + fn i32(v: i32) -> i32 { + v + } + #[inline(always)] + fn i16(v: i16) -> i16 { + v + } +} + +impl UnaryOpT for Floor { + const NAME: &'static str = "floor"; + const KERNEL: &'static str = "ufloor"; + const V: Self = Floor; + #[inline(always)] + fn bf16(v: bf16) -> bf16 { + v.floor() + } + #[inline(always)] + fn f16(v: f16) -> f16 { + v.floor() + } + #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + v.floor() + } + #[inline(always)] + fn f32(v: f32) -> f32 { + v.floor() + } + #[inline(always)] + fn f64(v: f64) -> f64 { + v.floor() + } + #[inline(always)] + fn u8(v: u8) -> u8 { + v + } + #[inline(always)] + fn u32(v: u32) -> u32 { + v + } + #[inline(always)] + fn i64(v: i64) -> i64 { + v + } + #[inline(always)] + fn i32(v: i32) -> i32 { + v + } + #[inline(always)] + fn i16(v: i16) -> i16 { + v + } +} + +impl UnaryOpT for Round { + const NAME: &'static str = "round"; + const KERNEL: &'static str = "uround"; + const V: Self = Round; + #[inline(always)] + fn bf16(v: bf16) -> bf16 { + v.round() + } + #[inline(always)] + fn f16(v: f16) -> f16 { + v.round() + } + #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + v.round() + } + #[inline(always)] + fn f32(v: f32) -> f32 { + v.round() + } + #[inline(always)] + fn f64(v: f64) -> f64 { + v.round() + } + #[inline(always)] + fn u8(v: u8) -> u8 { + v + } + #[inline(always)] + fn u32(v: u32) -> u32 { + v + } + #[inline(always)] + fn i64(v: i64) -> i64 { + v + } + #[inline(always)] + fn i32(v: i32) -> i32 { + v + } + #[inline(always)] + fn i16(v: i16) -> i16 { + v + } +} + +impl UnaryOpT for GeluErf { + const NAME: &'static str = "gelu_erf"; + const KERNEL: &'static str = "ugelu_erf"; + const V: Self = GeluErf; + #[inline(always)] + fn bf16(v: bf16) -> bf16 { + bf16::from_f64(Self::f64(v.to_f64())) + } + #[inline(always)] + fn f16(v: f16) -> f16 { + f16::from_f64(Self::f64(v.to_f64())) + } + #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + F8E4M3::from_f64(Self::f64(v.to_f64())) + } + #[inline(always)] + fn f32(v: f32) -> f32 { + Self::f64(v as f64) as f32 + } + #[inline(always)] + fn f64(v: f64) -> f64 { + (crate::cpu::erf::erf(v / 2f64.sqrt()) + 1.) * 0.5 * v + } + #[inline(always)] + fn u8(_: u8) -> u8 { + 0 + } + #[inline(always)] + fn u32(_: u32) -> u32 { + 0 + } + #[inline(always)] + fn i64(_: i64) -> i64 { + 0 + } + #[inline(always)] + fn i32(_: i32) -> i32 { + 0 + } + #[inline(always)] + fn i16(_: i16) -> i16 { + 0 + } +} + +impl UnaryOpT for Relu { + const NAME: &'static str = "relu"; + const KERNEL: &'static str = "urelu"; + const V: Self = Relu; + #[inline(always)] + fn bf16(v: bf16) -> bf16 { + v.max(bf16::ZERO) + } + #[inline(always)] + fn f16(v: f16) -> f16 { + v.max(f16::ZERO) + } + #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + v.max(F8E4M3::ZERO) + } + #[inline(always)] + fn f32(v: f32) -> f32 { + v.max(0f32) + } + #[inline(always)] + fn f64(v: f64) -> f64 { + v.max(0f64) + } + #[inline(always)] + fn u8(v: u8) -> u8 { + v + } + #[inline(always)] + fn u32(v: u32) -> u32 { + v + } + #[inline(always)] + fn i64(v: i64) -> i64 { + v + } + #[inline(always)] + fn i32(v: i32) -> i32 { + v + } + #[inline(always)] + fn i16(v: i16) -> i16 { + v + } +} + +/// `BackpropOp` is a wrapper around `Option`. The main goal is to ensure that dependencies are +/// properly checked when creating a new value +#[derive(Clone)] +pub struct BackpropOp(Option); + +impl BackpropOp { + pub(crate) fn none() -> Self { + BackpropOp(None) + } + + pub(crate) fn new1(arg: &Tensor, f: impl Fn(Tensor) -> Op) -> Self { + let op = if arg.track_op() { + Some(f(arg.clone())) + } else { + None + }; + Self(op) + } + + pub(crate) fn new2(arg1: &Tensor, arg2: &Tensor, f: impl Fn(Tensor, Tensor) -> Op) -> Self { + let op = if arg1.track_op() || arg2.track_op() { + Some(f(arg1.clone(), arg2.clone())) + } else { + None + }; + Self(op) + } + + pub(crate) fn new3( + arg1: &Tensor, + arg2: &Tensor, + arg3: &Tensor, + f: impl Fn(Tensor, Tensor, Tensor) -> Op, + ) -> Self { + let op = if arg1.track_op() || arg2.track_op() || arg3.track_op() { + Some(f(arg1.clone(), arg2.clone(), arg3.clone())) + } else { + None + }; + Self(op) + } + + pub(crate) fn new>(args: &[A], f: impl Fn(Vec) -> Op) -> Self { + let op = if args.iter().any(|arg| arg.as_ref().track_op()) { + let args: Vec = args.iter().map(|arg| arg.as_ref().clone()).collect(); + Some(f(args)) + } else { + None + }; + Self(op) + } + + pub(crate) fn is_none(&self) -> bool { + self.0.is_none() + } +} + +impl std::ops::Deref for BackpropOp { + type Target = Option; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl UnaryOpT for Sign { + const NAME: &'static str = "sign"; + const KERNEL: &'static str = "usign"; + const V: Self = Sign; + #[inline(always)] + fn bf16(v: bf16) -> bf16 { + bf16::from((v > bf16::ZERO) as i8) - bf16::from((v < bf16::ZERO) as i8) + } + #[inline(always)] + fn f16(v: f16) -> f16 { + f16::from((v > f16::ZERO) as i8) - f16::from((v < f16::ZERO) as i8) + } + #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + F8E4M3::from((v > F8E4M3::ZERO) as i8 as f32) + - F8E4M3::from((v < F8E4M3::ZERO) as i8 as f32) + } + #[inline(always)] + fn f32(v: f32) -> f32 { + f32::from(v > 0.) - f32::from(v < 0.) + } + #[inline(always)] + fn f64(v: f64) -> f64 { + f64::from(v > 0.) - f64::from(v < 0.) + } + #[inline(always)] + fn u8(v: u8) -> u8 { + u8::min(1, v) + } + #[inline(always)] + fn u32(v: u32) -> u32 { + u32::min(1, v) + } + #[inline(always)] + fn i64(v: i64) -> i64 { + (v > 0) as i64 - (v < 0) as i64 + } + #[inline(always)] + fn i32(v: i32) -> i32 { + (v > 0) as i32 - (v < 0) as i32 + } + #[inline(always)] + fn i16(v: i16) -> i16 { + (v > 0) as i16 - (v < 0) as i16 + } +} diff --git a/mcandle/mcandle-core/src/pickle.rs b/mcandle/mcandle-core/src/pickle.rs new file mode 100644 index 000000000..08335257c --- /dev/null +++ b/mcandle/mcandle-core/src/pickle.rs @@ -0,0 +1,817 @@ +// Just enough pickle support to be able to read PyTorch checkpoints. +// This hardcodes objects that are required for tensor reading, we may want to make this a bit more +// composable/tensor agnostic at some point. +use crate::{DType, Error as E, Layout, Result, Tensor}; +use byteorder::{LittleEndian, ReadBytesExt}; +use std::collections::HashMap; +use std::io::BufRead; + +const VERBOSE: bool = false; + +// https://docs.juliahub.com/Pickle/LAUNc/0.1.0/opcode/ +#[repr(u8)] +#[derive(Debug, Eq, PartialEq, Clone)] +pub enum OpCode { + // https://github.com/python/cpython/blob/ed25f097160b5cbb0c9a1f9a746d2f1bbc96515a/Lib/pickletools.py#L2123 + Proto = 0x80, + Global = b'c', + BinPut = b'q', + LongBinPut = b'r', + EmptyTuple = b')', + Reduce = b'R', + Mark = b'(', + BinUnicode = b'X', + BinInt = b'J', + Tuple = b't', + BinPersId = b'Q', + BinInt1 = b'K', + BinInt2 = b'M', + Tuple1 = 0x85, + Tuple2 = 0x86, + Tuple3 = 0x87, + NewTrue = 0x88, + NewFalse = 0x89, + None = b'N', + BinGet = b'h', + LongBinGet = b'j', + SetItem = b's', + SetItems = b'u', + EmptyDict = b'}', + Dict = b'd', + Build = b'b', + Stop = b'.', + NewObj = 0x81, + EmptyList = b']', + BinFloat = b'G', + Append = b'a', + Appends = b'e', +} + +// Avoid using FromPrimitive so as not to drag another dependency. +impl TryFrom for OpCode { + type Error = u8; + fn try_from(value: u8) -> std::result::Result { + match value { + 0x80 => Ok(Self::Proto), + b'c' => Ok(Self::Global), + b'q' => Ok(Self::BinPut), + b'r' => Ok(Self::LongBinPut), + b')' => Ok(Self::EmptyTuple), + b'R' => Ok(Self::Reduce), + b'(' => Ok(Self::Mark), + b'X' => Ok(Self::BinUnicode), + b'J' => Ok(Self::BinInt), + b't' => Ok(Self::Tuple), + b'Q' => Ok(Self::BinPersId), + b'K' => Ok(Self::BinInt1), + b'M' => Ok(Self::BinInt2), + b'N' => Ok(Self::None), + 0x85 => Ok(Self::Tuple1), + 0x86 => Ok(Self::Tuple2), + 0x87 => Ok(Self::Tuple3), + 0x88 => Ok(Self::NewTrue), + 0x89 => Ok(Self::NewFalse), + b'h' => Ok(Self::BinGet), + b'j' => Ok(Self::LongBinGet), + b's' => Ok(Self::SetItem), + b'u' => Ok(Self::SetItems), + b'}' => Ok(Self::EmptyDict), + b'd' => Ok(Self::EmptyDict), + b'b' => Ok(Self::Build), + b'.' => Ok(Self::Stop), + 0x81 => Ok(Self::NewObj), + b']' => Ok(Self::EmptyList), + b'G' => Ok(Self::BinFloat), + b'a' => Ok(Self::Append), + b'e' => Ok(Self::Appends), + value => Err(value), + } + } +} + +fn read_to_newline(r: &mut R) -> Result> { + let mut data: Vec = Vec::with_capacity(32); + r.read_until(b'\n', &mut data)?; + data.pop(); + if data.last() == Some(&b'\r') { + data.pop(); + } + Ok(data) +} + +#[derive(Debug, Clone, PartialEq)] +pub enum Object { + Class { + module_name: String, + class_name: String, + }, + Int(i32), + Float(f64), + Unicode(String), + Bool(bool), + None, + Tuple(Vec), + List(Vec), + Mark, + Dict(Vec<(Object, Object)>), + Reduce { + callable: Box, + args: Box, + }, + Build { + callable: Box, + args: Box, + }, + PersistentLoad(Box), +} + +type OResult = std::result::Result; + +impl Object { + pub fn unicode(self) -> OResult { + match self { + Self::Unicode(t) => Ok(t), + _ => Err(self), + } + } + + pub fn reduce(self) -> OResult<(Self, Self)> { + match self { + Self::Reduce { callable, args } => Ok((*callable, *args)), + _ => Err(self), + } + } + + pub fn none(self) -> OResult<()> { + match self { + Self::None => Ok(()), + _ => Err(self), + } + } + + pub fn persistent_load(self) -> OResult { + match self { + Self::PersistentLoad(t) => Ok(*t), + _ => Err(self), + } + } + + pub fn bool(self) -> OResult { + match self { + Self::Bool(t) => Ok(t), + _ => Err(self), + } + } + + pub fn int(self) -> OResult { + match self { + Self::Int(t) => Ok(t), + _ => Err(self), + } + } + + pub fn tuple(self) -> OResult> { + match self { + Self::Tuple(t) => Ok(t), + _ => Err(self), + } + } + + pub fn dict(self) -> OResult> { + match self { + Self::Dict(t) => Ok(t), + _ => Err(self), + } + } + + pub fn class(self) -> OResult<(String, String)> { + match self { + Self::Class { + module_name, + class_name, + } => Ok((module_name, class_name)), + _ => Err(self), + } + } + + pub fn into_tensor_info( + self, + name: Self, + dir_name: &std::path::Path, + ) -> Result> { + let name = match name.unicode() { + Ok(name) => name, + Err(_) => return Ok(None), + }; + let (callable, args) = match self.reduce() { + Ok(callable_args) => callable_args, + _ => return Ok(None), + }; + let (callable, args) = match callable { + Object::Class { + module_name, + class_name, + } if module_name == "torch._tensor" && class_name == "_rebuild_from_type_v2" => { + let mut args = args.tuple()?; + let callable = args.remove(0); + let args = args.remove(1); + (callable, args) + } + Object::Class { + module_name, + class_name, + } if module_name == "torch._utils" && class_name == "_rebuild_parameter" => { + let mut args = args.tuple()?; + args.remove(0).reduce()? + } + _ => (callable, args), + }; + match callable { + Object::Class { + module_name, + class_name, + } if module_name == "torch._utils" && class_name == "_rebuild_tensor_v2" => {} + _ => return Ok(None), + }; + let (layout, dtype, file_path, storage_size) = rebuild_args(args)?; + Ok(Some(TensorInfo { + name, + dtype, + layout, + path: format!("{}/{}", dir_name.to_string_lossy(), file_path), + storage_size, + })) + } +} + +impl TryFrom for String { + type Error = Object; + fn try_from(value: Object) -> std::result::Result { + match value { + Object::Unicode(s) => Ok(s), + other => Err(other), + } + } +} + +impl TryFrom for usize { + type Error = Object; + fn try_from(value: Object) -> std::result::Result { + match value { + Object::Int(s) if s >= 0 => Ok(s as usize), + other => Err(other), + } + } +} + +impl> TryFrom for Vec { + type Error = Object; + fn try_from(value: Object) -> std::result::Result { + match value { + Object::Tuple(values) => { + // This does not return the appropriate value in the error case but instead return + // the object related to the first error. + values + .into_iter() + .map(|v| T::try_from(v)) + .collect::, Self::Error>>() + } + other => Err(other), + } + } +} + +#[derive(Debug)] +pub struct Stack { + stack: Vec, + memo: HashMap, +} + +impl Stack { + pub fn empty() -> Self { + Self { + stack: Vec::with_capacity(512), + memo: HashMap::new(), + } + } + + pub fn stack(&self) -> &[Object] { + self.stack.as_slice() + } + + pub fn read_loop(&mut self, r: &mut R) -> Result<()> { + loop { + if self.read(r)? { + break; + } + } + Ok(()) + } + + pub fn finalize(mut self) -> Result { + self.pop() + } + + fn push(&mut self, obj: Object) { + self.stack.push(obj) + } + + fn pop(&mut self) -> Result { + match self.stack.pop() { + None => crate::bail!("unexpected empty stack"), + Some(obj) => Ok(obj), + } + } + + // https://docs.juliahub.com/Pickle/LAUNc/0.1.0/opcode/#Pickle.OpCodes.BUILD + fn build(&mut self) -> Result<()> { + let args = self.pop()?; + let obj = self.pop()?; + let obj = match (obj, args) { + (Object::Dict(mut obj), Object::Dict(mut args)) => { + obj.append(&mut args); + Object::Dict(obj) + } + (obj, args) => Object::Build { + callable: Box::new(obj), + args: Box::new(args), + }, + }; + self.push(obj); + Ok(()) + } + + fn reduce(&mut self) -> Result<()> { + let args = self.pop()?; + let callable = self.pop()?; + #[allow(clippy::single_match)] + let reduced = match &callable { + Object::Class { + module_name, + class_name, + } => { + if module_name == "collections" + && (class_name == "OrderedDict" || class_name == "defaultdict") + { + // TODO: have a separate ordered dict and a separate default dict. + Some(Object::Dict(vec![])) + } else { + None + } + } + _ => None, + }; + let reduced = reduced.unwrap_or_else(|| Object::Reduce { + callable: Box::new(callable), + args: Box::new(args), + }); + self.push(reduced); + Ok(()) + } + + fn last(&mut self) -> Result<&mut Object> { + match self.stack.last_mut() { + None => crate::bail!("unexpected empty stack"), + Some(obj) => Ok(obj), + } + } + + fn memo_get(&self, id: u32) -> Result { + match self.memo.get(&id) { + None => crate::bail!("missing object in memo {id}"), + Some(obj) => { + // Maybe we should use refcounting rather than doing potential large clones here. + Ok(obj.clone()) + } + } + } + + fn memo_put(&mut self, id: u32) -> Result<()> { + let obj = self.last()?.clone(); + self.memo.insert(id, obj); + Ok(()) + } + + fn persistent_load(&self, id: Object) -> Result { + Ok(Object::PersistentLoad(Box::new(id))) + } + + fn new_obj(&self, class: Object, args: Object) -> Result { + Ok(Object::Reduce { + callable: Box::new(class), + args: Box::new(args), + }) + } + + fn pop_to_marker(&mut self) -> Result> { + let mut mark_idx = None; + for (idx, obj) in self.stack.iter().enumerate().rev() { + if obj == &Object::Mark { + mark_idx = Some(idx); + break; + } + } + match mark_idx { + Some(mark_idx) => { + let objs = self.stack.split_off(mark_idx + 1); + self.stack.pop(); + Ok(objs) + } + None => { + crate::bail!("marker object not found") + } + } + } + + pub fn read(&mut self, r: &mut R) -> Result { + let op_code = match OpCode::try_from(r.read_u8()?) { + Ok(op_code) => op_code, + Err(op_code) => { + crate::bail!("unknown op-code {op_code}") + } + }; + // println!("op: {op_code:?}"); + // println!("{:?}", self.stack); + match op_code { + OpCode::Proto => { + let version = r.read_u8()?; + if VERBOSE { + println!("proto {version}"); + } + } + OpCode::Global => { + let module_name = read_to_newline(r)?; + let class_name = read_to_newline(r)?; + let module_name = String::from_utf8_lossy(&module_name).to_string(); + let class_name = String::from_utf8_lossy(&class_name).to_string(); + self.push(Object::Class { + module_name, + class_name, + }) + } + OpCode::BinInt1 => { + let arg = r.read_u8()?; + self.push(Object::Int(arg as i32)) + } + OpCode::BinInt2 => { + let arg = r.read_u16::()?; + self.push(Object::Int(arg as i32)) + } + OpCode::BinInt => { + let arg = r.read_i32::()?; + self.push(Object::Int(arg)) + } + OpCode::BinFloat => { + // Somehow floats are encoded using BigEndian whereas int types use LittleEndian. + // https://github.com/python/cpython/blob/0c80da4c14d904a367968955544dd6ae58c8101c/Lib/pickletools.py#L855 + // https://github.com/pytorch/pytorch/blob/372d078f361e726bb4ac0884ac334b04c58179ef/torch/_weights_only_unpickler.py#L243 + let arg = r.read_f64::()?; + self.push(Object::Float(arg)) + } + OpCode::BinUnicode => { + let len = r.read_u32::()?; + let mut data = vec![0u8; len as usize]; + r.read_exact(&mut data)?; + let data = String::from_utf8(data).map_err(E::wrap)?; + self.push(Object::Unicode(data)) + } + OpCode::BinPersId => { + let id = self.pop()?; + let obj = self.persistent_load(id)?; + self.push(obj) + } + OpCode::Tuple => { + let objs = self.pop_to_marker()?; + self.push(Object::Tuple(objs)) + } + OpCode::Tuple1 => { + let obj = self.pop()?; + self.push(Object::Tuple(vec![obj])) + } + OpCode::Tuple2 => { + let obj2 = self.pop()?; + let obj1 = self.pop()?; + self.push(Object::Tuple(vec![obj1, obj2])) + } + OpCode::Tuple3 => { + let obj3 = self.pop()?; + let obj2 = self.pop()?; + let obj1 = self.pop()?; + self.push(Object::Tuple(vec![obj1, obj2, obj3])) + } + OpCode::NewTrue => self.push(Object::Bool(true)), + OpCode::NewFalse => self.push(Object::Bool(false)), + OpCode::Append => { + let value = self.pop()?; + let pylist = self.last()?; + if let Object::List(d) = pylist { + d.push(value) + } else { + crate::bail!("expected a list, got {pylist:?}") + } + } + OpCode::Appends => { + let objs = self.pop_to_marker()?; + let pylist = self.last()?; + if let Object::List(d) = pylist { + d.extend(objs) + } else { + crate::bail!("expected a list, got {pylist:?}") + } + } + OpCode::SetItem => { + let value = self.pop()?; + let key = self.pop()?; + let pydict = self.last()?; + if let Object::Dict(d) = pydict { + d.push((key, value)) + } else { + crate::bail!("expected a dict, got {pydict:?}") + } + } + OpCode::SetItems => { + let mut objs = self.pop_to_marker()?; + let pydict = self.last()?; + if let Object::Dict(d) = pydict { + if objs.len() % 2 != 0 { + crate::bail!("setitems: not an even number of objects") + } + while let Some(value) = objs.pop() { + let key = objs.pop().unwrap(); + d.push((key, value)) + } + } else { + crate::bail!("expected a dict, got {pydict:?}") + } + } + OpCode::None => self.push(Object::None), + OpCode::Stop => { + return Ok(true); + } + OpCode::Build => self.build()?, + OpCode::EmptyDict => self.push(Object::Dict(vec![])), + OpCode::Dict => { + let mut objs = self.pop_to_marker()?; + let mut pydict = vec![]; + if objs.len() % 2 != 0 { + crate::bail!("setitems: not an even number of objects") + } + while let Some(value) = objs.pop() { + let key = objs.pop().unwrap(); + pydict.push((key, value)) + } + self.push(Object::Dict(pydict)) + } + OpCode::Mark => self.push(Object::Mark), + OpCode::Reduce => self.reduce()?, + OpCode::EmptyTuple => self.push(Object::Tuple(vec![])), + OpCode::EmptyList => self.push(Object::List(vec![])), + OpCode::BinGet => { + let arg = r.read_u8()?; + let obj = self.memo_get(arg as u32)?; + self.push(obj) + } + OpCode::LongBinGet => { + let arg = r.read_u32::()?; + let obj = self.memo_get(arg)?; + self.push(obj) + } + OpCode::BinPut => { + let arg = r.read_u8()?; + self.memo_put(arg as u32)? + } + OpCode::LongBinPut => { + let arg = r.read_u32::()?; + self.memo_put(arg)? + } + OpCode::NewObj => { + let args = self.pop()?; + let class = self.pop()?; + let obj = self.new_obj(class, args)?; + self.push(obj) + } + } + Ok(false) + } +} + +impl From for E { + fn from(value: Object) -> Self { + E::Msg(format!("conversion error on {value:?}")) + } +} + +// https://github.com/pytorch/pytorch/blob/4eac43d046ded0f0a5a5fa8db03eb40f45bf656e/torch/_utils.py#L198 +// Arguments: storage, storage_offset, size, stride, requires_grad, backward_hooks +fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> { + let mut args = args.tuple()?; + let stride = Vec::::try_from(args.remove(3))?; + let size = Vec::::try_from(args.remove(2))?; + let offset = args.remove(1).int()? as usize; + let storage = args.remove(0).persistent_load()?; + let mut storage = storage.tuple()?; + let storage_size = storage.remove(4).int()? as usize; + let path = storage.remove(2).unicode()?; + let (_module_name, class_name) = storage.remove(1).class()?; + let dtype = match class_name.as_str() { + "FloatStorage" => DType::F32, + "DoubleStorage" => DType::F64, + "HalfStorage" => DType::F16, + "BFloat16Storage" => DType::BF16, + "ByteStorage" => DType::U8, + "LongStorage" => DType::I64, + other => { + crate::bail!("unsupported storage type {other}") + } + }; + let layout = Layout::new(crate::Shape::from(size), stride, offset); + Ok((layout, dtype, path, storage_size)) +} + +#[derive(Debug, Clone)] +pub struct TensorInfo { + pub name: String, + pub dtype: DType, + pub layout: Layout, + pub path: String, + pub storage_size: usize, +} + +/// Read the tensor info from a .pth file. +/// +/// # Arguments +/// * `file` - The path to the .pth file. +/// * `verbose` - Whether to print debug information. +/// * `key` - Optional key to retrieve `state_dict` from the pth file. +pub fn read_pth_tensor_info>( + file: P, + verbose: bool, + key: Option<&str>, +) -> Result> { + let file = std::fs::File::open(file)?; + let zip_reader = std::io::BufReader::new(file); + let mut zip = zip::ZipArchive::new(zip_reader)?; + let zip_file_names = zip + .file_names() + .map(|f| f.to_string()) + .collect::>(); + + let mut tensor_infos = vec![]; + for file_name in zip_file_names.iter() { + if !file_name.ends_with("data.pkl") { + continue; + } + let dir_name = std::path::PathBuf::from(file_name.strip_suffix(".pkl").unwrap()); + let reader = zip.by_name(file_name)?; + let mut reader = std::io::BufReader::new(reader); + let mut stack = Stack::empty(); + stack.read_loop(&mut reader)?; + let obj = stack.finalize()?; + if VERBOSE || verbose { + println!("{obj:#?}"); + } + + let obj = match obj { + Object::Build { callable, args } => match *callable { + Object::Reduce { callable, args: _ } => match *callable { + Object::Class { + module_name, + class_name, + } if module_name == "__torch__" && class_name == "Module" => *args, + _ => continue, + }, + _ => continue, + }, + obj => obj, + }; + + // If key is provided, then we need to extract the state_dict from the object. + let obj = if let Some(key) = key { + if let Object::Dict(key_values) = obj { + key_values + .into_iter() + .find(|(k, _)| *k == Object::Unicode(key.to_owned())) + .map(|(_, v)| v) + .ok_or_else(|| E::Msg(format!("key {key} not found")))? + } else { + obj + } + } else { + obj + }; + + // If the object is a dict, then we can extract the tensor info from it. + // NOTE: We are assuming that the `obj` is state_dict by this stage. + if let Object::Dict(key_values) = obj { + for (name, value) in key_values.into_iter() { + match value.into_tensor_info(name, &dir_name) { + Ok(Some(tensor_info)) => tensor_infos.push(tensor_info), + Ok(None) => {} + Err(err) => eprintln!("skipping: {err:?}"), + } + } + } + } + Ok(tensor_infos) +} + +/// Lazy tensor loader. +pub struct PthTensors { + tensor_infos: HashMap, + path: std::path::PathBuf, + // We do not store a zip reader as it needs mutable access to extract data. Instead we + // re-create a zip reader for each tensor. +} + +impl PthTensors { + pub fn new>(path: P, key: Option<&str>) -> Result { + let tensor_infos = read_pth_tensor_info(path.as_ref(), false, key)?; + let tensor_infos = tensor_infos + .into_iter() + .map(|ti| (ti.name.to_string(), ti)) + .collect(); + let path = path.as_ref().to_owned(); + Ok(Self { tensor_infos, path }) + } + + pub fn tensor_infos(&self) -> &HashMap { + &self.tensor_infos + } + + pub fn get(&self, name: &str) -> Result> { + use std::io::Read; + let tensor_info = match self.tensor_infos.get(name) { + None => return Ok(None), + Some(tensor_info) => tensor_info, + }; + // We hope that the file has not changed since first reading it. + let zip_reader = std::io::BufReader::new(std::fs::File::open(&self.path)?); + let mut zip = zip::ZipArchive::new(zip_reader)?; + let mut reader = zip.by_name(&tensor_info.path)?; + let is_fortran_contiguous = tensor_info.layout.is_fortran_contiguous(); + let rank = tensor_info.layout.shape().rank(); + + // Reading the data is a bit tricky as it can be strided, for now only support the basic + // case and when the tensor is fortran contiguous. + if !tensor_info.layout.is_contiguous() && !is_fortran_contiguous { + crate::bail!( + "cannot retrieve non-contiguous tensors {:?}", + tensor_info.layout + ) + } + let start_offset = tensor_info.layout.start_offset(); + if start_offset > 0 { + std::io::copy( + &mut reader.by_ref().take(start_offset as u64), + &mut std::io::sink(), + )?; + } + let tensor = Tensor::from_reader( + tensor_info.layout.shape().clone(), + tensor_info.dtype, + &mut reader, + )?; + + if rank > 1 && is_fortran_contiguous { + // Reverse the shape, e.g. Shape(2, 3, 4) -> Shape(4, 3, 2) + let shape_reversed: Vec<_> = tensor_info.layout.dims().iter().rev().cloned().collect(); + let tensor = tensor.reshape(shape_reversed)?; + + // Permute (transpose) the dimensions, e.g. Shape(4, 3, 2) -> Shape(2, 3, 4) + let dim_indeces_reversed: Vec<_> = (0..rank).rev().collect(); + let tensor = tensor.permute(dim_indeces_reversed)?; + Ok(Some(tensor)) + } else { + Ok(Some(tensor)) + } + } +} + +/// Read all the tensors from a PyTorch pth file with a given key. +/// +/// # Arguments +/// * `path` - Path to the pth file. +/// * `key` - Optional key to retrieve `state_dict` from the pth file. Sometimes the pth file +/// contains multiple objects and the state_dict is the one we are interested in. +pub fn read_all_with_key>( + path: P, + key: Option<&str>, +) -> Result> { + let pth = PthTensors::new(path, key)?; + let tensor_names = pth.tensor_infos.keys(); + let mut tensors = Vec::with_capacity(tensor_names.len()); + for name in tensor_names { + if let Some(tensor) = pth.get(name)? { + tensors.push((name.to_string(), tensor)) + } + } + Ok(tensors) +} + +/// Read all the tensors from a PyTorch pth file. +/// +/// # Arguments +/// * `path` - Path to the pth file. +pub fn read_all>(path: P) -> Result> { + read_all_with_key(path, None) +} diff --git a/mcandle/mcandle-core/src/quantized/avx.rs b/mcandle/mcandle-core/src/quantized/avx.rs new file mode 100644 index 000000000..664f7653e --- /dev/null +++ b/mcandle/mcandle-core/src/quantized/avx.rs @@ -0,0 +1,667 @@ +use super::k_quants::{ + BlockQ2K, BlockQ3K, BlockQ4K, BlockQ4_0, BlockQ5K, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K, +}; +use crate::Result; +use byteorder::{ByteOrder, LittleEndian}; +use half::f16; + +#[cfg(target_arch = "x86")] +use core::arch::x86::*; +#[cfg(target_arch = "x86_64")] +use core::arch::x86_64::*; + +#[inline(always)] +pub(crate) unsafe fn sum_i16_pairs_float(x: __m256i) -> __m256 { + let ones = _mm256_set1_epi16(1); + let summed_pairs = _mm256_madd_epi16(ones, x); + _mm256_cvtepi32_ps(summed_pairs) +} + +#[inline(always)] +pub(crate) unsafe fn mul_sum_us8_pairs_float(ax: __m256i, sy: __m256i) -> __m256 { + let dot = _mm256_maddubs_epi16(ax, sy); + sum_i16_pairs_float(dot) +} + +#[inline(always)] +pub(crate) unsafe fn hsum_float_8(x: __m256) -> f32 { + let res = _mm256_extractf128_ps(x, 1); + let res = _mm_add_ps(res, _mm256_castps256_ps128(x)); + let res = _mm_add_ps(res, _mm_movehl_ps(res, res)); + let res = _mm_add_ss(res, _mm_movehdup_ps(res)); + _mm_cvtss_f32(res) +} + +#[inline(always)] +pub(crate) unsafe fn bytes_from_nibbles_32(rsi: *const u8) -> __m256i { + let tmp = _mm_loadu_si128(rsi as *const __m128i); + let bytes = _mm256_insertf128_si256::<1>(_mm256_castsi128_si256(tmp), _mm_srli_epi16(tmp, 4)); + let low_mask = _mm256_set1_epi8(0xF); + _mm256_and_si256(low_mask, bytes) +} + +#[inline(always)] +pub(crate) unsafe fn mul_sum_i8_pairs_float(x: __m256i, y: __m256i) -> __m256 { + let ax = _mm256_sign_epi8(x, x); + let sy = _mm256_sign_epi8(y, x); + mul_sum_us8_pairs_float(ax, sy) +} + +#[inline(always)] +pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result { + let qk = QK8_0; + if n % QK8_0 != 0 { + crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}") + } + unsafe { + let mut acc = _mm256_setzero_ps(); + for (x, y) in xs.iter().zip(ys.iter()) { + let d = _mm256_set1_ps(f16::to_f32(x.d) * f16::to_f32(y.d)); + let bx = bytes_from_nibbles_32(x.qs.as_ptr()); + let off = _mm256_set1_epi8(8); + let bx = _mm256_sub_epi8(bx, off); + let by = _mm256_loadu_si256(y.qs.as_ptr() as *const __m256i); + let q = mul_sum_i8_pairs_float(bx, by); + acc = _mm256_fmadd_ps(d, q, acc); + } + Ok(hsum_float_8(acc)) + } +} + +#[inline(always)] +pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> Result { + let qk = QK8_0; + if n % QK8_0 != 0 { + crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}") + } + unsafe { + let mut acc = _mm256_setzero_ps(); + for (x, y) in xs.iter().zip(ys.iter()) { + let d = _mm256_set1_ps(f16::to_f32(x.d) * f16::to_f32(y.d)); + let bx = _mm256_loadu_si256(x.qs.as_ptr() as *const __m256i); + let by = _mm256_loadu_si256(y.qs.as_ptr() as *const __m256i); + let q = mul_sum_i8_pairs_float(bx, by); + acc = _mm256_fmadd_ps(d, q, acc); + } + Ok(hsum_float_8(acc)) + } +} + +#[inline(always)] +unsafe fn get_scale_shuffle(i: usize) -> __m128i { + const K_SHUFFLE: [u8; 128] = [ + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, + 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, + 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10, + 11, 11, 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 13, 13, 13, 13, 13, 13, 13, + 13, 14, 14, 14, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15, 15, 15, 15, + ]; + _mm_loadu_si128((K_SHUFFLE.as_ptr() as *const __m128i).add(i)) +} + +#[inline(always)] +unsafe fn get_scale_shuffle_k4(i: usize) -> __m256i { + const K_SHUFFLE: [u8; 256] = [ + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, + 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, + 2, 3, 2, 3, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, + 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, + 6, 7, 6, 7, 6, 7, 6, 7, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, + 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, + 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 12, 13, 12, 13, 12, 13, + 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, + 13, 12, 13, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, + 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, + ]; + _mm256_loadu_si256((K_SHUFFLE.as_ptr() as *const __m256i).add(i)) +} + +#[inline(always)] +unsafe fn get_scale_shuffle_q3k(i: usize) -> __m256i { + const K_SHUFFLE: [u8; 128] = [ + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, + 2, 3, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, + 6, 7, 6, 7, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10, 11, 10, 11, 10, 11, 10, 11, + 10, 11, 10, 11, 10, 11, 10, 11, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, + 13, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, + ]; + _mm256_loadu_si256((K_SHUFFLE.as_ptr() as *const __m256i).add(i)) +} + +#[inline(always)] +pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result { + let qk = QK_K; + if n % qk != 0 { + crate::bail!("vec_dot_q6k_8k: {n} is not divisible by {qk}") + } + + unsafe { + let m4 = _mm256_set1_epi8(0xF); + let m2 = _mm256_set1_epi8(3); + let m32s = _mm256_set1_epi8(32); + let mut acc = _mm256_setzero_ps(); + for (x, y) in xs.iter().zip(ys.iter()) { + let d = y.d * x.d.to_f32(); + let mut q4 = x.ql.as_ptr(); + let mut qh = x.qh.as_ptr(); + let mut q8 = y.qs.as_ptr(); + + let scales = _mm_loadu_si128(x.scales.as_ptr() as *const __m128i); + let mut sumi = _mm256_setzero_si256(); + + for j in 0..QK_K / 128 { + let is = j * 4; + let scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is)); + let scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1)); + let scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2)); + let scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3)); + + let q4bits1 = _mm256_loadu_si256(q4 as *const __m256i); + q4 = q4.add(32); + let q4bits2 = _mm256_loadu_si256(q4 as *const __m256i); + q4 = q4.add(32); + let q4bits_h = _mm256_loadu_si256(qh as *const __m256i); + qh = qh.add(32); + + let q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bits_h, m2), 4); + let q4h_1 = + _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bits_h, 2), m2), 4); + let q4h_2 = + _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bits_h, 4), m2), 4); + let q4h_3 = + _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bits_h, 6), m2), 4); + + let q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0); + let q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1); + let q4_2 = + _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2); + let q4_3 = + _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3); + + let q8_0 = _mm256_loadu_si256(q8 as *const __m256i); + q8 = q8.add(32); + let q8_1 = _mm256_loadu_si256(q8 as *const __m256i); + q8 = q8.add(32); + let q8_2 = _mm256_loadu_si256(q8 as *const __m256i); + q8 = q8.add(32); + let q8_3 = _mm256_loadu_si256(q8 as *const __m256i); + q8 = q8.add(32); + + let q8s_0 = _mm256_maddubs_epi16(m32s, q8_0); + let q8s_1 = _mm256_maddubs_epi16(m32s, q8_1); + let q8s_2 = _mm256_maddubs_epi16(m32s, q8_2); + let q8s_3 = _mm256_maddubs_epi16(m32s, q8_3); + + let p16_0 = _mm256_maddubs_epi16(q4_0, q8_0); + let p16_1 = _mm256_maddubs_epi16(q4_1, q8_1); + let p16_2 = _mm256_maddubs_epi16(q4_2, q8_2); + let p16_3 = _mm256_maddubs_epi16(q4_3, q8_3); + + let p16_0 = _mm256_sub_epi16(p16_0, q8s_0); + let p16_1 = _mm256_sub_epi16(p16_1, q8s_1); + let p16_2 = _mm256_sub_epi16(p16_2, q8s_2); + let p16_3 = _mm256_sub_epi16(p16_3, q8s_3); + + let p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0); + let p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1); + let p16_2 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_2), p16_2); + let p16_3 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_3), p16_3); + + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1)); + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_2, p16_3)); + } + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); + } + Ok(hsum_float_8(acc)) + } +} + +#[inline(always)] +unsafe fn mm256_set_m128i(a: __m128i, b: __m128i) -> __m256i { + _mm256_insertf128_si256(_mm256_castsi128_si256(b), a, 1) +} + +#[inline(always)] +pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result { + if n % QK_K != 0 { + crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}") + } + + unsafe { + let m3 = _mm256_set1_epi8(3); + let m4 = _mm_set1_epi8(0xF); + + let mut acc = _mm256_setzero_ps(); + + for (x, y) in xs.iter().zip(ys.iter()) { + let d = y.d * x.d.to_f32(); + let dmin = -y.d * x.dmin.to_f32(); + + let mut q2 = x.qs.as_ptr(); + let mut q8 = y.qs.as_ptr(); + + let mins_and_scales = _mm_loadu_si128(x.scales.as_ptr() as *const __m128i); + let scales8 = _mm_and_si128(mins_and_scales, m4); + let mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4); + let mins = _mm256_cvtepi8_epi16(mins8); + let prod = + _mm256_madd_epi16(mins, _mm256_loadu_si256(y.bsums.as_ptr() as *const __m256i)); + + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(prod), acc); + + let all_scales = _mm256_cvtepi8_epi16(scales8); + let l_scales = _mm256_extracti128_si256(all_scales, 0); + let h_scales = _mm256_extracti128_si256(all_scales, 1); + let scales = [ + mm256_set_m128i(l_scales, l_scales), + mm256_set_m128i(h_scales, h_scales), + ]; + + let mut sumi = _mm256_setzero_si256(); + + for scale in scales { + let q2bits = _mm256_loadu_si256(q2 as *const __m256i); + q2 = q2.add(32); + + let q8_0 = _mm256_loadu_si256(q8 as *const __m256i); + q8 = q8.add(32); + let q8_1 = _mm256_loadu_si256(q8 as *const __m256i); + q8 = q8.add(32); + let q8_2 = _mm256_loadu_si256(q8 as *const __m256i); + q8 = q8.add(32); + let q8_3 = _mm256_loadu_si256(q8 as *const __m256i); + q8 = q8.add(32); + + let q2_0 = _mm256_and_si256(q2bits, m3); + let q2_1 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), m3); + let q2_2 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), m3); + let q2_3 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), m3); + + let p0 = _mm256_maddubs_epi16(q2_0, q8_0); + let p1 = _mm256_maddubs_epi16(q2_1, q8_1); + let p2 = _mm256_maddubs_epi16(q2_2, q8_2); + let p3 = _mm256_maddubs_epi16(q2_3, q8_3); + + let p0 = + _mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(0)), p0); + let p1 = + _mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(1)), p1); + let p2 = + _mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(2)), p2); + let p3 = + _mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(3)), p3); + + let p0 = _mm256_add_epi32(p0, p1); + let p2 = _mm256_add_epi32(p2, p3); + + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p0, p2)); + } + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); + } + + Ok(hsum_float_8(acc)) + } +} + +#[inline(always)] +pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Result { + if n % QK_K != 0 { + crate::bail!("vec_dot_q3k_q8k: {n} is not divisible by {QK_K}") + } + + const KMASK1: u32 = 0x03030303; + const KMASK2: u32 = 0x0f0f0f0f; + + let mut aux = [0u32; 3]; + + unsafe { + let m3 = _mm256_set1_epi8(3); + let mone = _mm256_set1_epi8(1); + let m32 = _mm_set1_epi8(32); + + let mut acc = _mm256_setzero_ps(); + for (x, y) in xs.iter().zip(ys.iter()) { + let d = y.d * x.d.to_f32(); + + let mut q3 = x.qs.as_ptr(); + let mut q8 = y.qs.as_ptr(); + + LittleEndian::read_u32_into(&x.scales, &mut aux); + let scales128 = _mm_set_epi32( + (((aux[1] >> 4) & KMASK2) | (((aux[2] >> 6) & KMASK1) << 4)) as i32, + (((aux[0] >> 4) & KMASK2) | (((aux[2] >> 4) & KMASK1) << 4)) as i32, + ((aux[1] & KMASK2) | (((aux[2] >> 2) & KMASK1) << 4)) as i32, + ((aux[0] & KMASK2) | (((aux[2]) & KMASK1) << 4)) as i32, + ); + let scales128 = _mm_sub_epi8(scales128, m32); + let all_scales = _mm256_cvtepi8_epi16(scales128); + let l_scales = _mm256_extracti128_si256(all_scales, 0); + let h_scales = _mm256_extracti128_si256(all_scales, 1); + let scales = [ + mm256_set_m128i(l_scales, l_scales), + mm256_set_m128i(h_scales, h_scales), + ]; + + // high bit + let hbits = _mm256_loadu_si256(x.hmask.as_ptr() as *const __m256i); + + let mut sumi = _mm256_setzero_si256(); + + for (j, scale) in scales.iter().enumerate() { + // load low 2 bits + let q3bits = _mm256_loadu_si256(q3 as *const __m256i); + q3 = q3.add(32); + + // Prepare low and high bits + // We hardcode the shifts here to avoid loading them into a separate register + let q3l_0 = _mm256_and_si256(q3bits, m3); + let q3h_0 = if j == 0 { + _mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 0)), 0) + } else { + _mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 4)), 4) + }; + let q3h_0 = _mm256_slli_epi16(q3h_0, 2); + + let q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 2), m3); + let q3h_1 = if j == 0 { + _mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 1)), 1) + } else { + _mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 5)), 5) + }; + let q3h_1 = _mm256_slli_epi16(q3h_1, 2); + + let q3l_2 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 4), m3); + let q3h_2 = if j == 0 { + _mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 2)), 2) + } else { + _mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 6)), 6) + }; + let q3h_2 = _mm256_slli_epi16(q3h_2, 2); + + let q3l_3 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 6), m3); + let q3h_3 = if j == 0 { + _mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 3)), 3) + } else { + _mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 7)), 7) + }; + let q3h_3 = _mm256_slli_epi16(q3h_3, 2); + + // load Q8 quants + let q8_0 = _mm256_loadu_si256(q8 as *const __m256i); + q8 = q8.add(32); + let q8_1 = _mm256_loadu_si256(q8 as *const __m256i); + q8 = q8.add(32); + let q8_2 = _mm256_loadu_si256(q8 as *const __m256i); + q8 = q8.add(32); + let q8_3 = _mm256_loadu_si256(q8 as *const __m256i); + q8 = q8.add(32); + + // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we + // can use _mm256_maddubs_epi16, and then subtract. The high bit part has the 2 + // already subtracted (and so, it is zero if the high bit was not set, and 2 if the + // high bit was set) + let q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0); + let q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1); + let q8s_2 = _mm256_maddubs_epi16(q3h_2, q8_2); + let q8s_3 = _mm256_maddubs_epi16(q3h_3, q8_3); + + let p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0); + let p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1); + let p16_2 = _mm256_maddubs_epi16(q3l_2, q8_2); + let p16_3 = _mm256_maddubs_epi16(q3l_3, q8_3); + + let p16_0 = _mm256_sub_epi16(p16_0, q8s_0); + let p16_1 = _mm256_sub_epi16(p16_1, q8s_1); + let p16_2 = _mm256_sub_epi16(p16_2, q8s_2); + let p16_3 = _mm256_sub_epi16(p16_3, q8s_3); + + // multiply with scales + let p16_0 = + _mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(0)), p16_0); + let p16_1 = + _mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(1)), p16_1); + let p16_2 = + _mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(2)), p16_2); + let p16_3 = + _mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(3)), p16_3); + + // accumulate + let p16_0 = _mm256_add_epi32(p16_0, p16_1); + let p16_2 = _mm256_add_epi32(p16_2, p16_3); + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_2)); + } + + // multiply with block scale and accumulate + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); + } + Ok(hsum_float_8(acc)) + } +} + +#[inline(always)] +pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result { + if n % QK_K != 0 { + crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}") + } + let mut utmp = [0u32; 4]; + const KMASK1: u32 = 0x3f3f3f3f; + const KMASK2: u32 = 0x0f0f0f0f; + const KMASK3: u32 = 0x03030303; + + unsafe { + let m4 = _mm256_set1_epi8(0xF); + + let mut acc = _mm256_setzero_ps(); + let mut acc_m = _mm_setzero_ps(); + + for (x, y) in xs.iter().zip(ys.iter()) { + let d = y.d * x.d.to_f32(); + let dmin = -y.d * x.dmin.to_f32(); + + LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]); + + utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4); + let uaux = utmp[1] & KMASK1; + utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4); + utmp[2] = uaux; + utmp[0] &= KMASK1; + + let mut q4 = x.qs.as_ptr(); + let mut q8 = y.qs.as_ptr(); + + let mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32( + utmp[3] as i32, + utmp[2] as i32, + utmp[1] as i32, + utmp[0] as i32, + )); + + let q8sums = _mm256_loadu_si256(y.bsums.as_ptr() as *const __m256i); + let q8s = _mm_hadd_epi16( + _mm256_extracti128_si256(q8sums, 0), + _mm256_extracti128_si256(q8sums, 1), + ); + let prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s); + acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m); + + let sc128 = _mm256_extracti128_si256(mins_and_scales, 0); + let scales = mm256_set_m128i(sc128, sc128); + + let mut sumi = _mm256_setzero_si256(); + + for j in 0..QK_K / 64 { + let scale_l = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2 * j)); + let scale_h = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2 * j + 1)); + + let q4bits = _mm256_loadu_si256(q4 as *const __m256i); + q4 = q4.add(32); + let q4l = _mm256_and_si256(q4bits, m4); + let q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4); + + let q8l = _mm256_loadu_si256(q8 as *const __m256i); + q8 = q8.add(32); + let p16l = _mm256_maddubs_epi16(q4l, q8l); + let p16l = _mm256_madd_epi16(scale_l, p16l); + sumi = _mm256_add_epi32(sumi, p16l); + + let q8h = _mm256_loadu_si256(q8 as *const __m256i); + q8 = q8.add(32); + let p16h = _mm256_maddubs_epi16(q4h, q8h); + let p16h = _mm256_madd_epi16(scale_h, p16h); + sumi = _mm256_add_epi32(sumi, p16h); + } + + let vd = _mm256_set1_ps(d); + acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc); + } + + let acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m)); + let acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m)); + + Ok(hsum_float_8(acc) + _mm_cvtss_f32(acc_m)) + } +} + +#[inline(always)] +pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Result { + if n % QK_K != 0 { + crate::bail!("vec_dot_q5k_q8k: {n} is not divisible by {QK_K}") + } + let mut utmp = [0u32; 4]; + const KMASK1: u32 = 0x3f3f3f3f; + const KMASK2: u32 = 0x0f0f0f0f; + const KMASK3: u32 = 0x03030303; + + unsafe { + let m4 = _mm256_set1_epi8(0xF); + let mzero = _mm_setzero_si128(); + let mone = _mm256_set1_epi8(1); + + let mut acc = _mm256_setzero_ps(); + let mut summs = 0.0; + + for (x, y) in xs.iter().zip(ys.iter()) { + let d = y.d * x.d.to_f32(); + let dmin = -y.d * x.dmin.to_f32(); + + LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]); + + utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4); + let uaux = utmp[1] & KMASK1; + utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4); + utmp[2] = uaux; + utmp[0] &= KMASK1; + + let mut q5 = x.qs.as_ptr(); + let mut q8 = y.qs.as_ptr(); + + let mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32( + utmp[3] as i32, + utmp[2] as i32, + utmp[1] as i32, + utmp[0] as i32, + )); + + let q8sums = _mm256_loadu_si256(y.bsums.as_ptr() as *const __m256i); + let q8s = _mm_hadd_epi16( + _mm256_extracti128_si256(q8sums, 0), + _mm256_extracti128_si256(q8sums, 1), + ); + let prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s); + let hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero); + summs += dmin * _mm_extract_epi32(hsum, 0) as f32; + + let sc128 = _mm256_extracti128_si256(mins_and_scales, 0); + let scales = mm256_set_m128i(sc128, sc128); + + let hbits = _mm256_loadu_si256(x.qh.as_ptr() as *const __m256i); + let mut hmask = mone; + + let mut sumi = _mm256_setzero_si256(); + + for j in 0..QK_K / 64 { + let scale_0 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2 * j)); + let scale_1 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2 * j + 1)); + + let q5bits = _mm256_loadu_si256(q5 as *const __m256i); + q5 = q5.add(32); + + //Similar to q3k we hardcode the shifts here to avoid loading them into a separate register + let q5l_0 = _mm256_and_si256(q5bits, m4); + let q5l_0_shift_input = _mm256_and_si256(hbits, hmask); + let q5l_0_right_shift = match j { + 0 => _mm256_srli_epi16(q5l_0_shift_input, 0), + 1 => _mm256_srli_epi16(q5l_0_shift_input, 2), + 2 => _mm256_srli_epi16(q5l_0_shift_input, 4), + 3 => _mm256_srli_epi16(q5l_0_shift_input, 6), + _ => unreachable!(), + }; + let q5h_0 = _mm256_slli_epi16(q5l_0_right_shift, 4); + let q5_0 = _mm256_add_epi8(q5l_0, q5h_0); + hmask = _mm256_slli_epi16(hmask, 1); + + let q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4); + let q5l_1_shift_input = _mm256_and_si256(hbits, hmask); + let q5l_1_right_shift = match j { + 0 => _mm256_srli_epi16(q5l_1_shift_input, 1), + 1 => _mm256_srli_epi16(q5l_1_shift_input, 3), + 2 => _mm256_srli_epi16(q5l_1_shift_input, 5), + 3 => _mm256_srli_epi16(q5l_1_shift_input, 7), + _ => unreachable!(), + }; + + let q5h_1 = _mm256_slli_epi16(q5l_1_right_shift, 4); + let q5_1 = _mm256_add_epi8(q5l_1, q5h_1); + hmask = _mm256_slli_epi16(hmask, 1); + + let q8_0 = _mm256_loadu_si256(q8 as *const __m256i); + q8 = q8.add(32); + let q8_1 = _mm256_loadu_si256(q8 as *const __m256i); + q8 = q8.add(32); + + let p16_0 = _mm256_maddubs_epi16(q5_0, q8_0); + let p16_1 = _mm256_maddubs_epi16(q5_1, q8_1); + + let p16_0 = _mm256_madd_epi16(scale_0, p16_0); + let p16_1 = _mm256_madd_epi16(scale_1, p16_1); + + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1)); + } + let vd = _mm256_set1_ps(d); + acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc); + } + Ok(hsum_float_8(acc) + summs) + } +} + +#[inline(always)] +pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result { + let qk = QK_K; + if n % qk != 0 { + crate::bail!("vec_dot_q8k_8k: {n} is not divisible by {qk}") + } + + unsafe { + let mut acc = _mm256_setzero_ps(); + for (xs, ys) in xs.iter().zip(ys.iter()) { + let mut sumi = _mm256_setzero_si256(); + let x_qs = xs.qs.as_ptr(); + let y_qs = ys.qs.as_ptr(); + for j in (0..QK_K).step_by(32) { + let xs = _mm256_loadu_si256(x_qs.add(j) as *const __m256i); + let ys = _mm256_loadu_si256(y_qs.add(j) as *const __m256i); + + let xs0 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(xs, 0)); + let ys0 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(ys, 0)); + sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(xs0, ys0)); + + let xs1 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(xs, 1)); + let ys1 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(ys, 1)); + sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(xs1, ys1)); + } + let d = _mm256_set1_ps(xs.d * ys.d); + acc = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi), acc); + } + Ok(hsum_float_8(acc)) + } +} diff --git a/mcandle/mcandle-core/src/quantized/cuda.rs b/mcandle/mcandle-core/src/quantized/cuda.rs new file mode 100644 index 000000000..b595e5339 --- /dev/null +++ b/mcandle/mcandle-core/src/quantized/cuda.rs @@ -0,0 +1,766 @@ +use super::{GgmlDType, QStorage}; +use crate::quantized::k_quants::GgmlType; +use crate::{backend::BackendDevice, cuda_backend::WrapErr}; +use crate::{CudaDevice, CudaStorage, Result}; +use half::f16; + +use cudarc::driver::{CudaSlice, CudaView, DeviceSlice}; + +#[derive(Clone, Debug)] +struct PaddedCudaSlice { + inner: CudaSlice, + len: usize, +} + +#[derive(Clone, Debug)] +pub struct QCudaStorage { + data: PaddedCudaSlice, + dtype: GgmlDType, + device: CudaDevice, +} + +static FORCE_DMMV: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false); + +pub fn set_force_dmmv(f: bool) { + FORCE_DMMV.store(f, std::sync::atomic::Ordering::Relaxed) +} + +pub const WARP_SIZE: usize = 32; +pub const MMQ_X_Q4_0_AMPERE: usize = 4; +pub const MMQ_Y_Q4_0_AMPERE: usize = 32; +pub const NWARPS_Q4_0_AMPERE: usize = 4; +pub const GGML_CUDA_MMV_X: usize = 32; +pub const GGML_CUDA_MMV_Y: usize = 1; +pub const CUDA_QUANTIZE_BLOCK_SIZE: usize = 256; +pub const CUDA_DEQUANTIZE_BLOCK_SIZE: usize = 256; +pub const MATRIX_ROW_PADDING: usize = 512; + +fn ceil_div(p: usize, q: usize) -> usize { + (p + q - 1) / q +} + +fn pad(p: usize, q: usize) -> usize { + ceil_div(p, q) * q +} + +fn quantize_q8_1( + src: &CudaView, + dst: &mut CudaSlice, + elem_count: usize, + ky: usize, + dev: &CudaDevice, +) -> Result<()> { + use cudarc::driver::LaunchAsync; + + let kx = elem_count; + let kx_padded = pad(kx, MATRIX_ROW_PADDING); + let num_blocks = ceil_div(kx_padded, CUDA_QUANTIZE_BLOCK_SIZE); + let func = dev.get_or_load_func("quantize_q8_1", mcandle_kernels::QUANTIZED)?; + let cfg = cudarc::driver::LaunchConfig { + grid_dim: (num_blocks as u32, ky as u32, 1), + block_dim: (CUDA_QUANTIZE_BLOCK_SIZE as u32, 1, 1), + shared_mem_bytes: 0, + }; + let params = (src, dst, kx as i32, kx_padded as i32); + unsafe { func.launch(cfg, params) }.w()?; + Ok(()) +} + +fn dequantize_f32( + data: &PaddedCudaSlice, + dtype: GgmlDType, + elem_count: usize, + dev: &CudaDevice, +) -> Result { + use cudarc::driver::LaunchAsync; + + let nb = (elem_count + 255) / 256; + let (kernel_name, is_k, block_dim, num_blocks) = match dtype { + GgmlDType::Q4_0 => ("dequantize_block_q4_0_f32", false, 32, nb), + GgmlDType::Q4_1 => ("dequantize_block_q4_1_f32", false, 32, nb), + GgmlDType::Q5_0 => ( + "dequantize_block_q5_0_f32", + false, + CUDA_DEQUANTIZE_BLOCK_SIZE, + ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE), + ), + GgmlDType::Q5_1 => ( + "dequantize_block_q5_1_f32", + false, + CUDA_DEQUANTIZE_BLOCK_SIZE, + ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE), + ), + GgmlDType::Q8_0 => ("dequantize_block_q8_0_f32", false, 32, nb), + GgmlDType::Q2K => ("dequantize_block_q2_K_f32", true, 64, nb), + GgmlDType::Q3K => ("dequantize_block_q3_K_f32", true, 64, nb), + GgmlDType::Q4K => ("dequantize_block_q4_K_f32", true, 32, nb), + GgmlDType::Q5K => ("dequantize_block_q5_K_f32", true, 64, nb), + GgmlDType::Q6K => ("dequantize_block_q6_K_f32", true, 64, nb), + GgmlDType::Q8K => ("dequantize_block_q8_K_f32", true, 32, nb), + _ => crate::bail!("unsupported dtype for dequantize {dtype:?}"), + }; + let func = dev.get_or_load_func(kernel_name, mcandle_kernels::QUANTIZED)?; + let dst = unsafe { dev.alloc::(elem_count).w()? }; + // See e.g. + // https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270 + let cfg = cudarc::driver::LaunchConfig { + grid_dim: (num_blocks as u32, 1, 1), + block_dim: (block_dim as u32, 1, 1), + shared_mem_bytes: 0, + }; + + if is_k { + let params = (&data.inner, &dst); + unsafe { func.launch(cfg, params) }.w()?; + } else { + let nb32 = match dtype { + GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count, + _ => elem_count / 32, + }; + let params = (&data.inner, &dst, nb32 as i32); + unsafe { func.launch(cfg, params) }.w()?; + } + Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone())) +} + +fn dequantize_f16( + data: &PaddedCudaSlice, + dtype: GgmlDType, + elem_count: usize, + dev: &CudaDevice, +) -> Result { + use cudarc::driver::LaunchAsync; + + let nb = (elem_count + 255) / 256; + let (kernel_name, is_k, block_dim, num_blocks) = match dtype { + GgmlDType::Q4_0 => ("dequantize_block_q4_0_f16", false, 32, nb), + GgmlDType::Q4_1 => ("dequantize_block_q4_1_f16", false, 32, nb), + GgmlDType::Q5_0 => ( + "dequantize_block_q5_0_f16", + false, + CUDA_DEQUANTIZE_BLOCK_SIZE, + ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE), + ), + GgmlDType::Q5_1 => ( + "dequantize_block_q5_1_f16", + false, + CUDA_DEQUANTIZE_BLOCK_SIZE, + ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE), + ), + GgmlDType::Q8_0 => ("dequantize_block_q8_0_f16", false, 32, nb), + GgmlDType::Q2K => ("dequantize_block_q2_K_f16", true, 64, nb), + GgmlDType::Q3K => ("dequantize_block_q3_K_f16", true, 64, nb), + GgmlDType::Q4K => ("dequantize_block_q4_K_f16", true, 32, nb), + GgmlDType::Q5K => ("dequantize_block_q5_K_f16", true, 64, nb), + GgmlDType::Q6K => ("dequantize_block_q6_K_f16", true, 64, nb), + GgmlDType::Q8K => ("dequantize_block_q8_K_f16", true, 32, nb), + _ => crate::bail!("unsupported dtype for dequantize {dtype:?}"), + }; + let func = dev.get_or_load_func(kernel_name, mcandle_kernels::QUANTIZED)?; + let dst = unsafe { dev.alloc::(elem_count).w()? }; + // See e.g. + // https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270 + let cfg = cudarc::driver::LaunchConfig { + grid_dim: (num_blocks as u32, 1, 1), + block_dim: (block_dim as u32, 1, 1), + shared_mem_bytes: 0, + }; + + if is_k { + let params = (&data.inner, &dst); + unsafe { func.launch(cfg, params) }.w()?; + } else { + let nb32 = match dtype { + GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count, + _ => elem_count / 32, + }; + let params = (&data.inner, &dst, nb32 as i32); + unsafe { func.launch(cfg, params) }.w()?; + } + Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone())) +} + +fn dequantize_mul_mat_vec( + data: &PaddedCudaSlice, + y: &CudaView, + dtype: GgmlDType, + ncols: usize, + nrows: usize, + dev: &CudaDevice, +) -> Result { + use cudarc::driver::LaunchAsync; + + let data_elems = data.len / dtype.type_size() * dtype.block_size(); + if data_elems < ncols * nrows { + crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems) + } + if y.len() != ncols { + crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len()) + } + let kernel_name = match dtype { + GgmlDType::Q4_0 => "dequantize_mul_mat_vec_q4_0_cuda", + GgmlDType::Q4_1 => "dequantize_mul_mat_vec_q4_1_cuda", + GgmlDType::Q5_0 => "dequantize_mul_mat_vec_q5_0_cuda", + GgmlDType::Q5_1 => "dequantize_mul_mat_vec_q5_1_cuda", + GgmlDType::Q8_0 => "dequantize_mul_mat_vec_q8_0_cuda", + GgmlDType::Q2K => "dequantize_mul_mat_vec_q2_k", + GgmlDType::Q3K => "dequantize_mul_mat_vec_q3_k", + GgmlDType::Q4K => "dequantize_mul_mat_vec_q4_k", + GgmlDType::Q5K => "dequantize_mul_mat_vec_q5_k", + GgmlDType::Q6K => "dequantize_mul_mat_vec_q6_k", + _ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"), + }; + let func = dev.get_or_load_func(kernel_name, mcandle_kernels::QUANTIZED)?; + let dst = unsafe { dev.alloc::(nrows).w()? }; + let block_num_y = ceil_div(nrows, GGML_CUDA_MMV_Y); + let cfg = cudarc::driver::LaunchConfig { + grid_dim: (block_num_y as u32, 1, 1), + block_dim: (WARP_SIZE as u32, GGML_CUDA_MMV_Y as u32, 1), + shared_mem_bytes: 0, + }; + + let params = (&data.inner, y, &dst, ncols as i32, nrows as i32); + unsafe { func.launch(cfg, params) }.w()?; + Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone())) +} + +fn mul_mat_vec_via_q8_1( + data: &PaddedCudaSlice, + y: &CudaView, + dtype: GgmlDType, + ncols: usize, + nrows: usize, + b_size: usize, + dev: &CudaDevice, +) -> Result { + use cudarc::driver::LaunchAsync; + + let data_elems = data.len / dtype.type_size() * dtype.block_size(); + if data_elems < ncols * nrows { + crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems) + } + if y.len() != ncols * b_size { + crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len()) + } + if b_size == 0 || b_size > 8 { + crate::bail!("only bsize between 1 and 8 are supported, got {b_size}") + } + // Start by quantizing y + let ncols_padded = pad(ncols, MATRIX_ROW_PADDING); + let y_size_in_bytes = + b_size * ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size(); + let mut y_q8_1 = unsafe { dev.alloc::(y_size_in_bytes).w()? }; + quantize_q8_1(y, &mut y_q8_1, ncols, b_size, dev)?; + + let kernel_name = match dtype { + GgmlDType::Q4_0 => "mul_mat_vec_q4_0_q8_1_cuda", + GgmlDType::Q4_1 => "mul_mat_vec_q4_1_q8_1_cuda", + GgmlDType::Q5_0 => "mul_mat_vec_q5_0_q8_1_cuda", + GgmlDType::Q5_1 => "mul_mat_vec_q5_1_q8_1_cuda", + GgmlDType::Q8_0 => "mul_mat_vec_q8_0_q8_1_cuda", + GgmlDType::Q2K => "mul_mat_vec_q2_K_q8_1_cuda", + GgmlDType::Q3K => "mul_mat_vec_q3_K_q8_1_cuda", + GgmlDType::Q4K => "mul_mat_vec_q4_K_q8_1_cuda", + GgmlDType::Q5K => "mul_mat_vec_q5_K_q8_1_cuda", + GgmlDType::Q6K => "mul_mat_vec_q6_K_q8_1_cuda", + _ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"), + }; + let kernel_name = format!("{kernel_name}{b_size}"); + let func = dev.get_or_load_func(&kernel_name, mcandle_kernels::QUANTIZED)?; + let dst = unsafe { dev.alloc::(nrows * b_size).w()? }; + // https://github.com/ggerganov/llama.cpp/blob/facb8b56f8fd3bb10a693bf0943ae9d69d0828ef/ggml-cuda/mmvq.cu#L98 + let (nblocks, nwarps) = match b_size { + 1 => (nrows as u32, 4), + 2..=4 => ((nrows as u32 + 1) / 2, 4), + 5..=8 => ((nrows as u32 + 1) / 2, 2), + _ => crate::bail!("unexpected bsize {b_size}"), + }; + let cfg = cudarc::driver::LaunchConfig { + grid_dim: (nblocks, 1, 1), + block_dim: (WARP_SIZE as u32, nwarps, 1), + shared_mem_bytes: 0, + }; + + let params = ( + &data.inner, + &y_q8_1, + &dst, + /* ncols_x */ ncols as i32, + /* nrows_x */ nrows as i32, + /* nrows_y */ ncols_padded as i32, + /* nrows_dst */ nrows as i32, + ); + unsafe { func.launch(cfg, params) }.w()?; + Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone())) +} + +#[allow(clippy::too_many_arguments)] +fn mul_mat_via_q8_1( + data: &PaddedCudaSlice, + y: &CudaView, + dtype: GgmlDType, + x_rows: usize, + x_cols: usize, + y_rows: usize, + y_cols: usize, + dev: &CudaDevice, +) -> Result { + use cudarc::driver::LaunchAsync; + + let data_elems = data.len / dtype.type_size() * dtype.block_size(); + if data_elems < x_rows * x_cols { + crate::bail!("unexpected lhs size {}, {x_rows} {x_cols}", data_elems) + } + if y.len() != y_rows * y_cols { + crate::bail!("unexpected y size {}, {y_rows} {y_cols}", y.len()) + } + if x_cols != y_rows { + crate::bail!("unexpected x/y size {x_rows} {x_cols} {y_rows} {y_cols}") + } + let k = x_cols; + // Start by quantizing y + let k_padded = pad(k, MATRIX_ROW_PADDING); + let y_size_in_bytes = + k_padded * y_cols * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size(); + let mut y_q8_1 = unsafe { dev.alloc::(y_size_in_bytes).w()? }; + quantize_q8_1(y, &mut y_q8_1, k, y_cols, dev)?; + + let (kernel_name, mmq_x, mmq_y) = match dtype { + GgmlDType::Q4_0 => ("mul_mat_q4_0", 64, 128), + GgmlDType::Q4_1 => ("mul_mat_q4_1", 64, 128), + GgmlDType::Q5_0 => ("mul_mat_q5_0", 128, 64), + GgmlDType::Q5_1 => ("mul_mat_q5_1", 128, 64), + GgmlDType::Q8_0 => ("mul_mat_q8_0", 128, 64), + GgmlDType::Q2K => ("mul_mat_q2_K", 64, 128), + GgmlDType::Q3K => ("mul_mat_q3_K", 128, 128), + GgmlDType::Q4K => ("mul_mat_q4_K", 64, 128), + GgmlDType::Q5K => ("mul_mat_q5_K", 64, 128), + GgmlDType::Q6K => ("mul_mat_q6_K", 64, 64), + _ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"), + }; + let func = dev.get_or_load_func(kernel_name, mcandle_kernels::QUANTIZED)?; + let dst = unsafe { dev.alloc::(x_rows * y_cols).w()? }; + let cfg = cudarc::driver::LaunchConfig { + grid_dim: ( + ceil_div(x_rows, mmq_y) as u32, + ceil_div(y_cols, mmq_x) as u32, + 1, + ), + block_dim: (WARP_SIZE as u32, 4, 1), + shared_mem_bytes: 0, + }; + + let params = ( + /* vx */ &data.inner, + /* vy */ &y_q8_1, + /* dst */ &dst, + /* ncols_x */ x_cols as i32, + /* nrows_x */ x_rows as i32, + /* ncols_y */ y_cols as i32, + /* nrows_y */ k_padded as i32, + /* nrows_dst */ x_rows as i32, + ); + unsafe { func.launch(cfg, params) }.w()?; + Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone())) +} + +impl QCudaStorage { + pub fn zeros(device: &CudaDevice, el_count: usize, dtype: GgmlDType) -> Result { + let size_in_bytes = ceil_div(el_count, dtype.block_size()) * dtype.type_size(); + let padded_size_in_bytes = + ceil_div(el_count + MATRIX_ROW_PADDING, dtype.block_size()) * dtype.type_size(); + let inner = device.alloc_zeros::(padded_size_in_bytes).w()?; + Ok(QCudaStorage { + data: PaddedCudaSlice { + inner, + len: size_in_bytes, + }, + device: device.clone(), + dtype, + }) + } + + pub fn dtype(&self) -> GgmlDType { + self.dtype + } + + pub fn device(&self) -> &CudaDevice { + &self.device + } + + pub fn dequantize(&self, elem_count: usize) -> Result { + fn deq(buffer: &[u8], n: usize, dst: &mut [f32]) -> Result<()> { + let slice = unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const T, n) }; + let vec = slice.to_vec(); + T::to_float(&vec, dst) + } + + let fast_kernel = matches!( + self.dtype, + GgmlDType::Q4_0 + | GgmlDType::Q4_1 + | GgmlDType::Q5_0 + | GgmlDType::Q5_1 + | GgmlDType::Q8_0 + | GgmlDType::Q2K + | GgmlDType::Q3K + | GgmlDType::Q4K + | GgmlDType::Q5K + | GgmlDType::Q6K + | GgmlDType::Q8K + ); + if fast_kernel { + return dequantize_f32(&self.data, self.dtype, elem_count, self.device()); + } + // Run the dequantization on cpu. + + let buffer = self + .device + .dtoh_sync_copy(&self.data.inner.slice(..self.data.len)) + .w()?; + let mut out = vec![0.0; elem_count]; + let block_len = elem_count / self.dtype.block_size(); + match self.dtype { + GgmlDType::F32 => deq::(&buffer, block_len, &mut out)?, + GgmlDType::F16 => deq::(&buffer, block_len, &mut out)?, + GgmlDType::BF16 => deq::(&buffer, block_len, &mut out)?, + GgmlDType::Q4_0 => deq::(&buffer, block_len, &mut out)?, + GgmlDType::Q4_1 => deq::(&buffer, block_len, &mut out)?, + GgmlDType::Q5_0 => deq::(&buffer, block_len, &mut out)?, + GgmlDType::Q5_1 => deq::(&buffer, block_len, &mut out)?, + GgmlDType::Q8_0 => deq::(&buffer, block_len, &mut out)?, + GgmlDType::Q8_1 => deq::(&buffer, block_len, &mut out)?, + GgmlDType::Q2K => deq::(&buffer, block_len, &mut out)?, + GgmlDType::Q3K => deq::(&buffer, block_len, &mut out)?, + GgmlDType::Q4K => deq::(&buffer, block_len, &mut out)?, + GgmlDType::Q5K => deq::(&buffer, block_len, &mut out)?, + GgmlDType::Q6K => deq::(&buffer, block_len, &mut out)?, + GgmlDType::Q8K => deq::(&buffer, block_len, &mut out)?, + } + + self.device + .storage_from_cpu_storage(&crate::CpuStorage::F32(out)) + } + + pub fn dequantize_f16(&self, elem_count: usize) -> Result { + dequantize_f16(&self.data, self.dtype, elem_count, self.device()) + } + + pub fn quantize(&mut self, src: &CudaStorage) -> Result<()> { + // Run the quantization on cpu. + let src = match &src.slice { + crate::cuda_backend::CudaStorageSlice::F32(data) => { + self.device.dtoh_sync_copy(data).w()? + } + _ => crate::bail!("only f32 can be quantized"), + }; + let src_len = src.len(); + let src = crate::Storage::Cpu(crate::CpuStorage::F32(src)); + let mut qcpu_storage = crate::Device::Cpu.qzeros(src_len, self.dtype)?; + qcpu_storage.quantize(&src)?; + let data = qcpu_storage.data()?; + let padded_len = + data.len() + MATRIX_ROW_PADDING * self.dtype.type_size() / self.dtype.block_size(); + let mut inner = unsafe { self.device.alloc::(padded_len).w()? }; + self.device + .htod_sync_copy_into(data.as_ref(), &mut inner.slice_mut(..data.len())) + .w()?; + self.data = PaddedCudaSlice { + inner, + len: data.len(), + }; + Ok(()) + } + + pub fn quantize_onto(&mut self, src: &crate::CpuStorage) -> Result<()> { + // Run the quantization on cpu. + let src_len = src.as_slice::()?.len(); + let mut qcpu_storage = crate::Device::Cpu.qzeros(src_len, self.dtype)?; + + if let QStorage::Cpu(storage) = &mut qcpu_storage { + storage.from_float(src.as_slice::()?)?; + } else { + unreachable!() + } + + let data = qcpu_storage.data()?; + let padded_len = + data.len() + MATRIX_ROW_PADDING * self.dtype.type_size() / self.dtype.block_size(); + let mut inner = unsafe { self.device.alloc::(padded_len).w()? }; + self.device + .htod_sync_copy_into(data.as_ref(), &mut inner.slice_mut(..data.len())) + .w()?; + self.data = PaddedCudaSlice { + inner, + len: data.len(), + }; + Ok(()) + } + + pub fn storage_size_in_bytes(&self) -> usize { + self.data.len + } + + pub fn fwd( + &self, + self_shape: &crate::Shape, + storage: &CudaStorage, + layout: &crate::Layout, + ) -> Result<(CudaStorage, crate::Shape)> { + let max_bm = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) { + 1 + } else { + 8 + }; + let use_vec_kernel = match layout.shape().dims() { + [b, m, _k] => b * m <= max_bm, + [b, _k] => *b <= max_bm, + _ => false, + }; + if use_vec_kernel { + self.dequantize_matmul_vec(self_shape, storage, layout) + } else { + self.dequantize_matmul(self_shape, storage, layout) + } + } + + pub fn data(&self) -> Result> { + self.device + .dtoh_sync_copy(&self.data.inner.slice(..self.data.len)) + .w() + } +} + +impl QCudaStorage { + fn dequantize_matmul_vec( + &self, + self_shape: &crate::Shape, + rhs: &CudaStorage, + rhs_l: &crate::Layout, + ) -> Result<(CudaStorage, crate::Shape)> { + let (nrows, ncols) = self_shape.dims2()?; + let rhs = rhs.as_cuda_slice::()?; + let rhs = match rhs_l.contiguous_offsets() { + Some((o1, o2)) => rhs.slice(o1..o2), + None => Err(crate::Error::RequiresContiguous { op: "dmmv" }.bt())?, + }; + let (b_size, k) = match rhs_l.shape().dims() { + [b, m, k] => (b * m, *k), + [b, k] => (*b, *k), + _ => crate::bail!("unexpected rhs shape in dmmv {:?}", rhs_l.shape()), + }; + if ncols != k { + crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", rhs_l.shape()) + } + + let out = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) { + dequantize_mul_mat_vec(&self.data, &rhs, self.dtype, ncols, nrows, self.device())? + } else { + mul_mat_vec_via_q8_1( + &self.data, + &rhs, + self.dtype, + ncols, + nrows, + b_size, + self.device(), + )? + }; + let mut out_shape = rhs_l.shape().dims().to_vec(); + out_shape.pop(); + out_shape.push(nrows); + Ok((out, out_shape.into())) + } + + fn dequantize_matmul( + &self, + self_shape: &crate::Shape, + storage: &CudaStorage, + layout: &crate::Layout, + ) -> Result<(CudaStorage, crate::Shape)> { + use crate::backend::BackendStorage; + let (n, k) = self_shape.dims2()?; + let (b, m, k2) = match layout.shape().dims() { + &[b, m, k2] => (b, m, k2), + &[m, k2] => (1, m, k2), + s => crate::bail!("unexpected shape for input {s:?}"), + }; + if k2 != k { + crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", layout.shape()) + } + + let out = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) { + let data_f32 = self.dequantize(n * k)?; + let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0).broadcast_as((b, k, n))?; + storage.matmul_with_alpha(&data_f32, None, (b, m, n, k), layout, &rhs_l)? + } else { + let storage = storage.as_cuda_slice::()?; + let storage = match layout.contiguous_offsets() { + Some((o1, o2)) => storage.slice(o1..o2), + None => Err(crate::Error::RequiresContiguous { + op: "quantized-matmul", + } + .bt())?, + }; + mul_mat_via_q8_1( + &self.data, + &storage, + self.dtype, + /* x_rows */ n, + /* x_cols */ k, + /* y_rows */ k, + /* y_cols */ b * m, + self.device(), + )? + }; + let mut out_shape = layout.shape().dims().to_vec(); + out_shape.pop(); + out_shape.push(n); + Ok((out, out_shape.into())) + } +} + +pub fn load_quantized( + device: &CudaDevice, + data: &[T], +) -> Result { + let data = unsafe { + std::slice::from_raw_parts(data.as_ptr() as *const u8, core::mem::size_of_val(data)) + }; + let dtype = T::DTYPE; + let padded_len = data.len() + MATRIX_ROW_PADDING * dtype.type_size() / dtype.block_size(); + let mut inner = unsafe { device.alloc::(padded_len).w()? }; + device + .htod_sync_copy_into(data, &mut inner.slice_mut(..data.len())) + .w()?; + Ok(QStorage::Cuda(QCudaStorage { + data: PaddedCudaSlice { + inner, + len: data.len(), + }, + device: device.clone(), + dtype, + })) +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn cuda_quantize_q8_1() -> Result<()> { + let dev = CudaDevice::new(0)?; + let el = 256; + let el_padded = pad(el, MATRIX_ROW_PADDING); + let y_size_in_bytes = + el_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size(); + let mut y_q8_1 = unsafe { dev.alloc::(y_size_in_bytes).w()? }; + let vs: Vec = (0..el).map(|v| v as f32).collect(); + let y = dev.htod_sync_copy(&vs).w()?; + quantize_q8_1(&y.slice(..), &mut y_q8_1, el, 1, &dev)?; + Ok(()) + } + + #[test] + fn cuda_mmv_q8_1() -> Result<()> { + let dev = CudaDevice::new(0)?; + let ncols = 256; + let vs: Vec = (0..ncols).map(|v| v as f32).collect(); + let y = dev.htod_sync_copy(&vs).w()?; + let mut xs = QCudaStorage::zeros(&dev, ncols, GgmlDType::Q4_0)?; + xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?; + let cuda_storage = mul_mat_vec_via_q8_1( + &xs.data, + &y.slice(..), + /* dtype */ GgmlDType::Q4_0, + /* ncols */ ncols, + /* nrows */ 1, + /* b_size */ 1, + &dev, + )?; + let vs = cuda_storage.as_cuda_slice::()?; + let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap(); + assert_eq!(vs.len(), 1); + // for n = 255, n.(n+1).(2n+1) / 6 = 5559680 + // Q8 means 1/256 precision. + assert_eq!(vs[0], 5561664.5); + + let cuda_storage = dequantize_mul_mat_vec( + &xs.data, + &y.slice(..), + /* dtype */ GgmlDType::Q4_0, + /* ncols */ ncols, + /* nrows */ 1, + &dev, + )?; + let vs = cuda_storage.as_cuda_slice::()?; + let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap(); + assert_eq!(vs.len(), 1); + assert_eq!(vs[0], 5561851.0); + Ok(()) + } + + #[test] + fn cuda_mm_q8_1() -> Result<()> { + let dev = CudaDevice::new(0)?; + let ncols = 256; + let vs: Vec = (0..ncols * 4).map(|v| v as f32 / 4.).collect(); + let y = dev.htod_sync_copy(&vs).w()?; + let mut xs = QCudaStorage::zeros(&dev, ncols * 4, GgmlDType::Q4_0)?; + xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?; + let cuda_storage = mul_mat_via_q8_1( + &xs.data, + &y.slice(..), + /* dtype */ GgmlDType::Q4_0, + /* x_rows */ 4, + /* x_cols */ ncols, + /* y_rows */ ncols, + /* y_cols */ 4, + &dev, + )?; + let vs = cuda_storage.as_cuda_slice::()?; + let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap(); + + /* + x = torch.tensor([float(v) for v in range(1024)]).reshape(4, 256) + x @ x.t() / 16 + tensor([[ 347480.0000, 869720.0000, 1391960.0000, 1914200.0000], + [ 869720.0000, 2440536.0000, 4011352.0000, 5582166.5000], + [ 1391960.0000, 4011352.0000, 6630742.0000, 9250132.0000], + [ 1914200.0000, 5582166.5000, 9250132.0000, 12918099.0000]]) + */ + assert_eq!(vs.len(), 16); + assert_eq!(vs[0], 347604.0); + assert_eq!(vs[1], 888153.06); + assert_eq!(vs[4], 869780.7); + assert_eq!(vs[5], 2483145.0); + assert_eq!(vs[11], 9407368.0); + assert_eq!(vs[14], 9470856.0); + assert_eq!(vs[15], 13138824.0); + Ok(()) + } + + // The following test used to fail under compute-sanitizer until #2526. + #[test] + fn cuda_mm_q8_1_pad() -> Result<()> { + let dev = CudaDevice::new(0)?; + let (x_rows, ncols, y_cols) = (4, 16, 2048); + let vs: Vec = (0..ncols * y_cols).map(|v| v as f32 / 256.).collect(); + let y = dev.htod_sync_copy(&vs).w()?; + let mut xs = QCudaStorage::zeros(&dev, ncols * x_rows, GgmlDType::Q4_0)?; + xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?; + let cuda_storage = mul_mat_via_q8_1( + &xs.data, + &y.slice(..), + /* dtype */ GgmlDType::Q4_0, + /* x_rows */ x_rows, + /* x_cols */ ncols, + /* y_rows */ ncols, + /* y_cols */ y_cols, + &dev, + )?; + let vs = cuda_storage.as_cuda_slice::()?; + let _vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap(); + Ok(()) + } +} diff --git a/mcandle/mcandle-core/src/quantized/dummy_cuda.rs b/mcandle/mcandle-core/src/quantized/dummy_cuda.rs new file mode 100644 index 000000000..23a9e05bc --- /dev/null +++ b/mcandle/mcandle-core/src/quantized/dummy_cuda.rs @@ -0,0 +1,62 @@ +#![allow(unused)] +use super::GgmlDType; +use crate::{CudaDevice, CudaStorage, Error, Result}; + +pub struct QCudaStorage { + dtype: GgmlDType, + device: CudaDevice, +} + +impl QCudaStorage { + pub fn zeros(_: &CudaDevice, _: usize, _: GgmlDType) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + + pub fn dtype(&self) -> GgmlDType { + self.dtype + } + + pub fn device(&self) -> &CudaDevice { + &self.device + } + + pub fn dequantize(&self, _elem_count: usize) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + + pub fn dequantize_f16(&self, _elem_count: usize) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + + pub fn quantize(&mut self, _src: &CudaStorage) -> Result<()> { + Err(Error::NotCompiledWithCudaSupport) + } + + pub fn quantize_onto(&mut self, _src: &crate::CpuStorage) -> Result<()> { + Err(Error::NotCompiledWithCudaSupport) + } + + pub fn storage_size_in_bytes(&self) -> usize { + 0 + } + + pub fn fwd( + &self, + _self_shape: &crate::Shape, + _storage: &CudaStorage, + _layout: &crate::Layout, + ) -> Result<(CudaStorage, crate::Shape)> { + Err(Error::NotCompiledWithCudaSupport) + } + + pub fn data(&self) -> Result> { + Err(Error::NotCompiledWithCudaSupport) + } +} + +pub fn load_quantized( + _device: &CudaDevice, + _data: &[T], +) -> Result { + Err(Error::NotCompiledWithCudaSupport) +} diff --git a/mcandle/mcandle-core/src/quantized/dummy_metal.rs b/mcandle/mcandle-core/src/quantized/dummy_metal.rs new file mode 100644 index 000000000..c5c8db928 --- /dev/null +++ b/mcandle/mcandle-core/src/quantized/dummy_metal.rs @@ -0,0 +1,58 @@ +#![allow(unused)] +use super::GgmlDType; +use crate::{Error, MetalDevice, MetalStorage, Result}; + +pub struct QMetalStorage { + dtype: GgmlDType, + device: MetalDevice, +} + +impl QMetalStorage { + pub fn zeros(_: &MetalDevice, _: usize, _: GgmlDType) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + pub fn dtype(&self) -> GgmlDType { + self.dtype + } + + pub fn device(&self) -> &MetalDevice { + &self.device + } + + pub fn dequantize(&self, _elem_count: usize) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + + pub fn quantize(&mut self, _src: &MetalStorage) -> Result<()> { + Err(Error::NotCompiledWithMetalSupport) + } + + pub fn quantize_onto(&mut self, _src: &crate::CpuStorage) -> Result<()> { + Err(Error::NotCompiledWithCudaSupport) + } + + pub fn storage_size_in_bytes(&self) -> usize { + 0 + } + + pub fn fwd( + &self, + _self_shape: &crate::Shape, + _storage: &MetalStorage, + _layout: &crate::Layout, + ) -> Result<(MetalStorage, crate::Shape)> { + Err(Error::NotCompiledWithMetalSupport) + } + + pub fn data(&self) -> Result> { + Err(Error::NotCompiledWithMetalSupport) + } +} + +pub fn load_quantized( + _device: &MetalDevice, + _data: &[T], +) -> Result { + Err(Error::NotCompiledWithMetalSupport) +} diff --git a/mcandle/mcandle-core/src/quantized/ggml_file.rs b/mcandle/mcandle-core/src/quantized/ggml_file.rs new file mode 100644 index 000000000..ea5ec0257 --- /dev/null +++ b/mcandle/mcandle-core/src/quantized/ggml_file.rs @@ -0,0 +1,266 @@ +//! Support for the GGML file format. + +use super::{k_quants, GgmlDType, QStorage}; +use crate::{Device, Result}; +use byteorder::{LittleEndian, ReadBytesExt}; +use std::collections::HashMap; + +// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.h#L37 +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum Magic { + Ggjt, + Ggla, + Ggmf, + Ggml, + Ggsn, +} + +impl TryFrom for Magic { + type Error = crate::Error; + fn try_from(value: u32) -> Result { + let magic = match value { + 0x67676a74 => Self::Ggjt, + 0x67676c61 => Self::Ggla, + 0x67676d66 => Self::Ggmf, + 0x67676d6c => Self::Ggml, + 0x6767736e => Self::Ggsn, + _ => crate::bail!("unknown magic {value:08x}"), + }; + Ok(magic) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum VersionedMagic { + GgmlUnversioned, + GgmfV1, + GgjtV1, + GgjtV2, + GgjtV3, +} + +impl VersionedMagic { + fn read(reader: &mut R) -> Result { + let magic = reader.read_u32::()?; + let magic = Magic::try_from(magic)?; + if magic == Magic::Ggml { + return Ok(Self::GgmlUnversioned); + } + let version = reader.read_u32::()?; + let versioned_magic = match (magic, version) { + (Magic::Ggmf, 1) => Self::GgmfV1, + (Magic::Ggjt, 1) => Self::GgjtV1, + (Magic::Ggjt, 2) => Self::GgjtV2, + (Magic::Ggjt, 3) => Self::GgjtV3, + _ => crate::bail!("ggml: unsupported magic/version {magic:?}/{version}"), + }; + Ok(versioned_magic) + } + + fn align32(&self) -> bool { + match self { + Self::GgmlUnversioned | Self::GgmfV1 => false, + Self::GgjtV1 | Self::GgjtV2 | Self::GgjtV3 => true, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct HParams { + pub n_vocab: u32, + pub n_embd: u32, + pub n_mult: u32, + pub n_head: u32, + pub n_layer: u32, + pub n_rot: u32, + pub ftype: u32, +} + +impl HParams { + fn read(reader: &mut R) -> Result { + let n_vocab = reader.read_u32::()?; + let n_embd = reader.read_u32::()?; + let n_mult = reader.read_u32::()?; + let n_head = reader.read_u32::()?; + let n_layer = reader.read_u32::()?; + let n_rot = reader.read_u32::()?; + let ftype = reader.read_u32::()?; + Ok(Self { + n_vocab, + n_embd, + n_mult, + n_head, + n_layer, + n_rot, + ftype, + }) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct Vocab { + pub token_score_pairs: Vec<(Vec, f32)>, +} + +impl Vocab { + fn read(reader: &mut R, n_vocab: usize) -> Result { + // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L556 + let mut token_score_pairs = Vec::with_capacity(n_vocab); + for _index in 0..n_vocab { + let len = reader.read_u32::()? as usize; + let mut word = vec![0u8; len]; + reader.read_exact(&mut word)?; + let score = reader.read_f32::()?; + token_score_pairs.push((word, score)) + } + Ok(Self { token_score_pairs }) + } +} + +fn from_raw_data( + raw_data: &[u8], + size_in_bytes: usize, + dims: Vec, + device: &Device, +) -> Result { + let raw_data_ptr = raw_data.as_ptr(); + let n_blocks = size_in_bytes / std::mem::size_of::(); + let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) }; + let data: QStorage = match device { + Device::Cpu => QStorage::Cpu(Box::new(data.to_vec())), + Device::Metal(metal) => super::metal::load_quantized(metal, data)?, + Device::Cuda(cuda) => super::cuda::load_quantized(cuda, data)?, + }; + super::QTensor::new(data, dims) +} + +/// Creates a [Tensor] from a raw GGML tensor. +pub fn qtensor_from_ggml( + ggml_dtype: GgmlDType, + raw_data: &[u8], + dims: Vec, + device: &Device, +) -> Result { + let tensor_elems = dims.iter().product::(); + let block_size = ggml_dtype.block_size(); + if tensor_elems % block_size != 0 { + crate::bail!( + "the number of elements {tensor_elems} is not divisible by the block size {block_size}" + ) + } + let size_in_bytes = tensor_elems / block_size * ggml_dtype.type_size(); + + match ggml_dtype { + GgmlDType::F32 => from_raw_data::(raw_data, size_in_bytes, dims, device), + GgmlDType::F16 => from_raw_data::(raw_data, size_in_bytes, dims, device), + GgmlDType::BF16 => from_raw_data::(raw_data, size_in_bytes, dims, device), + GgmlDType::Q4_0 => { + from_raw_data::(raw_data, size_in_bytes, dims, device) + } + GgmlDType::Q4_1 => { + from_raw_data::(raw_data, size_in_bytes, dims, device) + } + GgmlDType::Q5_0 => { + from_raw_data::(raw_data, size_in_bytes, dims, device) + } + GgmlDType::Q5_1 => { + from_raw_data::(raw_data, size_in_bytes, dims, device) + } + GgmlDType::Q8_0 => { + from_raw_data::(raw_data, size_in_bytes, dims, device) + } + GgmlDType::Q2K => { + from_raw_data::(raw_data, size_in_bytes, dims, device) + } + GgmlDType::Q3K => { + from_raw_data::(raw_data, size_in_bytes, dims, device) + } + GgmlDType::Q4K => { + from_raw_data::(raw_data, size_in_bytes, dims, device) + } + GgmlDType::Q5K => { + from_raw_data::(raw_data, size_in_bytes, dims, device) + } + GgmlDType::Q6K => { + from_raw_data::(raw_data, size_in_bytes, dims, device) + } + _ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"), + } +} + +fn read_one_tensor( + reader: &mut R, + magic: VersionedMagic, + device: &Device, +) -> Result<(String, super::QTensor)> { + let n_dims = reader.read_u32::()?; + let name_len = reader.read_u32::()?; + let ggml_dtype = reader.read_u32::()?; + let ggml_dtype = GgmlDType::from_u32(ggml_dtype)?; + let mut dims = vec![0u32; n_dims as usize]; + reader.read_u32_into::(&mut dims)?; + // The dimensions are stored in reverse order, see for example: + // https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/convert.py#L969 + dims.reverse(); + let mut name = vec![0u8; name_len as usize]; + reader.read_exact(&mut name)?; + let name = String::from_utf8_lossy(&name).into_owned(); + + if magic.align32() { + let pos = reader.stream_position()?; + reader.seek(std::io::SeekFrom::Current(((32 - pos % 32) % 32) as i64))?; + } + let dims = dims.iter().map(|&u| u as usize).collect::>(); + let tensor_elems = dims.iter().product::(); + let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.block_size(); + // TODO: Mmap version to avoid copying the data around? + let mut raw_data = vec![0u8; size_in_bytes]; + reader.read_exact(&mut raw_data)?; + match qtensor_from_ggml(ggml_dtype, &raw_data, dims, device) { + Ok(tensor) => Ok((name, tensor)), + Err(e) => crate::bail!("Error creating tensor {name}: {e}"), + } +} + +pub struct Content { + pub magic: VersionedMagic, + pub hparams: HParams, + pub vocab: Vocab, + pub tensors: HashMap, + pub device: Device, +} + +impl Content { + pub fn read( + reader: &mut R, + device: &Device, + ) -> Result { + // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505 + let last_position = reader.seek(std::io::SeekFrom::End(0))?; + reader.seek(std::io::SeekFrom::Start(0))?; + let magic = VersionedMagic::read(reader)?; + let hparams = HParams::read(reader)?; + let vocab = Vocab::read(reader, hparams.n_vocab as usize)?; + let mut tensors = HashMap::new(); + + while reader.stream_position()? != last_position { + let (name, tensor) = read_one_tensor(reader, magic, device)?; + tensors.insert(name, tensor); + } + let device = device.clone(); + Ok(Self { + magic, + hparams, + vocab, + tensors, + device, + }) + } + + pub fn remove(&mut self, name: &str) -> Result { + match self.tensors.remove(name) { + None => crate::bail!("cannot find tensor with name '{name}'"), + Some(tensor) => Ok(tensor), + } + } +} diff --git a/mcandle/mcandle-core/src/quantized/gguf_file.rs b/mcandle/mcandle-core/src/quantized/gguf_file.rs new file mode 100644 index 000000000..d3fe4b585 --- /dev/null +++ b/mcandle/mcandle-core/src/quantized/gguf_file.rs @@ -0,0 +1,539 @@ +//! Support for the GGUF file format. +//! +//! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md + +use super::{GgmlDType, QTensor}; +use crate::{Device, Result}; +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; +use std::collections::HashMap; + +pub const DEFAULT_ALIGNMENT: u64 = 32; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum Magic { + Gguf, +} + +impl TryFrom for Magic { + type Error = crate::Error; + fn try_from(value: u32) -> Result { + let magic = match value { + 0x46554747 | 0x47475546 => Self::Gguf, + _ => crate::bail!("unknown magic 0x{value:08x}"), + }; + Ok(magic) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum VersionedMagic { + GgufV1, + GgufV2, + GgufV3, +} + +impl VersionedMagic { + fn read(reader: &mut R) -> Result { + let magic = reader.read_u32::()?; + let magic = Magic::try_from(magic)?; + let version = reader.read_u32::()?; + let versioned_magic = match (magic, version) { + (Magic::Gguf, 1) => Self::GgufV1, + (Magic::Gguf, 2) => Self::GgufV2, + (Magic::Gguf, 3) => Self::GgufV3, + _ => crate::bail!("gguf: unsupported magic/version {magic:?}/{version}"), + }; + Ok(versioned_magic) + } +} + +#[derive(Debug)] +pub struct TensorInfo { + pub ggml_dtype: GgmlDType, + pub shape: crate::Shape, + pub offset: u64, +} + +impl TensorInfo { + pub fn read( + &self, + reader: &mut R, + tensor_data_offset: u64, + device: &Device, + ) -> Result { + let tensor_elems = self.shape.elem_count(); + let block_size = self.ggml_dtype.block_size(); + if tensor_elems % block_size != 0 { + crate::bail!( + "the number of elements {tensor_elems} is not divisible by the block size {block_size}" + ) + } + let size_in_bytes = tensor_elems / block_size * self.ggml_dtype.type_size(); + let mut raw_data = vec![0u8; size_in_bytes]; + reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?; + reader.read_exact(&mut raw_data)?; + super::ggml_file::qtensor_from_ggml( + self.ggml_dtype, + &raw_data, + self.shape.dims().to_vec(), + device, + ) + } +} + +#[derive(Debug)] +pub struct Content { + pub magic: VersionedMagic, + pub metadata: HashMap, + pub tensor_infos: HashMap, + pub tensor_data_offset: u64, +} + +fn read_string(reader: &mut R, magic: &VersionedMagic) -> Result { + let len = match magic { + VersionedMagic::GgufV1 => reader.read_u32::()? as usize, + VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => { + reader.read_u64::()? as usize + } + }; + let mut v = vec![0u8; len]; + reader.read_exact(&mut v)?; + // GGUF strings are supposed to be non-null terminated but in practice this happens. + while let Some(0) = v.last() { + v.pop(); + } + // GGUF strings are utf8 encoded but there are cases that don't seem to be valid. + Ok(String::from_utf8_lossy(&v).into_owned()) +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum ValueType { + // The value is a 8-bit unsigned integer. + U8, + // The value is a 8-bit signed integer. + I8, + // The value is a 16-bit unsigned little-endian integer. + U16, + // The value is a 16-bit signed little-endian integer. + I16, + // The value is a 32-bit unsigned little-endian integer. + U32, + // The value is a 32-bit signed little-endian integer. + I32, + // The value is a 64-bit unsigned little-endian integer. + U64, + // The value is a 64-bit signed little-endian integer. + I64, + // The value is a 32-bit IEEE754 floating point number. + F32, + // The value is a 64-bit IEEE754 floating point number. + F64, + // The value is a boolean. + // 1-byte value where 0 is false and 1 is true. + // Anything else is invalid, and should be treated as either the model being invalid or the reader being buggy. + Bool, + // The value is a UTF-8 non-null-terminated string, with length prepended. + String, + // The value is an array of other values, with the length and type prepended. + // Arrays can be nested, and the length of the array is the number of elements in the array, not the number of bytes. + Array, +} + +#[derive(Debug, Clone)] +pub enum Value { + U8(u8), + I8(i8), + U16(u16), + I16(i16), + U32(u32), + I32(i32), + U64(u64), + I64(i64), + F32(f32), + F64(f64), + Bool(bool), + String(String), + Array(Vec), +} + +impl Value { + pub fn value_type(&self) -> ValueType { + match self { + Self::U8(_) => ValueType::U8, + Self::I8(_) => ValueType::I8, + Self::U16(_) => ValueType::U16, + Self::I16(_) => ValueType::I16, + Self::U32(_) => ValueType::U32, + Self::I32(_) => ValueType::I32, + Self::U64(_) => ValueType::U64, + Self::I64(_) => ValueType::I64, + Self::F32(_) => ValueType::F32, + Self::F64(_) => ValueType::F64, + Self::Bool(_) => ValueType::Bool, + Self::String(_) => ValueType::String, + Self::Array(_) => ValueType::Array, + } + } + + pub fn to_u8(&self) -> Result { + match self { + Self::U8(v) => Ok(*v), + v => crate::bail!("not a u8 {v:?}"), + } + } + + pub fn to_i8(&self) -> Result { + match self { + Self::I8(v) => Ok(*v), + v => crate::bail!("not a i8 {v:?}"), + } + } + + pub fn to_u16(&self) -> Result { + match self { + Self::U16(v) => Ok(*v), + v => crate::bail!("not a u16 {v:?}"), + } + } + + pub fn to_i16(&self) -> Result { + match self { + Self::I16(v) => Ok(*v), + v => crate::bail!("not a i16 {v:?}"), + } + } + + pub fn to_u32(&self) -> Result { + match self { + Self::U32(v) => Ok(*v), + v => crate::bail!("not a u32 {v:?}"), + } + } + + pub fn to_i32(&self) -> Result { + match self { + Self::I32(v) => Ok(*v), + v => crate::bail!("not a i32 {v:?}"), + } + } + + /// This will also automatically upcast any integral types which will not truncate. + pub fn to_u64(&self) -> Result { + match self { + Self::U64(v) => Ok(*v), + // Autoupcast cases here + Self::U8(v) => Ok(*v as u64), + Self::U16(v) => Ok(*v as u64), + Self::U32(v) => Ok(*v as u64), + Self::Bool(v) => Ok(*v as u64), + v => crate::bail!("not a u64 or upcastable to u64 {v:?}"), + } + } + + pub fn to_i64(&self) -> Result { + match self { + Self::I64(v) => Ok(*v), + v => crate::bail!("not a i64 {v:?}"), + } + } + + pub fn to_f32(&self) -> Result { + match self { + Self::F32(v) => Ok(*v), + v => crate::bail!("not a f32 {v:?}"), + } + } + + pub fn to_f64(&self) -> Result { + match self { + Self::F64(v) => Ok(*v), + v => crate::bail!("not a f64 {v:?}"), + } + } + + pub fn to_bool(&self) -> Result { + match self { + Self::Bool(v) => Ok(*v), + v => crate::bail!("not a bool {v:?}"), + } + } + + pub fn to_vec(&self) -> Result<&Vec> { + match self { + Self::Array(v) => Ok(v), + v => crate::bail!("not a vec {v:?}"), + } + } + + pub fn to_string(&self) -> Result<&String> { + match self { + Self::String(v) => Ok(v), + v => crate::bail!("not a string {v:?}"), + } + } + + fn read( + reader: &mut R, + value_type: ValueType, + magic: &VersionedMagic, + ) -> Result { + let v = match value_type { + ValueType::U8 => Self::U8(reader.read_u8()?), + ValueType::I8 => Self::I8(reader.read_i8()?), + ValueType::U16 => Self::U16(reader.read_u16::()?), + ValueType::I16 => Self::I16(reader.read_i16::()?), + ValueType::U32 => Self::U32(reader.read_u32::()?), + ValueType::I32 => Self::I32(reader.read_i32::()?), + ValueType::U64 => Self::U64(reader.read_u64::()?), + ValueType::I64 => Self::I64(reader.read_i64::()?), + ValueType::F32 => Self::F32(reader.read_f32::()?), + ValueType::F64 => Self::F64(reader.read_f64::()?), + ValueType::Bool => match reader.read_u8()? { + 0 => Self::Bool(false), + 1 => Self::Bool(true), + b => crate::bail!("unexpected bool value {b}"), + }, + ValueType::String => Self::String(read_string(reader, magic)?), + ValueType::Array => { + let value_type = reader.read_u32::()?; + let value_type = ValueType::from_u32(value_type)?; + let len = match magic { + VersionedMagic::GgufV1 => reader.read_u32::()? as usize, + VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => { + reader.read_u64::()? as usize + } + }; + let mut vs = Vec::with_capacity(len); + for _ in 0..len { + vs.push(Value::read(reader, value_type, magic)?) + } + Self::Array(vs) + } + }; + Ok(v) + } + + fn write(&self, w: &mut W) -> Result<()> { + match self { + &Self::U8(v) => w.write_u8(v)?, + &Self::I8(v) => w.write_i8(v)?, + &Self::U16(v) => w.write_u16::(v)?, + &Self::I16(v) => w.write_i16::(v)?, + &Self::U32(v) => w.write_u32::(v)?, + &Self::I32(v) => w.write_i32::(v)?, + &Self::U64(v) => w.write_u64::(v)?, + &Self::I64(v) => w.write_i64::(v)?, + &Self::F32(v) => w.write_f32::(v)?, + &Self::F64(v) => w.write_f64::(v)?, + &Self::Bool(v) => w.write_u8(u8::from(v))?, + Self::String(v) => write_string(w, v.as_str())?, + Self::Array(v) => { + // The `Value` type does not enforce that all the values in an Array have the same + // type. + let value_type = if v.is_empty() { + // Doesn't matter, the array is empty. + ValueType::U32 + } else { + let value_type: std::collections::HashSet<_> = + v.iter().map(|elem| elem.value_type()).collect(); + if value_type.len() != 1 { + crate::bail!("multiple value-types in the same array {value_type:?}") + } + value_type.into_iter().next().unwrap() + }; + w.write_u32::(value_type.to_u32())?; + w.write_u64::(v.len() as u64)?; + for elem in v.iter() { + elem.write(w)? + } + } + } + Ok(()) + } +} + +impl ValueType { + fn from_u32(v: u32) -> Result { + let v = match v { + 0 => Self::U8, + 1 => Self::I8, + 2 => Self::U16, + 3 => Self::I16, + 4 => Self::U32, + 5 => Self::I32, + 6 => Self::F32, + 7 => Self::Bool, + 8 => Self::String, + 9 => Self::Array, + 10 => Self::U64, + 11 => Self::I64, + 12 => Self::F64, + v => crate::bail!("unrecognized value-type {v:#08x}"), + }; + Ok(v) + } + + fn to_u32(self) -> u32 { + match self { + Self::U8 => 0, + Self::I8 => 1, + Self::U16 => 2, + Self::I16 => 3, + Self::U32 => 4, + Self::I32 => 5, + Self::F32 => 6, + Self::Bool => 7, + Self::String => 8, + Self::Array => 9, + Self::U64 => 10, + Self::I64 => 11, + Self::F64 => 12, + } + } +} + +impl Content { + pub fn read(reader: &mut R) -> Result { + let magic = VersionedMagic::read(reader)?; + + let tensor_count = match magic { + VersionedMagic::GgufV1 => reader.read_u32::()? as usize, + VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => { + reader.read_u64::()? as usize + } + }; + let metadata_kv_count = match magic { + VersionedMagic::GgufV1 => reader.read_u32::()? as usize, + VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => { + reader.read_u64::()? as usize + } + }; + + let mut metadata = HashMap::new(); + for _idx in 0..metadata_kv_count { + let key = read_string(reader, &magic)?; + let value_type = reader.read_u32::()?; + let value_type = ValueType::from_u32(value_type)?; + let value = Value::read(reader, value_type, &magic)?; + metadata.insert(key, value); + } + let mut tensor_infos = HashMap::new(); + for _idx in 0..tensor_count { + let tensor_name = read_string(reader, &magic)?; + let n_dimensions = reader.read_u32::()?; + + let mut dimensions: Vec = match magic { + VersionedMagic::GgufV1 => { + let mut dimensions = vec![0; n_dimensions as usize]; + reader.read_u32_into::(&mut dimensions)?; + dimensions.into_iter().map(|c| c as usize).collect() + } + VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => { + let mut dimensions = vec![0; n_dimensions as usize]; + reader.read_u64_into::(&mut dimensions)?; + dimensions.into_iter().map(|c| c as usize).collect() + } + }; + + dimensions.reverse(); + let ggml_dtype = reader.read_u32::()?; + let ggml_dtype = GgmlDType::from_u32(ggml_dtype)?; + let offset = reader.read_u64::()?; + tensor_infos.insert( + tensor_name, + TensorInfo { + shape: crate::Shape::from(dimensions), + offset, + ggml_dtype, + }, + ); + } + let position = reader.stream_position()?; + let alignment = match metadata.get("general.alignment") { + Some(Value::U8(v)) => *v as u64, + Some(Value::U16(v)) => *v as u64, + Some(Value::U32(v)) => *v as u64, + Some(Value::I8(v)) if *v >= 0 => *v as u64, + Some(Value::I16(v)) if *v >= 0 => *v as u64, + Some(Value::I32(v)) if *v >= 0 => *v as u64, + _ => DEFAULT_ALIGNMENT, + }; + let tensor_data_offset = (position + alignment - 1) / alignment * alignment; + Ok(Self { + magic, + metadata, + tensor_infos, + tensor_data_offset, + }) + } + + pub fn tensor( + &self, + reader: &mut R, + name: &str, + device: &Device, + ) -> Result { + let tensor_info = match self.tensor_infos.get(name) { + Some(tensor_info) => tensor_info, + None => crate::bail!("cannot find tensor info for {name}"), + }; + tensor_info.read(reader, self.tensor_data_offset, device) + } +} + +fn write_string(w: &mut W, str: &str) -> Result<()> { + let bytes = str.as_bytes(); + w.write_u64::(bytes.len() as u64)?; + w.write_all(bytes)?; + Ok(()) +} + +pub fn write( + w: &mut W, + metadata: &[(&str, &Value)], + tensors: &[(&str, &QTensor)], +) -> Result<()> { + w.write_u32::(0x46554747)?; + w.write_u32::(2)?; // version 2. + w.write_u64::(tensors.len() as u64)?; + w.write_u64::(metadata.len() as u64)?; + for (name, value) in metadata.iter() { + write_string(w, name)?; + w.write_u32::(value.value_type().to_u32())?; + value.write(w)?; + } + let mut offset = 0usize; + let mut offsets = Vec::with_capacity(tensors.len()); + for (name, tensor) in tensors.iter() { + write_string(w, name)?; + let dims = tensor.shape().dims(); + w.write_u32::(dims.len() as u32)?; + for &dim in dims.iter().rev() { + w.write_u64::(dim as u64)?; + } + w.write_u32::(tensor.dtype().to_u32())?; + w.write_u64::(offset as u64)?; + offsets.push(offset); + let size_in_bytes = tensor.storage_size_in_bytes(); + let padding = 31 - (31 + size_in_bytes) % 32; + offset += size_in_bytes + padding; + } + let pos = w.stream_position()? as usize; + let padding = 31 - (31 + pos) % 32; + w.write_all(&vec![0u8; padding])?; + let tensor_start_pos = w.stream_position()? as usize; + for (offset, (_name, tensor)) in offsets.iter().zip(tensors.iter()) { + let pos = w.stream_position()? as usize; + if tensor_start_pos + offset != pos { + crate::bail!( + "internal error, unexpected current position {tensor_start_pos} {offset} {pos}" + ) + } + let data = tensor.data()?; + let size_in_bytes = data.len(); + w.write_all(&data)?; + let padding = 31 - (31 + size_in_bytes) % 32; + w.write_all(&vec![0u8; padding])?; + } + Ok(()) +} diff --git a/mcandle/mcandle-core/src/quantized/k_quants.rs b/mcandle/mcandle-core/src/quantized/k_quants.rs new file mode 100644 index 000000000..2e9292195 --- /dev/null +++ b/mcandle/mcandle-core/src/quantized/k_quants.rs @@ -0,0 +1,2009 @@ +use super::utils::{ + get_scale_min_k4, group_for_dequantization, group_for_quantization, make_q3_quants, + make_qkx1_quants, make_qx_quants, nearest_int, +}; +use super::GgmlDType; +use crate::Result; +use byteorder::{ByteOrder, LittleEndian}; +use half::{bf16, f16}; +use rayon::prelude::*; + +// Default to QK_K 256 rather than 64. +pub const QK_K: usize = 256; +pub const K_SCALE_SIZE: usize = 12; + +pub const QK4_0: usize = 32; +pub const QK4_1: usize = 32; +pub const QK5_0: usize = 32; +pub const QK5_1: usize = 32; +pub const QK8_0: usize = 32; +pub const QK8_1: usize = 32; + +pub trait GgmlType: Sized + Clone + Send + Sync { + const DTYPE: GgmlDType; + const BLCK_SIZE: usize; + type VecDotType: GgmlType; + + // This is only safe for types that include immediate values such as float/int/... + fn zeros() -> Self { + unsafe { std::mem::MaybeUninit::zeroed().assume_init() } + } + fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()>; + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()>; + + /// Dot product used as a building block for quantized mat-mul. + /// n is the number of elements to be considered. + fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result; + + /// Generic implementation of the dot product without simd optimizations. + fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result; +} + +#[derive(Debug, Clone, PartialEq)] +#[repr(C)] +pub struct BlockQ4_0 { + pub(crate) d: f16, + pub(crate) qs: [u8; QK4_0 / 2], +} +const _: () = assert!(std::mem::size_of::() == 18); + +#[derive(Debug, Clone, PartialEq)] +#[repr(C)] +pub struct BlockQ4_1 { + pub(crate) d: f16, + pub(crate) m: f16, + pub(crate) qs: [u8; QK4_1 / 2], +} +const _: () = assert!(std::mem::size_of::() == 20); + +#[derive(Debug, Clone, PartialEq)] +#[repr(C)] +pub struct BlockQ5_0 { + pub(crate) d: f16, + pub(crate) qh: [u8; 4], + pub(crate) qs: [u8; QK5_0 / 2], +} +const _: () = assert!(std::mem::size_of::() == 22); + +#[derive(Debug, Clone, PartialEq)] +#[repr(C)] +pub struct BlockQ5_1 { + pub(crate) d: f16, + pub(crate) m: f16, + pub(crate) qh: [u8; 4], + pub(crate) qs: [u8; QK5_1 / 2], +} +const _: () = assert!(std::mem::size_of::() == 24); + +#[derive(Debug, Clone, PartialEq)] +#[repr(C)] +pub struct BlockQ8_0 { + pub(crate) d: f16, + pub(crate) qs: [i8; QK8_0], +} +const _: () = assert!(std::mem::size_of::() == 34); + +#[derive(Debug, Clone, PartialEq)] +#[repr(C)] +pub struct BlockQ8_1 { + pub(crate) d: f16, + pub(crate) s: f16, + pub(crate) qs: [i8; QK8_1], +} +const _: () = assert!(std::mem::size_of::() == 36); + +#[derive(Debug, Clone, PartialEq)] +#[repr(C)] +pub struct BlockQ2K { + pub(crate) scales: [u8; QK_K / 16], + pub(crate) qs: [u8; QK_K / 4], + pub(crate) d: f16, + pub(crate) dmin: f16, +} +const _: () = assert!(QK_K / 16 + QK_K / 4 + 2 * 2 == std::mem::size_of::()); + +#[derive(Debug, Clone, PartialEq)] +#[repr(C)] +pub struct BlockQ3K { + pub(crate) hmask: [u8; QK_K / 8], + pub(crate) qs: [u8; QK_K / 4], + pub(crate) scales: [u8; 12], + pub(crate) d: f16, +} +const _: () = assert!(QK_K / 8 + QK_K / 4 + 12 + 2 == std::mem::size_of::()); + +#[derive(Debug, Clone, PartialEq)] +// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/k_quants.h#L82 +#[repr(C)] +pub struct BlockQ4K { + pub(crate) d: f16, + pub(crate) dmin: f16, + pub(crate) scales: [u8; K_SCALE_SIZE], + pub(crate) qs: [u8; QK_K / 2], +} +const _: () = assert!(QK_K / 2 + K_SCALE_SIZE + 2 * 2 == std::mem::size_of::()); + +#[derive(Debug, Clone, PartialEq)] +#[repr(C)] +pub struct BlockQ5K { + pub(crate) d: f16, + pub(crate) dmin: f16, + pub(crate) scales: [u8; K_SCALE_SIZE], + pub(crate) qh: [u8; QK_K / 8], + pub(crate) qs: [u8; QK_K / 2], +} +const _: () = + assert!(QK_K / 8 + QK_K / 2 + 2 * 2 + K_SCALE_SIZE == std::mem::size_of::()); + +#[derive(Debug, Clone, PartialEq)] +#[repr(C)] +pub struct BlockQ6K { + pub(crate) ql: [u8; QK_K / 2], + pub(crate) qh: [u8; QK_K / 4], + pub(crate) scales: [i8; QK_K / 16], + pub(crate) d: f16, +} +const _: () = assert!(3 * QK_K / 4 + QK_K / 16 + 2 == std::mem::size_of::()); + +#[derive(Debug, Clone, PartialEq)] +#[repr(C)] +pub struct BlockQ8K { + pub(crate) d: f32, + pub(crate) qs: [i8; QK_K], + pub(crate) bsums: [i16; QK_K / 16], +} +const _: () = assert!(4 + QK_K + QK_K / 16 * 2 == std::mem::size_of::()); + +impl GgmlType for BlockQ4_0 { + const DTYPE: GgmlDType = GgmlDType::Q4_0; + const BLCK_SIZE: usize = QK4_0; + type VecDotType = BlockQ8_0; + + // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1525 + fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { + let k = ys.len(); + let qk = Self::BLCK_SIZE; + if k % qk != 0 { + crate::bail!("dequantize_row_q4_0: {k} is not divisible by {qk}") + } + + let nb = k / qk; + for i in 0..nb { + let d = xs[i].d.to_f32(); + + for j in 0..(qk / 2) { + let x0 = (xs[i].qs[j] & 0x0F) as i16 - 8; + let x1 = (xs[i].qs[j] >> 4) as i16 - 8; + + ys[i * qk + j] = (x0 as f32) * d; + ys[i * qk + j + qk / 2] = (x1 as f32) * d; + } + } + Ok(()) + } + + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + // quantize_row_q4_0 + let qk = Self::BLCK_SIZE; + let k = xs.len(); + if k % qk != 0 { + crate::bail!("{k} is not divisible by {}", qk); + }; + let nb = k / qk; + if ys.len() != nb { + crate::bail!("size mismatch {} {} {}", xs.len(), ys.len(), qk,) + } + for (i, ys) in ys.iter_mut().enumerate() { + let mut amax = 0f32; + let mut max = 0f32; + + let xs = &xs[i * qk..(i + 1) * qk]; + for &x in xs.iter() { + if amax < x.abs() { + amax = x.abs(); + max = x; + } + } + let d = max / -8.0; + let id = if d != 0f32 { 1. / d } else { 0. }; + ys.d = f16::from_f32(d); + + for (j, q) in ys.qs.iter_mut().enumerate() { + let x0 = xs[j] * id; + let x1 = xs[qk / 2 + j] * id; + let xi0 = u8::min(15, (x0 + 8.5) as u8); + let xi1 = u8::min(15, (x1 + 8.5) as u8); + *q = xi0 | (xi1 << 4) + } + } + Ok(()) + } + + // https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/ggml.c#L2361C10-L2361C122 + #[allow(unreachable_code)] + fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + #[cfg(target_feature = "avx")] + return super::avx::vec_dot_q4_0_q8_0(n, xs, ys); + + #[cfg(target_feature = "neon")] + return super::neon::vec_dot_q4_0_q8_0(n, xs, ys); + + #[cfg(target_feature = "simd128")] + return super::simd128::vec_dot_q4_0_q8_0(n, xs, ys); + + Self::vec_dot_unopt(n, xs, ys) + } + + fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + let qk = QK8_0; + if n % QK8_0 != 0 { + crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}") + } + // Generic implementation. + let mut sumf = 0f32; + for (xs, ys) in xs.iter().zip(ys.iter()) { + let mut sum_i = 0; + for j in 0..qk / 2 { + let v0 = (xs.qs[j] & 0x0F) as i32 - 8; + let v1 = (xs.qs[j] >> 4) as i32 - 8; + sum_i += v0 * ys.qs[j] as i32 + v1 * ys.qs[j + qk / 2] as i32 + } + sumf += sum_i as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d) + } + Ok(sumf) + } +} + +impl GgmlType for BlockQ4_1 { + const DTYPE: GgmlDType = GgmlDType::Q4_1; + const BLCK_SIZE: usize = QK4_1; + type VecDotType = BlockQ8_1; + + fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + Self::vec_dot_unopt(n, xs, ys) + } + + fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + // ggml_vec_dot_q4_1_q8_1 + let qk = QK8_1; + if n % qk != 0 { + crate::bail!("vec_dot_q4_1_q8_1: {n} is not divisible by {qk}") + } + let nb = n / qk; + if nb % 2 != 0 { + crate::bail!("vec_dot_q4_1_q8_1: {n}, nb is not divisible by 2") + } + + // Generic implementation. + let mut sumf = 0f32; + + for (xs, ys) in xs.iter().zip(ys.iter()) { + let mut sumi = 0i32; + + for j in 0..qk / 2 { + let v0 = xs.qs[j] as i32 & 0x0F; + let v1 = xs.qs[j] as i32 >> 4; + sumi += (v0 * ys.qs[j] as i32) + (v1 * ys.qs[j + qk / 2] as i32); + } + + sumf += sumi as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d) + + f16::to_f32(xs.m) * f16::to_f32(ys.s) + } + Ok(sumf) + } + + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + // quantize_row_q4_1 + let qk = Self::BLCK_SIZE; + if ys.len() * qk != xs.len() { + crate::bail!("size mismatch {} {} {}", xs.len(), ys.len(), qk,) + } + for (i, ys) in ys.iter_mut().enumerate() { + let xs = &xs[i * qk..(i + 1) * qk]; + + let mut min = f32::INFINITY; + let mut max = f32::NEG_INFINITY; + for &x in xs.iter() { + min = f32::min(x, min); + max = f32::max(x, max); + } + let d = (max - min) / ((1 << 4) - 1) as f32; + let id = if d != 0f32 { 1. / d } else { 0. }; + ys.d = f16::from_f32(d); + ys.m = f16::from_f32(min); + + for (j, q) in ys.qs.iter_mut().take(qk / 2).enumerate() { + let x0 = (xs[j] - min) * id; + let x1 = (xs[qk / 2 + j] - min) * id; + + let xi0 = u8::min(15, (x0 + 0.5) as u8); + let xi1 = u8::min(15, (x1 + 0.5) as u8); + + *q = xi0 | (xi1 << 4); + } + } + Ok(()) + } + + // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1545 + fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { + let k = ys.len(); + if k % QK4_1 != 0 { + crate::bail!("dequantize_row_q4_1: {k} is not divisible by {QK4_1}"); + } + + let nb = k / QK4_1; + for i in 0..nb { + let d = xs[i].d.to_f32(); + let m = xs[i].m.to_f32(); + + for j in 0..(QK4_1 / 2) { + let x0 = xs[i].qs[j] & 0x0F; + let x1 = xs[i].qs[j] >> 4; + + ys[i * QK4_1 + j] = (x0 as f32) * d + m; + ys[i * QK4_1 + j + QK4_1 / 2] = (x1 as f32) * d + m; + } + } + Ok(()) + } +} + +impl GgmlType for BlockQ5_0 { + const DTYPE: GgmlDType = GgmlDType::Q5_0; + const BLCK_SIZE: usize = QK5_0; + type VecDotType = BlockQ8_0; + + fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + let qk = Self::BLCK_SIZE; + if n % Self::BLCK_SIZE != 0 { + crate::bail!("vec_dot_q5_0_q8_0: {n} is not divisible by {qk}") + } + let nb = n / qk; + if nb % 2 != 0 { + crate::bail!("vec_dot_q5_0_q8_0: {n}, nb is not divisible by 2") + } + Self::vec_dot_unopt(n, xs, ys) + } + + fn vec_dot_unopt(_n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + // Generic implementation. + let mut sumf = 0f32; + + for (xs, ys) in xs.iter().zip(ys.iter()) { + let qh = LittleEndian::read_u32(&xs.qh); + let mut sumi = 0i32; + + for j in 0..Self::BLCK_SIZE / 2 { + let xh_0 = (((qh & (1u32 << j)) >> j) << 4) as u8; + let xh_1 = ((qh & (1u32 << (j + 16))) >> (j + 12)) as u8; + + let x0 = ((xs.qs[j] & 0x0F) as i32 | xh_0 as i32) - 16; + let x1 = ((xs.qs[j] >> 4) as i32 | xh_1 as i32) - 16; + + sumi += (x0 * ys.qs[j] as i32) + (x1 * ys.qs[j + Self::BLCK_SIZE / 2] as i32); + } + + sumf += sumi as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d) + } + Ok(sumf) + } + + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + // quantize_row_q5_0 + let k = xs.len(); + if ys.len() * Self::BLCK_SIZE != k { + crate::bail!("size mismatch {k} {} {}", ys.len(), Self::BLCK_SIZE) + } + for (i, ys) in ys.iter_mut().enumerate() { + let xs = &xs[i * Self::BLCK_SIZE..(i + 1) * Self::BLCK_SIZE]; + + let mut amax = 0f32; + let mut max = 0f32; + for &x in xs.iter() { + if amax < x.abs() { + amax = x.abs(); + max = x; + } + } + let d = max / -16.; + let id = if d != 0f32 { 1. / d } else { 0. }; + ys.d = f16::from_f32(d); + let mut qh = 0u32; + for j in 0..Self::BLCK_SIZE / 2 { + let x0 = xs[j] * id; + let x1 = xs[j + Self::BLCK_SIZE / 2] * id; + let xi0 = ((x0 + 16.5) as i8).min(31) as u8; + let xi1 = ((x1 + 16.5) as i8).min(31) as u8; + ys.qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4); + qh |= ((xi0 as u32 & 0x10) >> 4) << j; + qh |= ((xi1 as u32 & 0x10) >> 4) << (j + Self::BLCK_SIZE / 2); + } + LittleEndian::write_u32(&mut ys.qh, qh) + } + Ok(()) + } + + // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1566 + fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { + let k = ys.len(); + if k % QK5_0 != 0 { + crate::bail!("dequantize_row_q5_0: {k} is not divisible by {QK5_0}"); + } + + let nb = k / QK5_0; + for i in 0..nb { + let d = xs[i].d.to_f32(); + let qh: u32 = LittleEndian::read_u32(&xs[i].qh); + + for j in 0..(QK5_0 / 2) { + let xh_0 = (((qh >> j) << 4) & 0x10) as u8; + let xh_1 = ((qh >> (j + 12)) & 0x10) as u8; + + let x0 = ((xs[i].qs[j] & 0x0F) | xh_0) as i32 - 16; + let x1 = ((xs[i].qs[j] >> 4) | xh_1) as i32 - 16; + + ys[i * QK5_0 + j] = (x0 as f32) * d; + ys[i * QK5_0 + j + QK5_0 / 2] = (x1 as f32) * d; + } + } + Ok(()) + } +} + +impl GgmlType for BlockQ5_1 { + const DTYPE: GgmlDType = GgmlDType::Q5_1; + const BLCK_SIZE: usize = QK5_1; + type VecDotType = BlockQ8_1; + + fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + Self::vec_dot_unopt(n, xs, ys) + } + + fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + let qk = Self::BLCK_SIZE; + if n % Self::BLCK_SIZE != 0 { + crate::bail!("vec_dot_q5_1_q8_1: {n} is not divisible by {qk}") + } + let nb = n / qk; + if nb % 2 != 0 { + crate::bail!("vec_dot_q5_1_q8_1: {n}, nb is not divisible by 2") + } + + // Generic implementation. + let mut sumf = 0f32; + + for (xs, ys) in xs.iter().zip(ys.iter()) { + let qh = LittleEndian::read_u32(&xs.qh); + let mut sumi = 0i32; + + for j in 0..Self::BLCK_SIZE / 2 { + let xh_0 = ((qh >> j) << 4) & 0x10; + let xh_1 = (qh >> (j + 12)) & 0x10; + + let x0 = (xs.qs[j] as i32 & 0xF) | xh_0 as i32; + let x1 = (xs.qs[j] as i32 >> 4) | xh_1 as i32; + + sumi += (x0 * ys.qs[j] as i32) + (x1 * ys.qs[j + Self::BLCK_SIZE / 2] as i32); + } + + sumf += sumi as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d) + + f16::to_f32(xs.m) * f16::to_f32(ys.s) + } + Ok(sumf) + } + + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + // quantize_row_q5_1 + let qk = Self::BLCK_SIZE; + if ys.len() * qk != xs.len() { + crate::bail!("size mismatch {} {} {}", xs.len(), ys.len(), qk,) + } + for (i, ys) in ys.iter_mut().enumerate() { + let xs = &xs[i * qk..(i + 1) * qk]; + + let mut min = f32::INFINITY; + let mut max = f32::NEG_INFINITY; + for &x in xs.iter() { + min = f32::min(x, min); + max = f32::max(x, max); + } + let d = (max - min) / ((1 << 5) - 1) as f32; + let id = if d != 0f32 { 1. / d } else { 0. }; + ys.d = f16::from_f32(d); + ys.m = f16::from_f32(min); + + let mut qh = 0u32; + for (j, q) in ys.qs.iter_mut().take(qk / 2).enumerate() { + let x0 = (xs[j] - min) * id; + let x1 = (xs[qk / 2 + j] - min) * id; + + let xi0 = (x0 + 0.5) as u8; + let xi1 = (x1 + 0.5) as u8; + + *q = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4); + // get the 5-th bit and store it in qh at the right position + qh |= ((xi0 as u32 & 0x10) >> 4) << j; + qh |= ((xi1 as u32 & 0x10) >> 4) << (j + qk / 2); + } + LittleEndian::write_u32(&mut ys.qh, qh); + } + Ok(()) + } + + // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1592 + fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { + let k = ys.len(); + if k % QK5_1 != 0 { + crate::bail!("dequantize_row_q5_1: {k} is not divisible by {QK5_1}"); + } + + let nb = k / QK5_1; + for i in 0..nb { + let d = xs[i].d.to_f32(); + let m = xs[i].m.to_f32(); + let qh: u32 = LittleEndian::read_u32(&xs[i].qh); + + for j in 0..(QK5_1 / 2) { + let xh_0 = (((qh >> j) << 4) & 0x10) as u8; + let xh_1 = ((qh >> (j + 12)) & 0x10) as u8; + + let x0 = (xs[i].qs[j] & 0x0F) | xh_0; + let x1 = (xs[i].qs[j] >> 4) | xh_1; + + ys[i * QK5_1 + j] = (x0 as f32) * d + m; + ys[i * QK5_1 + j + QK5_1 / 2] = (x1 as f32) * d + m; + } + } + Ok(()) + } +} + +impl GgmlType for BlockQ8_0 { + const DTYPE: GgmlDType = GgmlDType::Q8_0; + const BLCK_SIZE: usize = QK8_0; + type VecDotType = BlockQ8_0; + + // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1619 + fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { + let k = ys.len(); + if k % QK8_0 != 0 { + crate::bail!("dequantize_row_q8_0: {k} is not divisible by {QK8_0}"); + } + + let nb = k / QK8_0; + + for i in 0..nb { + let d = xs[i].d.to_f32(); + + for j in 0..QK8_0 { + ys[i * QK8_0 + j] = xs[i].qs[j] as f32 * d; + } + } + Ok(()) + } + + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + // quantize_row_q8_0 + let k = xs.len(); + if k % Self::BLCK_SIZE != 0 { + crate::bail!("{k} is not divisible by {}", Self::BLCK_SIZE); + }; + let nb = k / Self::BLCK_SIZE; + if ys.len() != nb { + crate::bail!( + "size mismatch {} {} {}", + xs.len(), + ys.len(), + Self::BLCK_SIZE + ) + } + for (i, ys) in ys.iter_mut().enumerate() { + let mut amax = 0f32; + let xs = &xs[i * Self::BLCK_SIZE..(i + 1) * Self::BLCK_SIZE]; + for &x in xs.iter() { + amax = amax.max(x.abs()) + } + let d = amax / ((1 << 7) - 1) as f32; + let id = if d != 0f32 { 1. / d } else { 0. }; + ys.d = f16::from_f32(d); + for (y, &x) in ys.qs.iter_mut().zip(xs.iter()) { + *y = f32::round(x * id) as i8 + } + } + Ok(()) + } + + #[allow(unreachable_code)] + fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + #[cfg(target_feature = "avx")] + return super::avx::vec_dot_q8_0_q8_0(n, xs, ys); + + #[cfg(target_feature = "neon")] + return super::neon::vec_dot_q8_0_q8_0(n, xs, ys); + + #[cfg(target_feature = "simd128")] + return super::simd128::vec_dot_q8_0_q8_0(n, xs, ys); + + Self::vec_dot_unopt(n, xs, ys) + } + + fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + let qk = QK8_0; + if n % QK8_0 != 0 { + crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}") + } + + // Generic implementation. + let mut sumf = 0f32; + for (xs, ys) in xs.iter().zip(ys.iter()) { + let sum_i = xs + .qs + .iter() + .zip(ys.qs.iter()) + .map(|(&x, &y)| x as i32 * y as i32) + .sum::(); + sumf += sum_i as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d) + } + Ok(sumf) + } +} + +impl GgmlType for BlockQ8_1 { + const DTYPE: GgmlDType = GgmlDType::Q8_1; + const BLCK_SIZE: usize = QK8_1; + type VecDotType = BlockQ8_1; + + fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + Self::vec_dot_unopt(n, xs, ys) + } + + fn vec_dot_unopt(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result { + unimplemented!("no support for vec-dot on Q8_1") + } + + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + // quantize_row_q8_1 + let k = xs.len(); + if ys.len() * Self::BLCK_SIZE != k { + crate::bail!("size mismatch {k} {} {}", ys.len(), Self::BLCK_SIZE) + } + for (i, ys) in ys.iter_mut().enumerate() { + let mut amax = 0f32; + let xs = &xs[i * Self::BLCK_SIZE..(i + 1) * Self::BLCK_SIZE]; + for &x in xs.iter() { + amax = amax.max(x.abs()) + } + let d = amax / ((1 << 7) - 1) as f32; + let id = if d != 0f32 { 1. / d } else { 0. }; + ys.d = f16::from_f32(d); + let mut sum = 0i32; + for j in 0..Self::BLCK_SIZE / 2 { + let v0 = xs[j] * id; + let v1 = xs[j + Self::BLCK_SIZE / 2] * id; + ys.qs[j] = f32::round(v0) as i8; + ys.qs[j + Self::BLCK_SIZE / 2] = f32::round(v1) as i8; + sum += ys.qs[j] as i32 + ys.qs[j + Self::BLCK_SIZE / 2] as i32; + } + ys.s = f16::from_f32(sum as f32) * ys.d; + } + Ok(()) + } + + fn to_float(_xs: &[Self], _ys: &mut [f32]) -> Result<()> { + unimplemented!("no support for vec-dot on Q8_1") + } +} + +impl GgmlType for BlockQ2K { + const DTYPE: GgmlDType = GgmlDType::Q2K; + const BLCK_SIZE: usize = QK_K; + type VecDotType = BlockQ8K; + + #[allow(unreachable_code)] + fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + #[cfg(target_feature = "avx")] + return super::avx::vec_dot_q2k_q8k(n, xs, ys); + + #[cfg(target_feature = "neon")] + return super::neon::vec_dot_q2k_q8k(n, xs, ys); + + #[cfg(target_feature = "simd128")] + return super::simd128::vec_dot_q2k_q8k(n, xs, ys); + + Self::vec_dot_unopt(n, xs, ys) + } + + fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + if n % QK_K != 0 { + crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}") + } + + let mut sumf = 0.0; + for (x, y) in xs.iter().zip(ys.iter()) { + let mut q2: &[_] = &x.qs; + let mut q8: &[_] = &y.qs; + let sc = &x.scales; + + let mut summs = 0; + for (bsum, scale) in y.bsums.iter().zip(sc) { + summs += *bsum as i32 * ((scale >> 4) as i32); + } + + let dall = y.d * x.d.to_f32(); + let dmin = y.d * x.dmin.to_f32(); + + let mut isum = 0; + let mut is = 0; + for _ in 0..(QK_K / 128) { + let mut shift = 0; + for _ in 0..4 { + let d = (sc[is] & 0xF) as i32; + is += 1; + let mut isuml = 0; + for l in 0..16 { + isuml += q8[l] as i32 * (((q2[l] >> shift) & 3) as i32); + } + isum += d * isuml; + let d = (sc[is] & 0xF) as i32; + is += 1; + isuml = 0; + for l in 16..32 { + isuml += q8[l] as i32 * (((q2[l] >> shift) & 3) as i32); + } + isum += d * isuml; + shift += 2; + // adjust the indexing + q8 = &q8[32..]; + } + // adjust the indexing + q2 = &q2[32..]; + } + sumf += dall * isum as f32 - dmin * summs as f32; + } + + Ok(sumf) + } + + // https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L279 + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + const Q4SCALE: f32 = 15.0; + + for (block, x) in group_for_quantization(xs, ys)? { + //calculate scales and mins + let mut mins: [f32; QK_K / 16] = [0.0; QK_K / 16]; + let mut scales: [f32; QK_K / 16] = [0.0; QK_K / 16]; + + for (j, x_scale_slice) in x.chunks(16).enumerate() { + (scales[j], mins[j]) = make_qkx1_quants(3, 5, x_scale_slice); + } + // get max scale and max min and ensure they are >= 0.0 + let max_scale = scales.iter().fold(0.0, |max, &val| val.max(max)); + let max_min = mins.iter().fold(0.0, |max, &val| val.max(max)); + + if max_scale > 0.0 { + let iscale = Q4SCALE / max_scale; + for (j, scale) in scales.iter().enumerate().take(QK_K / 16) { + block.scales[j] = nearest_int(iscale * scale) as u8; + } + block.d = f16::from_f32(max_scale / Q4SCALE); + } else { + for j in 0..QK_K / 16 { + block.scales[j] = 0; + } + block.d = f16::from_f32(0.0); + } + + if max_min > 0.0 { + let iscale = Q4SCALE / max_min; + for (j, scale) in block.scales.iter_mut().enumerate() { + let l = nearest_int(iscale * mins[j]) as u8; + *scale |= l << 4; + } + block.dmin = f16::from_f32(max_min / Q4SCALE); + } else { + block.dmin = f16::from_f32(0.0); + } + + let mut big_l: [u8; QK_K] = [0; QK_K]; + + for j in 0..QK_K / 16 { + let d = block.d.to_f32() * (block.scales[j] & 0xF) as f32; + if d == 0.0 { + continue; + } + let dm = block.dmin.to_f32() * (block.scales[j] >> 4) as f32; + for ii in 0..16 { + let ll = nearest_int((x[16 * j + ii] + dm) / d).clamp(0, 3); + big_l[16 * j + ii] = ll as u8; + } + } + + for j in (0..QK_K).step_by(128) { + for ll in 0..32 { + block.qs[j / 4 + ll] = big_l[j + ll] + | (big_l[j + ll + 32] << 2) + | (big_l[j + ll + 64] << 4) + | (big_l[j + ll + 96] << 6); + } + } + } + Ok(()) + } + // https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L354 + fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { + for (block, y) in group_for_dequantization(xs, ys)? { + let d = block.d.to_f32(); + let min = block.dmin.to_f32(); + + let mut is = 0; + + for (y_block, qs) in y.chunks_exact_mut(128).zip(block.qs.chunks_exact(32)) { + // Step by 32 over q. + let mut shift = 0; + let mut y_block_index = 0; + for _j in 0..4 { + let sc = block.scales[is]; + is += 1; + let dl = d * (sc & 0xF) as f32; + let ml = min * (sc >> 4) as f32; + for q in &qs[..16] { + let y = dl * ((q >> shift) & 3) as f32 - ml; + y_block[y_block_index] = y; + y_block_index += 1; + } + + let sc = block.scales[is]; + is += 1; + let dl = d * (sc & 0xF) as f32; + let ml = min * (sc >> 4) as f32; + for q in &qs[16..] { + let y = dl * ((q >> shift) & 3) as f32 - ml; + y_block[y_block_index] = y; + y_block_index += 1; + } + + shift += 2; + } + } + } + Ok(()) + } +} + +impl GgmlType for BlockQ3K { + const DTYPE: GgmlDType = GgmlDType::Q3K; + const BLCK_SIZE: usize = QK_K; + type VecDotType = BlockQ8K; + + #[allow(unreachable_code)] + fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + #[cfg(target_feature = "avx")] + return super::avx::vec_dot_q3k_q8k(n, xs, ys); + + #[cfg(target_feature = "neon")] + return super::neon::vec_dot_q3k_q8k(n, xs, ys); + + Self::vec_dot_unopt(n, xs, ys) + } + + fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + if n % QK_K != 0 { + crate::bail!("vec_dot_q3k_q8k: {n} is not divisible by {QK_K}") + } + + const KMASK1: u32 = 0x03030303; + const KMASK2: u32 = 0x0f0f0f0f; + + let mut aux8: [i8; QK_K] = [0; QK_K]; + let mut aux16: [i16; 8] = [0; 8]; + let mut sums: [f32; 8] = [0.0; 8]; + let mut aux32: [i32; 8] = [0; 8]; + + let mut auxs: [u32; 4] = [0; 4]; + + for (x, y) in xs.iter().zip(ys.iter()) { + let mut q3: &[u8] = &x.qs; + let hmask: &[u8] = &x.hmask; + let mut q8: &[i8] = &y.qs; + + aux32.fill(0); + let mut a = &mut aux8[..]; + + let mut m = 1; + //Like the GGML original this is written this way to enable the compiler to vectorize it. + for _ in 0..QK_K / 128 { + a.iter_mut() + .take(32) + .zip(q3) + .for_each(|(a_val, q3_val)| *a_val = (q3_val & 3) as i8); + a.iter_mut() + .take(32) + .zip(hmask) + .for_each(|(a_val, hmask_val)| { + *a_val -= if hmask_val & m != 0 { 0 } else { 4 } + }); + a = &mut a[32..]; + m <<= 1; + + a.iter_mut() + .take(32) + .zip(q3) + .for_each(|(a_val, q3_val)| *a_val = ((q3_val >> 2) & 3) as i8); + a.iter_mut() + .take(32) + .zip(hmask) + .for_each(|(a_val, hmask_val)| { + *a_val -= if hmask_val & m != 0 { 0 } else { 4 } + }); + a = &mut a[32..]; + m <<= 1; + + a.iter_mut() + .take(32) + .zip(q3) + .for_each(|(a_val, q3_val)| *a_val = ((q3_val >> 4) & 3) as i8); + a.iter_mut() + .take(32) + .zip(hmask) + .for_each(|(a_val, hmask_val)| { + *a_val -= if hmask_val & m != 0 { 0 } else { 4 } + }); + a = &mut a[32..]; + m <<= 1; + + a.iter_mut() + .take(32) + .zip(q3) + .for_each(|(a_val, q3_val)| *a_val = ((q3_val >> 6) & 3) as i8); + a.iter_mut() + .take(32) + .zip(hmask) + .for_each(|(a_val, hmask_val)| { + *a_val -= if hmask_val & m != 0 { 0 } else { 4 } + }); + a = &mut a[32..]; + m <<= 1; + q3 = &q3[32..]; + } + + a = &mut aux8[..]; + + LittleEndian::read_u32_into(&x.scales, &mut auxs[0..3]); + + let tmp = auxs[2]; + auxs[2] = ((auxs[0] >> 4) & KMASK2) | (((tmp >> 4) & KMASK1) << 4); + auxs[3] = ((auxs[1] >> 4) & KMASK2) | (((tmp >> 6) & KMASK1) << 4); + auxs[0] = (auxs[0] & KMASK2) | (((tmp) & KMASK1) << 4); + auxs[1] = (auxs[1] & KMASK2) | (((tmp >> 2) & KMASK1) << 4); + + for aux in auxs { + for scale in aux.to_le_bytes() { + let scale = i8::from_be_bytes([scale]); + for l in 0..8 { + aux16[l] = q8[l] as i16 * a[l] as i16; + } + for l in 0..8 { + aux32[l] += (scale as i32 - 32) * aux16[l] as i32; + } + q8 = &q8[8..]; + a = &mut a[8..]; + + for l in 0..8 { + aux16[l] = q8[l] as i16 * a[l] as i16; + } + for l in 0..8 { + aux32[l] += (scale as i32 - 32) * aux16[l] as i32; + } + q8 = &q8[8..]; + a = &mut a[8..]; + } + } + let d = x.d.to_f32() * y.d; + for l in 0..8 { + sums[l] += d * aux32[l] as f32; + } + } + + Ok(sums.iter().sum()) + } + + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + for (block, x) in group_for_quantization(xs, ys)? { + let mut scales: [f32; QK_K / 16] = [0.0; QK_K / 16]; + for (j, x_scale_slice) in x.chunks_exact(16).enumerate() { + scales[j] = make_q3_quants(x_scale_slice, 4, true); + } + + // Get max scale by absolute value. + let mut max_scale: f32 = 0.0; + for &scale in scales.iter() { + if scale.abs() > max_scale.abs() { + max_scale = scale; + } + } + + block.scales.fill(0); + + if max_scale != 0.0 { + let iscale = -32.0 / max_scale; + for (j, scale) in scales.iter().enumerate() { + let l_val = nearest_int(iscale * scale); + let l_val = l_val.clamp(-32, 31) + 32; + if j < 8 { + block.scales[j] = (l_val & 0xF) as u8; + } else { + block.scales[j - 8] |= ((l_val & 0xF) << 4) as u8; + } + let l_val = l_val >> 4; + block.scales[j % 4 + 8] |= (l_val << (2 * (j / 4))) as u8; + } + block.d = f16::from_f32(1.0 / iscale); + } else { + block.d = f16::from_f32(0.0); + } + + let mut l: [i8; QK_K] = [0; QK_K]; + + for j in 0..QK_K / 16 { + let sc = if j < 8 { + block.scales[j] & 0xF + } else { + block.scales[j - 8] >> 4 + }; + let sc = (sc | (((block.scales[8 + j % 4] >> (2 * (j / 4))) & 3) << 4)) as i8 - 32; + let d = block.d.to_f32() * sc as f32; + if d != 0.0 { + for ii in 0..16 { + let l_val = nearest_int(x[16 * j + ii] / d); + l[16 * j + ii] = (l_val.clamp(-4, 3) + 4) as i8; + } + } + } + + block.hmask.fill(0); + let mut m = 0; + let mut hm = 1; + + for ll in l.iter_mut() { + if *ll > 3 { + block.hmask[m] |= hm; + *ll -= 4; + } + m += 1; + if m == QK_K / 8 { + m = 0; + hm <<= 1; + } + } + + for j in (0..QK_K).step_by(128) { + for l_val in 0..32 { + block.qs[j / 4 + l_val] = (l[j + l_val] + | (l[j + l_val + 32] << 2) + | (l[j + l_val + 64] << 4) + | (l[j + l_val + 96] << 6)) + as u8; + } + } + } + + Ok(()) + } + + // https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L533 + fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { + const KMASK1: u32 = 0x03030303; + const KMASK2: u32 = 0x0f0f0f0f; + + for (block, y) in group_for_dequantization(xs, ys)? { + //Reconstruct the scales + let mut aux = [0; 4]; + LittleEndian::read_u32_into(&block.scales, &mut aux[0..3]); + + let tmp = aux[2]; + aux[2] = ((aux[0] >> 4) & KMASK2) | (((tmp >> 4) & KMASK1) << 4); + aux[3] = ((aux[1] >> 4) & KMASK2) | (((tmp >> 6) & KMASK1) << 4); + aux[0] = (aux[0] & KMASK2) | (((tmp) & KMASK1) << 4); + aux[1] = (aux[1] & KMASK2) | (((tmp >> 2) & KMASK1) << 4); + + //Transfer the scales into an i8 array + let scales: &mut [i8] = + unsafe { std::slice::from_raw_parts_mut(aux.as_mut_ptr() as *mut i8, 16) }; + + let d_all = block.d.to_f32(); + let mut m = 1; + let mut is = 0; + + // Dequantize both 128 long blocks + // 32 qs values per 128 long block + // Each 16 elements get a scale + for (y, qs) in y.chunks_exact_mut(128).zip(block.qs.chunks_exact(32)) { + let mut shift = 0; + for shift_scoped_y in y.chunks_exact_mut(32) { + for (scale_index, scale_scoped_y) in + shift_scoped_y.chunks_exact_mut(16).enumerate() + { + let dl = d_all * (scales[is] as f32 - 32.0); + for (i, inner_y) in scale_scoped_y.iter_mut().enumerate() { + let new_y = dl + * (((qs[i + 16 * scale_index] >> shift) & 3) as i8 + - if (block.hmask[i + 16 * scale_index] & m) == 0 { + 4 + } else { + 0 + }) as f32; + *inner_y = new_y; + } + // 16 block finished => advance scale index + is += 1; + } + // 32 block finished => increase shift and m + shift += 2; + m <<= 1; + } + } + } + + Ok(()) + } +} + +impl GgmlType for BlockQ4K { + const DTYPE: GgmlDType = GgmlDType::Q4K; + const BLCK_SIZE: usize = QK_K; + type VecDotType = BlockQ8K; + + #[allow(unreachable_code)] + fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + #[cfg(target_feature = "avx")] + return super::avx::vec_dot_q4k_q8k(n, xs, ys); + + #[cfg(target_feature = "neon")] + return super::neon::vec_dot_q4k_q8k(n, xs, ys); + + #[cfg(target_feature = "simd128")] + return super::simd128::vec_dot_q4k_q8k(n, xs, ys); + + Self::vec_dot_unopt(n, xs, ys) + } + + fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + if n % QK_K != 0 { + crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}") + } + + const KMASK1: u32 = 0x3f3f3f3f; + const KMASK2: u32 = 0x0f0f0f0f; + const KMASK3: u32 = 0x03030303; + + let mut utmp: [u32; 4] = [0; 4]; + let mut scales: [u8; 8] = [0; 8]; + let mut mins: [u8; 8] = [0; 8]; + + let mut aux8: [i8; QK_K] = [0; QK_K]; + let mut aux16: [i16; 8] = [0; 8]; + let mut sums: [f32; 8] = [0.0; 8]; + let mut aux32: [i32; 8] = [0; 8]; + + let mut sumf = 0.0; + for (y, x) in ys.iter().zip(xs.iter()) { + let q4 = &x.qs; + let q8 = &y.qs; + aux32.fill(0); + + let mut a = &mut aux8[..]; + let mut q4 = &q4[..]; + for _ in 0..QK_K / 64 { + for l in 0..32 { + a[l] = (q4[l] & 0xF) as i8; + } + a = &mut a[32..]; + for l in 0..32 { + a[l] = (q4[l] >> 4) as i8; + } + a = &mut a[32..]; + q4 = &q4[32..]; + } + + LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]); + + utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4); + let uaux = utmp[1] & KMASK1; + utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4); + utmp[2] = uaux; + utmp[0] &= KMASK1; + + //extract scales and mins + LittleEndian::write_u32_into(&utmp[0..2], &mut scales); + LittleEndian::write_u32_into(&utmp[2..4], &mut mins); + + let mut sumi = 0; + for j in 0..QK_K / 16 { + sumi += y.bsums[j] as i32 * mins[j / 2] as i32; + } + + let mut a = &mut aux8[..]; + let mut q8 = &q8[..]; + + for scale in scales { + let scale = scale as i32; + for _ in 0..4 { + for l in 0..8 { + aux16[l] = q8[l] as i16 * a[l] as i16; + } + for l in 0..8 { + aux32[l] += scale * aux16[l] as i32; + } + q8 = &q8[8..]; + a = &mut a[8..]; + } + } + let d = x.d.to_f32() * y.d; + for l in 0..8 { + sums[l] += d * aux32[l] as f32; + } + let dmin = x.dmin.to_f32() * y.d; + sumf -= dmin * sumi as f32; + } + Ok(sumf + sums.iter().sum::()) + } + + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + for (block, x) in group_for_quantization(xs, ys)? { + let mut mins: [f32; QK_K / 32] = [0.0; QK_K / 32]; + let mut scales: [f32; QK_K / 32] = [0.0; QK_K / 32]; + + for (j, x_scale_slice) in x.chunks_exact(32).enumerate() { + (scales[j], mins[j]) = make_qkx1_quants(15, 5, x_scale_slice); + } + + // get max scale and max min and ensure they are >= 0.0 + let max_scale = scales.iter().fold(0.0, |max, &val| val.max(max)); + let max_min = mins.iter().fold(0.0, |max, &val| val.max(max)); + + let inv_scale = if max_scale > 0.0 { + 63.0 / max_scale + } else { + 0.0 + }; + let inv_min = if max_min > 0.0 { 63.0 / max_min } else { 0.0 }; + + for j in 0..QK_K / 32 { + let ls = nearest_int(inv_scale * scales[j]).min(63) as u8; + let lm = nearest_int(inv_min * mins[j]).min(63) as u8; + if j < 4 { + block.scales[j] = ls; + block.scales[j + 4] = lm; + } else { + block.scales[j + 4] = (ls & 0xF) | ((lm & 0xF) << 4); + block.scales[j - 4] |= (ls >> 4) << 6; + block.scales[j] |= (lm >> 4) << 6; + } + } + + block.d = f16::from_f32(max_scale / 63.0); + block.dmin = f16::from_f32(max_min / 63.0); + + let mut l: [u8; QK_K] = [0; QK_K]; + + for j in 0..QK_K / 32 { + let (sc, m) = get_scale_min_k4(j, &block.scales); + let d = block.d.to_f32() * sc as f32; + if d != 0.0 { + let dm = block.dmin.to_f32() * m as f32; + for ii in 0..32 { + let l_val = nearest_int((x[32 * j + ii] + dm) / d); + l[32 * j + ii] = l_val.clamp(0, 15) as u8; + } + } + } + + let q = &mut block.qs; + for j in (0..QK_K).step_by(64) { + for l_val in 0..32 { + let offset_index = (j / 64) * 32 + l_val; + q[offset_index] = l[j + l_val] | (l[j + l_val + 32] << 4); + } + } + } + Ok(()) + } + // https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L735 + fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { + for (block, y) in group_for_dequantization(xs, ys)? { + let d = block.d.to_f32(); + let min = block.dmin.to_f32(); + let q = &block.qs; + let mut is = 0; + let mut ys_index = 0; + + for j in (0..QK_K).step_by(64) { + let q = &q[j / 2..j / 2 + 32]; + let (sc, m) = get_scale_min_k4(is, &block.scales); + let d1 = d * sc as f32; + let m1 = min * m as f32; + let (sc, m) = get_scale_min_k4(is + 1, &block.scales); + let d2 = d * sc as f32; + let m2 = min * m as f32; + for q in q { + y[ys_index] = d1 * (q & 0xF) as f32 - m1; + ys_index += 1; + } + for q in q { + y[ys_index] = d2 * (q >> 4) as f32 - m2; + ys_index += 1; + } + is += 2; + } + } + Ok(()) + } +} + +// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L928 +impl GgmlType for BlockQ5K { + const DTYPE: GgmlDType = GgmlDType::Q5K; + const BLCK_SIZE: usize = QK_K; + type VecDotType = BlockQ8K; + + #[allow(unreachable_code)] + fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + #[cfg(target_feature = "avx")] + return super::avx::vec_dot_q5k_q8k(n, xs, ys); + + #[cfg(target_feature = "neon")] + return super::neon::vec_dot_q5k_q8k(n, xs, ys); + + Self::vec_dot_unopt(n, xs, ys) + } + + fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + if n % QK_K != 0 { + crate::bail!("vec_dot_q5k_q8k: {n} is not divisible by {QK_K}") + } + + const KMASK1: u32 = 0x3f3f3f3f; + const KMASK2: u32 = 0x0f0f0f0f; + const KMASK3: u32 = 0x03030303; + + let mut utmp: [u32; 4] = [0; 4]; + let mut scales: [u8; 8] = [0; 8]; + let mut mins: [u8; 8] = [0; 8]; + + let mut aux8: [i8; QK_K] = [0; QK_K]; + let mut aux16: [i16; 8] = [0; 8]; + let mut sums: [f32; 8] = [0.0; 8]; + let mut aux32: [i32; 8] = [0; 8]; + + let mut sumf = 0.0; + for (y, x) in ys.iter().zip(xs.iter()) { + let q5 = &x.qs; + let hm = &x.qh; + let q8 = &y.qs; + aux32.fill(0); + + let mut a = &mut aux8[..]; + let mut q5 = &q5[..]; + let mut m = 1u8; + + for _ in 0..QK_K / 64 { + for l in 0..32 { + a[l] = (q5[l] & 0xF) as i8; + a[l] += if hm[l] & m != 0 { 16 } else { 0 }; + } + a = &mut a[32..]; + m <<= 1; + for l in 0..32 { + a[l] = (q5[l] >> 4) as i8; + a[l] += if hm[l] & m != 0 { 16 } else { 0 }; + } + a = &mut a[32..]; + m <<= 1; + q5 = &q5[32..]; + } + + LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]); + + utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4); + let uaux = utmp[1] & KMASK1; + utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4); + utmp[2] = uaux; + utmp[0] &= KMASK1; + + //extract scales and mins + LittleEndian::write_u32_into(&utmp[0..2], &mut scales); + LittleEndian::write_u32_into(&utmp[2..4], &mut mins); + + let mut sumi = 0; + for j in 0..QK_K / 16 { + sumi += y.bsums[j] as i32 * mins[j / 2] as i32; + } + + let mut a = &mut aux8[..]; + let mut q8 = &q8[..]; + + for scale in scales { + let scale = scale as i32; + for _ in 0..4 { + for l in 0..8 { + aux16[l] = q8[l] as i16 * a[l] as i16; + } + for l in 0..8 { + aux32[l] += scale * aux16[l] as i32; + } + q8 = &q8[8..]; + a = &mut a[8..]; + } + } + let d = x.d.to_f32() * y.d; + for l in 0..8 { + sums[l] += d * aux32[l] as f32; + } + let dmin = x.dmin.to_f32() * y.d; + sumf -= dmin * sumi as f32; + } + Ok(sumf + sums.iter().sum::()) + } + + // https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L793 + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + for (block, x) in group_for_quantization(xs, ys)? { + let mut mins: [f32; QK_K / 32] = [0.0; QK_K / 32]; + let mut scales: [f32; QK_K / 32] = [0.0; QK_K / 32]; + + for (j, x_scale_slice) in x.chunks_exact(32).enumerate() { + (scales[j], mins[j]) = make_qkx1_quants(31, 5, x_scale_slice); + } + + // get max scale and max min and ensure they are >= 0.0 + let max_scale = scales.iter().fold(0.0, |max, &val| val.max(max)); + let max_min = mins.iter().fold(0.0, |max, &val| val.max(max)); + + let inv_scale = if max_scale > 0.0 { + 63.0 / max_scale + } else { + 0.0 + }; + let inv_min = if max_min > 0.0 { 63.0 / max_min } else { 0.0 }; + for j in 0..QK_K / 32 { + let ls = nearest_int(inv_scale * scales[j]).min(63) as u8; + let lm = nearest_int(inv_min * mins[j]).min(63) as u8; + if j < 4 { + block.scales[j] = ls; + block.scales[j + 4] = lm; + } else { + block.scales[j + 4] = (ls & 0xF) | ((lm & 0xF) << 4); + block.scales[j - 4] |= (ls >> 4) << 6; + block.scales[j] |= (lm >> 4) << 6; + } + } + block.d = f16::from_f32(max_scale / 63.0); + block.dmin = f16::from_f32(max_min / 63.0); + + let mut l: [u8; QK_K] = [0; QK_K]; + for j in 0..QK_K / 32 { + let (sc, m) = get_scale_min_k4(j, &block.scales); + let d = block.d.to_f32() * sc as f32; + if d == 0.0 { + continue; + } + let dm = block.dmin.to_f32() * m as f32; + for ii in 0..32 { + let ll = nearest_int((x[32 * j + ii] + dm) / d); + l[32 * j + ii] = ll.clamp(0, 31) as u8; + } + } + + let qh = &mut block.qh; + let ql = &mut block.qs; + qh.fill(0); + + let mut m1 = 1; + let mut m2 = 2; + for n in (0..QK_K).step_by(64) { + let offset = (n / 64) * 32; + for j in 0..32 { + let mut l1 = l[n + j]; + if l1 > 15 { + l1 -= 16; + qh[j] |= m1; + } + let mut l2 = l[n + j + 32]; + if l2 > 15 { + l2 -= 16; + qh[j] |= m2; + } + ql[offset + j] = l1 | (l2 << 4); + } + m1 <<= 2; + m2 <<= 2; + } + } + + Ok(()) + } + + // https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L928 + fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { + for (block, y) in group_for_dequantization(xs, ys)? { + let d = block.d.to_f32(); + let min = block.dmin.to_f32(); + let ql = &block.qs; + let qh = &block.qh; + let mut is = 0; + let mut u1 = 1; + let mut u2 = 2; + let mut ys_index = 0; + + for j in (0..QK_K).step_by(64) { + let ql = &ql[j / 2..j / 2 + 32]; + let (sc, m) = get_scale_min_k4(is, &block.scales); + let d1 = d * sc as f32; + let m1 = min * m as f32; + let (sc, m) = get_scale_min_k4(is + 1, &block.scales); + let d2 = d * sc as f32; + let m2 = min * m as f32; + for (ql, qh) in ql.iter().zip(qh) { + let to_add = if qh & u1 != 0 { 16f32 } else { 0f32 }; + y[ys_index] = d1 * ((ql & 0xF) as f32 + to_add) - m1; + ys_index += 1; + } + for (ql, qh) in ql.iter().zip(qh) { + let to_add = if qh & u2 != 0 { 16f32 } else { 0f32 }; + y[ys_index] = d2 * ((ql >> 4) as f32 + to_add) - m2; + ys_index += 1; + } + is += 2; + u1 <<= 2; + u2 <<= 2; + } + } + Ok(()) + } +} + +impl GgmlType for BlockQ6K { + const DTYPE: GgmlDType = GgmlDType::Q6K; + const BLCK_SIZE: usize = QK_K; + type VecDotType = BlockQ8K; + + #[allow(unreachable_code)] + fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + #[cfg(target_feature = "avx")] + return super::avx::vec_dot_q6k_q8k(n, xs, ys); + + #[cfg(target_feature = "neon")] + return super::neon::vec_dot_q6k_q8k(n, xs, ys); + + #[cfg(target_feature = "simd128")] + return super::simd128::vec_dot_q6k_q8k(n, xs, ys); + + Self::vec_dot_unopt(n, xs, ys) + } + + fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + if n % QK_K != 0 { + crate::bail!("vec_dot_q6k_q8k: {n} is not divisible by {QK_K}") + } + + let mut aux8 = [0i8; QK_K]; + let mut aux16 = [0i16; 8]; + let mut sums = [0f32; 8]; + let mut aux32 = [0f32; 8]; + + for (x, y) in xs.iter().zip(ys.iter()) { + let q4 = &x.ql; + let qh = &x.qh; + let q8 = &y.qs; + aux32.fill(0f32); + + for j in (0..QK_K).step_by(128) { + let aux8 = &mut aux8[j..]; + let q4 = &q4[j / 2..]; + let qh = &qh[j / 4..]; + for l in 0..32 { + aux8[l] = (((q4[l] & 0xF) | ((qh[l] & 3) << 4)) as i32 - 32) as i8; + aux8[l + 32] = + (((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) as i32 - 32) as i8; + aux8[l + 64] = (((q4[l] >> 4) | (((qh[l] >> 4) & 3) << 4)) as i32 - 32) as i8; + aux8[l + 96] = + (((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) as i32 - 32) as i8; + } + } + + for (j, &scale) in x.scales.iter().enumerate() { + let scale = scale as f32; + let q8 = &q8[16 * j..]; + let aux8 = &aux8[16 * j..]; + for l in 0..8 { + aux16[l] = q8[l] as i16 * aux8[l] as i16; + } + for l in 0..8 { + aux32[l] += scale * aux16[l] as f32 + } + let q8 = &q8[8..]; + let aux8 = &aux8[8..]; + for l in 0..8 { + aux16[l] = q8[l] as i16 * aux8[l] as i16; + } + for l in 0..8 { + aux32[l] += scale * aux16[l] as f32 + } + } + + let d = x.d.to_f32() * y.d; + for (sum, &a) in sums.iter_mut().zip(aux32.iter()) { + *sum += a * d; + } + } + Ok(sums.iter().sum()) + } + + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + if xs.len() != ys.len() * Self::BLCK_SIZE { + crate::bail!( + "quantize_row_q6k: size mismatch {} {} {}", + xs.len(), + ys.len(), + Self::BLCK_SIZE + ) + } + let mut l = [0i8; QK_K]; + let mut scales = [0f32; QK_K / 16]; + let mut x = xs.as_ptr(); + let l = l.as_mut_ptr(); + unsafe { + for y in ys.iter_mut() { + let mut max_scale = 0f32; + let mut max_abs_scale = 0f32; + for (ib, scale_) in scales.iter_mut().enumerate() { + let scale = make_qx_quants(16, 32, x.add(16 * ib), l.add(16 * ib), 1); + *scale_ = scale; + let abs_scale = scale.abs(); + if abs_scale > max_abs_scale { + max_abs_scale = abs_scale; + max_scale = scale + } + } + + let iscale = -128f32 / max_scale; + y.d = f16::from_f32(1.0 / iscale); + + for (y_scale, scale) in y.scales.iter_mut().zip(scales.iter()) { + *y_scale = nearest_int(iscale * scale).min(127) as i8 + } + + for (j, &y_scale) in y.scales.iter().enumerate() { + let d = y.d.to_f32() * y_scale as f32; + if d == 0. { + continue; + } + for ii in 0..16 { + let ll = nearest_int(*x.add(16 * j + ii) / d).clamp(-32, 31); + *l.add(16 * j + ii) = (ll + 32) as i8 + } + } + + let mut ql = y.ql.as_mut_ptr(); + let mut qh = y.qh.as_mut_ptr(); + + for j in (0..QK_K).step_by(128) { + for l_idx in 0..32 { + let q1 = *l.add(j + l_idx) & 0xF; + let q2 = *l.add(j + l_idx + 32) & 0xF; + let q3 = *l.add(j + l_idx + 64) & 0xF; + let q4 = *l.add(j + l_idx + 96) & 0xF; + *ql.add(l_idx) = (q1 | (q3 << 4)) as u8; + *ql.add(l_idx + 32) = (q2 | (q4 << 4)) as u8; + *qh.add(l_idx) = ((*l.add(j + l_idx) >> 4) + | ((*l.add(j + l_idx + 32) >> 4) << 2) + | ((*l.add(j + l_idx + 64) >> 4) << 4) + | ((*l.add(j + l_idx + 96) >> 4) << 6)) + as u8; + } + ql = ql.add(64); + qh = qh.add(32); + } + + x = x.add(QK_K) + } + } + Ok(()) + } + + // https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L1067 + fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { + let k = ys.len(); + if k % QK_K != 0 { + crate::bail!("dequantize_row_q6k: {k} is not divisible by {QK_K}") + } + for (idx_x, x) in xs.iter().enumerate() { + let d = x.d.to_f32(); + let ql = &x.ql; + let qh = &x.qh; + let sc = &x.scales; + for n in (0..QK_K).step_by(128) { + let idx = n / 128; + let ys = &mut ys[idx_x * QK_K + n..]; + let sc = &sc[8 * idx..]; + let ql = &ql[64 * idx..]; + let qh = &qh[32 * idx..]; + for l in 0..32 { + let is = l / 16; + let q1 = ((ql[l] & 0xF) | ((qh[l] & 3) << 4)) as i8 - 32; + let q2 = ((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) as i8 - 32; + let q3 = ((ql[l] >> 4) | (((qh[l] >> 4) & 3) << 4)) as i8 - 32; + let q4 = ((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) as i8 - 32; + ys[l] = d * sc[is] as f32 * q1 as f32; + ys[l + 32] = d * sc[is + 2] as f32 * q2 as f32; + ys[l + 64] = d * sc[is + 4] as f32 * q3 as f32; + ys[l + 96] = d * sc[is + 6] as f32 * q4 as f32; + } + } + } + Ok(()) + } +} + +impl GgmlType for BlockQ8K { + const DTYPE: GgmlDType = GgmlDType::Q8K; + const BLCK_SIZE: usize = QK_K; + type VecDotType = BlockQ8K; + + #[allow(unreachable_code)] + fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + #[cfg(target_feature = "avx")] + return super::avx::vec_dot_q8k_q8k(n, xs, ys); + + #[cfg(target_feature = "neon")] + return super::neon::vec_dot_q8k_q8k(n, xs, ys); + + #[cfg(target_feature = "simd128")] + return super::simd128::vec_dot_q8k_q8k(n, xs, ys); + + Self::vec_dot_unopt(n, xs, ys) + } + + fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + let qk = QK_K; + if n % QK_K != 0 { + crate::bail!("vec_dot_q8k_q8k: {n} is not divisible by {qk}") + } + + // Generic implementation. + let mut sumf = 0f32; + for (xs, ys) in xs.iter().zip(ys.iter()) { + let sum_i = xs + .qs + .iter() + .zip(ys.qs.iter()) + .map(|(&x, &y)| x as i32 * y as i32) + .sum::(); + sumf += sum_i as f32 * xs.d * ys.d + } + Ok(sumf) + } + + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + let k = xs.len(); + if k % QK_K != 0 { + crate::bail!("quantize_row_q8k: {k} is not divisible by {QK_K}") + } + for (i, y) in ys.iter_mut().enumerate() { + let mut max = 0f32; + let mut amax = 0f32; + let xs = &xs[i * QK_K..(i + 1) * QK_K]; + for &x in xs.iter() { + if amax < x.abs() { + amax = x.abs(); + max = x; + } + } + if amax == 0f32 { + y.d = 0f32; + y.qs.fill(0) + } else { + let iscale = -128f32 / max; + for (j, q) in y.qs.iter_mut().enumerate() { + // ggml uses nearest_int with bit magic here, maybe we want the same + // but we would have to test and benchmark it. + let v = (iscale * xs[j]).round(); + *q = v.min(127.) as i8 + } + for j in 0..QK_K / 16 { + let mut sum = 0i32; + for ii in 0..16 { + sum += y.qs[j * 16 + ii] as i32 + } + y.bsums[j] = sum as i16 + } + y.d = 1.0 / iscale + } + } + Ok(()) + } + + fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { + let k = ys.len(); + if k % QK_K != 0 { + crate::bail!("dequantize_row_q8k: {k} is not divisible by {QK_K}") + } + for (i, x) in xs.iter().enumerate() { + for (j, &q) in x.qs.iter().enumerate() { + ys[i * QK_K + j] = x.d * q as f32 + } + } + Ok(()) + } +} + +// https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/ggml.c#L10605 +pub fn matmul( + mkn: (usize, usize, usize), + lhs: &[f32], + rhs_t: &[T], + dst: &mut [f32], +) -> Result<()> { + let (m, k, n) = mkn; + if m * k != lhs.len() { + crate::bail!("unexpected lhs length {} {mkn:?}", lhs.len()); + } + + let k_in_lhs_blocks = (k + T::BLCK_SIZE - 1) / T::BLCK_SIZE; + let k_in_rhs_blocks = (k + T::VecDotType::BLCK_SIZE - 1) / T::VecDotType::BLCK_SIZE; + // TODO: Do not make this copy if the DotType is f32. + // TODO: Pre-allocate this. + let mut lhs_b = vec![T::VecDotType::zeros(); m * k_in_lhs_blocks]; + for row_idx in 0..m { + let lhs_b = &mut lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks]; + let lhs = &lhs[row_idx * k..(row_idx + 1) * k]; + T::VecDotType::from_float(lhs, lhs_b)? + } + let lhs_b = lhs_b.as_slice(); + + for row_idx in 0..m { + let lhs_row = &lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks]; + let dst_row = &mut dst[row_idx * n..(row_idx + 1) * n]; + + let result: Result> = dst_row + .into_par_iter() + .enumerate() + .with_min_len(128) + .with_max_len(512) + .map(|(col_idx, dst)| { + let rhs_col = &rhs_t[col_idx * k_in_rhs_blocks..(col_idx + 1) * k_in_rhs_blocks]; + T::vec_dot(k, rhs_col, lhs_row).map(|value| *dst = value) + }) + .collect(); + + result?; + } + Ok(()) +} + +impl GgmlType for f32 { + const DTYPE: GgmlDType = GgmlDType::F32; + const BLCK_SIZE: usize = 1; + type VecDotType = f32; + + fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + Self::vec_dot_unopt(n, xs, ys) + } + + fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + if xs.len() < n { + crate::bail!("size mismatch {} < {n}", xs.len()) + } + if ys.len() < n { + crate::bail!("size mismatch {} < {n}", ys.len()) + } + let mut res = 0f32; + unsafe { crate::cpu::vec_dot_f32(xs.as_ptr(), ys.as_ptr(), &mut res, n) }; + Ok(res) + } + + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + if xs.len() != ys.len() { + crate::bail!("size mismatch {} {}", xs.len(), ys.len()); + } + ys.copy_from_slice(xs); + Ok(()) + } + + fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { + if xs.len() != ys.len() { + crate::bail!("size mismatch {} {}", xs.len(), ys.len()); + } + ys.copy_from_slice(xs); + Ok(()) + } +} + +impl GgmlType for f16 { + const DTYPE: GgmlDType = GgmlDType::F16; + const BLCK_SIZE: usize = 1; + type VecDotType = f16; + + fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + Self::vec_dot_unopt(n, xs, ys) + } + + fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + if xs.len() < n { + crate::bail!("size mismatch {} < {n}", xs.len()) + } + if ys.len() < n { + crate::bail!("size mismatch {} < {n}", ys.len()) + } + let mut res = 0f32; + unsafe { crate::cpu::vec_dot_f16(xs.as_ptr(), ys.as_ptr(), &mut res, n) }; + Ok(res) + } + + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + if xs.len() != ys.len() { + crate::bail!("size mismatch {} {}", xs.len(), ys.len()); + } + // TODO: vectorize + for (x, y) in xs.iter().zip(ys.iter_mut()) { + *y = f16::from_f32(*x) + } + Ok(()) + } + + fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { + if xs.len() != ys.len() { + crate::bail!("size mismatch {} {}", xs.len(), ys.len()); + } + // TODO: vectorize + for (x, y) in xs.iter().zip(ys.iter_mut()) { + *y = x.to_f32() + } + Ok(()) + } +} + +impl GgmlType for bf16 { + const DTYPE: GgmlDType = GgmlDType::BF16; + const BLCK_SIZE: usize = 1; + type VecDotType = bf16; + + fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + Self::vec_dot_unopt(n, xs, ys) + } + + fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + if xs.len() < n { + crate::bail!("size mismatch {} < {n}", xs.len()) + } + if ys.len() < n { + crate::bail!("size mismatch {} < {n}", ys.len()) + } + let mut res = 0f32; + unsafe { crate::cpu::vec_dot_bf16(xs.as_ptr(), ys.as_ptr(), &mut res, n) }; + Ok(res) + } + + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + if xs.len() != ys.len() { + crate::bail!("size mismatch {} {}", xs.len(), ys.len()); + } + // TODO: vectorize + for (x, y) in xs.iter().zip(ys.iter_mut()) { + *y = bf16::from_f32(*x) + } + Ok(()) + } + + fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { + if xs.len() != ys.len() { + crate::bail!("size mismatch {} {}", xs.len(), ys.len()); + } + // TODO: vectorize + for (x, y) in xs.iter().zip(ys.iter_mut()) { + *y = x.to_f32() + } + Ok(()) + } +} diff --git a/mcandle/mcandle-core/src/quantized/metal.rs b/mcandle/mcandle-core/src/quantized/metal.rs new file mode 100644 index 000000000..b3d5d6e21 --- /dev/null +++ b/mcandle/mcandle-core/src/quantized/metal.rs @@ -0,0 +1,357 @@ +use super::{GgmlDType, QStorage}; +use crate::backend::BackendStorage; +use crate::{DType, MetalDevice, MetalStorage, Result, Shape, D}; +use metal::Buffer; +use std::sync::Arc; + +pub struct QMetalStorage { + dtype: GgmlDType, + device: MetalDevice, + buffer: Arc, +} + +impl QMetalStorage { + pub fn zeros(device: &MetalDevice, elem_count: usize, dtype: GgmlDType) -> Result { + let size = elem_count * dtype.type_size() / dtype.block_size(); + let buffer = device.allocate_zeros(size)?; + Ok(Self { + buffer, + device: device.clone(), + dtype, + }) + } + + pub fn dtype(&self) -> GgmlDType { + self.dtype + } + + pub fn device(&self) -> &MetalDevice { + &self.device + } + + pub fn buffer(&self) -> &Buffer { + &self.buffer + } + + pub fn dequantize(&self, elem_count: usize) -> Result { + use crate::quantized::k_quants::GgmlType; + + let buffer = self.device.new_buffer_managed(self.buffer.length())?; + let command_buffer = self.device.command_buffer()?; + command_buffer.set_label("to_cpu"); + let blit = command_buffer.new_blit_command_encoder(); + blit.set_label("blit_to_cpu"); + blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length()); + blit.end_encoding(); + self.device.wait_until_completed()?; + let mut out = vec![0.0; elem_count]; + let block_len = elem_count / self.dtype.block_size(); + match self.dtype { + GgmlDType::F32 => { + let vec: Vec = read_to_vec(&buffer, block_len); + f32::to_float(&vec, &mut out)?; + } + GgmlDType::F16 => { + let vec: Vec = read_to_vec(&buffer, block_len); + half::f16::to_float(&vec, &mut out)?; + } + GgmlDType::BF16 => { + let vec: Vec = read_to_vec(&buffer, block_len); + half::bf16::to_float(&vec, &mut out)?; + } + GgmlDType::Q4_0 => { + let vec: Vec = read_to_vec(&buffer, block_len); + crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?; + } + GgmlDType::Q4_1 => { + let vec: Vec = read_to_vec(&buffer, block_len); + crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?; + } + GgmlDType::Q5_0 => { + let vec: Vec = read_to_vec(&buffer, block_len); + crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?; + } + GgmlDType::Q5_1 => { + let vec: Vec = read_to_vec(&buffer, block_len); + crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?; + } + GgmlDType::Q8_0 => { + let vec: Vec = read_to_vec(&buffer, block_len); + crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?; + } + GgmlDType::Q8_1 => { + let vec: Vec = read_to_vec(&buffer, block_len); + crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?; + } + GgmlDType::Q2K => { + let vec: Vec = read_to_vec(&buffer, block_len); + crate::quantized::BlockQ2K::to_float(&vec, &mut out)?; + } + GgmlDType::Q3K => { + let vec: Vec = read_to_vec(&buffer, block_len); + crate::quantized::BlockQ3K::to_float(&vec, &mut out)?; + } + GgmlDType::Q4K => { + let vec: Vec = read_to_vec(&buffer, block_len); + crate::quantized::BlockQ4K::to_float(&vec, &mut out)?; + } + GgmlDType::Q5K => { + let vec: Vec = read_to_vec(&buffer, block_len); + crate::quantized::BlockQ5K::to_float(&vec, &mut out)?; + } + GgmlDType::Q6K => { + let vec: Vec = read_to_vec(&buffer, block_len); + crate::quantized::BlockQ6K::to_float(&vec, &mut out)?; + } + GgmlDType::Q8K => { + let vec: Vec = read_to_vec(&buffer, block_len); + crate::quantized::BlockQ8K::to_float(&vec, &mut out)?; + } + } + + let buffer = self.device.new_buffer_with_data(&out)?; + Ok(MetalStorage::new( + buffer, + self.device.clone(), + elem_count, + DType::F32, + )) + } + + pub fn quantize(&mut self, src: &MetalStorage) -> Result<()> { + // Quantization only happens on CPU for now. + let src = src.to_cpu::()?; + let elem_count = src.len(); + let src = crate::Storage::Cpu(crate::CpuStorage::F32(src)); + let mut qcpu_storage = crate::Device::Cpu.qzeros(elem_count, self.dtype)?; + qcpu_storage.quantize(&src)?; + let buffer = self.device.new_buffer_with_data(&qcpu_storage.data()?)?; + self.buffer = buffer; + Ok(()) + } + + pub fn quantize_onto(&mut self, src: &crate::CpuStorage) -> Result<()> { + // Quantization only happens on CPU for now. + let elem_count = src.as_slice::()?.len(); + let mut qcpu_storage = crate::Device::Cpu.qzeros(elem_count, self.dtype)?; + + if let QStorage::Cpu(storage) = &mut qcpu_storage { + storage.from_float(src.as_slice::()?)?; + } else { + unreachable!() + } + + let buffer = self.device.new_buffer_with_data(&qcpu_storage.data()?)?; + self.buffer = buffer; + Ok(()) + } + + pub fn storage_size_in_bytes(&self) -> usize { + self.buffer.length() as usize + } + + fn fwd_mv( + &self, + self_shape: &Shape, + storage: &MetalStorage, + layout: &crate::Layout, + ) -> Result<(MetalStorage, Shape)> { + use crate::MetalError; + + if !layout.is_contiguous() { + crate::bail!("input tensor is not contiguous {layout:?}") + } + let src_shape = layout.shape(); + // self is transposed so n is first then k. + if src_shape.rank() < 2 { + crate::bail!("input tensor has only one dimension {layout:?}") + } + let (n, k) = self_shape.dims2()?; + let mut dst_shape = src_shape.dims().to_vec(); + + // We always use a single batch dimension and stack all the tensors in the batch on the + // second dimension as the implementation in mcandle-metal-kernels doesn't handle batch + // properly. + let m = match dst_shape.len() { + 3 => dst_shape[0] * dst_shape[1], + 2 => dst_shape[0], + n => crate::bail!("Invalid rank {n} for quantized matmul metal"), + }; + let last_k = dst_shape.pop().unwrap(); + if last_k != k { + crate::bail!("input tensor {layout:?} incompatible with {:?}", self_shape) + } + dst_shape.push(n); + let dst_shape = Shape::from(dst_shape); + let device = storage.device().clone(); + let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?; + let command_buffer = device.command_buffer()?; + // In some cases it would be better to use the mm variant, though it has its drawbacks + // around memory alignemnt. + for batch_id in 0..m { + mcandle_metal_kernels::call_quantized_matmul_mv_t( + device.device(), + &command_buffer, + device.kernels(), + self.dtype.into(), + (1, 1, n, k), + storage.buffer(), + (layout.start_offset() + batch_id * k) * storage.dtype().size_in_bytes(), + &self.buffer, + batch_id * n * DType::F32.size_in_bytes(), + &dst, + ) + .map_err(MetalError::from)?; + } + let dst_storage = crate::MetalStorage::new(dst, device, dst_shape.elem_count(), DType::F32); + Ok((dst_storage, dst_shape)) + } + + pub fn fwd( + &self, + self_shape: &Shape, + storage: &MetalStorage, + layout: &crate::Layout, + ) -> Result<(MetalStorage, Shape)> { + use crate::MetalError; + + if !layout.is_contiguous() { + crate::bail!("input tensor is not contiguous {layout:?}") + } + let src_shape = layout.shape(); + // self is transposed so n is first then k. + if src_shape.rank() < 2 { + crate::bail!("input tensor has only one dimension {layout:?}") + } + let n = self_shape.dim(D::Minus2)?; + let k = self_shape.dim(D::Minus1)?; + let mut dst_shape = src_shape.dims().to_vec(); + + if src_shape.rank() < self_shape.rank() { + crate::bail!( + "input rank ({}) must be >= weight rank ({})", + src_shape.rank(), + self_shape.rank() + ) + } + + if src_shape.dim(D::Minus2)? == 1 { + return self.fwd_mv(self_shape, storage, layout); + } + + let last_k = dst_shape.pop().unwrap(); + if last_k != k { + crate::bail!("input tensor {layout:?} incompatible with {:?}", self_shape) + } + dst_shape.push(n); + let dst_shape = Shape::from(dst_shape); + let device = storage.device().clone(); + let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?; + let command_buffer = device.command_buffer()?; + + assert_eq!(storage.dtype(), DType::F32); + + if self_shape.rank() > 4 { + crate::bail!("weight rank ({}) must be <= 4", self_shape.rank()) + } + let src0_l = crate::Layout::contiguous( + [vec![1; 4 - self_shape.rank()], self_shape.dims().to_vec()].concat(), + ); + let src0_stride = src0_l + .stride() + .iter() + .map(|x| { + (*x as f32 * (self.dtype.type_size() as f32 / self.dtype.block_size() as f32)) + as usize + }) + .collect::>(); + + if src_shape.rank() > 4 { + crate::bail!("weight rank ({}) must be <= 4", src_shape.rank()) + } + let src1_l = crate::Layout::contiguous( + [vec![1; 4 - src_shape.rank()], src_shape.dims().to_vec()].concat(), + ); + + mcandle_metal_kernels::call_quantized_matmul_mm_t( + device.device(), + &command_buffer, + device.kernels(), + self.dtype.into(), + src0_l.dims(), + &src0_stride, + &self.buffer, + src1_l.dims(), + &src1_l + .stride() + .iter() + .map(|x| x * DType::F32.size_in_bytes()) + .collect::>(), + storage.buffer(), + src1_l.start_offset() * storage.dtype().size_in_bytes(), + dst_shape.dims(), + 0, + &dst, + ) + .map_err(MetalError::from)?; + + let dst_storage = crate::MetalStorage::new(dst, device, dst_shape.elem_count(), DType::F32); + Ok((dst_storage, dst_shape)) + } + + pub fn data(&self) -> Result> { + let buffer = self.device.new_buffer_managed(self.buffer.length())?; + { + let command_buffer = self.device.command_buffer()?; + command_buffer.set_label("to_cpu"); + let blit = command_buffer.new_blit_command_encoder(); + blit.set_label("blit_to_cpu"); + blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length()); + blit.end_encoding(); + } + self.device.wait_until_completed()?; + Ok(read_to_vec::(&buffer, self.buffer.length() as usize)) + } +} + +pub fn load_quantized( + device: &MetalDevice, + data: &[T], +) -> Result { + let buffer = device.new_buffer_with_data(data)?; + let device = device.clone(); + Ok(QStorage::Metal(QMetalStorage { + dtype: T::DTYPE, + device, + buffer, + })) +} + +fn read_to_vec(buffer: &Buffer, n: usize) -> Vec { + let ptr = buffer.contents() as *const T; + assert!(!ptr.is_null()); + let slice = unsafe { std::slice::from_raw_parts(ptr, n) }; + slice.to_vec() +} + +impl From for mcandle_metal_kernels::GgmlDType { + fn from(value: GgmlDType) -> Self { + match value { + GgmlDType::Q4_0 => mcandle_metal_kernels::GgmlDType::Q4_0, + GgmlDType::Q4_1 => mcandle_metal_kernels::GgmlDType::Q4_1, + GgmlDType::Q5_0 => mcandle_metal_kernels::GgmlDType::Q5_0, + GgmlDType::Q5_1 => mcandle_metal_kernels::GgmlDType::Q5_1, + GgmlDType::Q8_0 => mcandle_metal_kernels::GgmlDType::Q8_0, + GgmlDType::Q8_1 => mcandle_metal_kernels::GgmlDType::Q8_1, + GgmlDType::Q2K => mcandle_metal_kernels::GgmlDType::Q2K, + GgmlDType::Q3K => mcandle_metal_kernels::GgmlDType::Q3K, + GgmlDType::Q4K => mcandle_metal_kernels::GgmlDType::Q4K, + GgmlDType::Q5K => mcandle_metal_kernels::GgmlDType::Q5K, + GgmlDType::Q6K => mcandle_metal_kernels::GgmlDType::Q6K, + GgmlDType::Q8K => mcandle_metal_kernels::GgmlDType::Q8K, + GgmlDType::F16 => mcandle_metal_kernels::GgmlDType::F16, + GgmlDType::F32 => mcandle_metal_kernels::GgmlDType::F32, + GgmlDType::BF16 => mcandle_metal_kernels::GgmlDType::F16, + } + } +} diff --git a/mcandle/mcandle-core/src/quantized/mod.rs b/mcandle/mcandle-core/src/quantized/mod.rs new file mode 100644 index 000000000..dadbaf8fb --- /dev/null +++ b/mcandle/mcandle-core/src/quantized/mod.rs @@ -0,0 +1,594 @@ +use crate::{CpuStorage, DType, Device, Result, Shape, Storage, Tensor}; +use k_quants::*; +use std::borrow::Cow; + +#[cfg(target_feature = "avx")] +pub mod avx; +mod dummy_cuda; +mod dummy_metal; +pub mod ggml_file; +pub mod gguf_file; +pub mod k_quants; +#[cfg(feature = "metal")] +pub mod metal; +#[cfg(not(feature = "metal"))] +mod metal { + pub use super::dummy_metal::*; +} +#[cfg(feature = "cuda")] +pub mod cuda; +#[cfg(not(feature = "cuda"))] +mod cuda { + pub use super::dummy_cuda::*; +} + +#[cfg(target_feature = "neon")] +pub mod neon; +#[cfg(target_feature = "simd128")] +pub mod simd128; +pub mod utils; +use half::{bf16, f16}; + +pub use k_quants::GgmlType; + +pub struct QTensor { + storage: QStorage, + shape: Shape, +} + +impl Device { + fn qzeros(&self, elem_count: usize, dtype: GgmlDType) -> Result { + match self { + Device::Cpu => { + let storage = dtype.cpu_zeros(elem_count); + Ok(QStorage::Cpu(storage)) + } + Device::Metal(metal) => { + let storage = metal::QMetalStorage::zeros(metal, elem_count, dtype)?; + Ok(QStorage::Metal(storage)) + } + Device::Cuda(cuda) => { + let storage = cuda::QCudaStorage::zeros(cuda, elem_count, dtype)?; + Ok(QStorage::Cuda(storage)) + } + } + } +} + +pub enum QStorage { + Cpu(Box), + Metal(metal::QMetalStorage), + Cuda(cuda::QCudaStorage), +} + +impl QStorage { + fn block_size(&self) -> usize { + match self { + QStorage::Cpu(storage) => storage.block_size(), + QStorage::Metal(storage) => storage.dtype().block_size(), + QStorage::Cuda(storage) => storage.dtype().block_size(), + } + } + + fn dtype(&self) -> GgmlDType { + match self { + QStorage::Cpu(storage) => storage.dtype(), + QStorage::Metal(storage) => storage.dtype(), + QStorage::Cuda(storage) => storage.dtype(), + } + } + + fn device(&self) -> Device { + match self { + QStorage::Cpu(_storage) => Device::Cpu, + QStorage::Metal(storage) => Device::Metal(storage.device().clone()), + QStorage::Cuda(storage) => Device::Cuda(storage.device().clone()), + } + } + + fn size_in_bytes(&self) -> usize { + match self { + QStorage::Cpu(storage) => storage.storage_size_in_bytes(), + QStorage::Metal(storage) => storage.storage_size_in_bytes(), + QStorage::Cuda(storage) => storage.storage_size_in_bytes(), + } + } + + fn quantize(&mut self, src: &Storage) -> Result<()> { + match (self, src) { + (QStorage::Cpu(storage), Storage::Cpu(src)) => { + storage.from_float(src.as_slice::()?)?; + } + (QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?, + (QStorage::Cuda(storage), Storage::Cuda(src)) => storage.quantize(src)?, + _ => crate::bail!("Invalid quantize storage locations do not match"), + } + Ok(()) + } + + fn quantize_onto(&mut self, src: &Storage) -> Result<()> { + match (self, src) { + (QStorage::Cpu(storage), Storage::Cpu(src)) => { + storage.from_float(src.as_slice::()?)?; + } + (QStorage::Metal(storage), Storage::Cpu(src)) => storage.quantize_onto(src)?, + (QStorage::Cuda(storage), Storage::Cpu(src)) => storage.quantize_onto(src)?, + _ => crate::bail!("Invalid quantize source storage locations: not on cpu"), + } + Ok(()) + } + + fn dequantize(&self, elem_count: usize) -> Result { + match self { + QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)), + QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)), + QStorage::Cuda(storage) => Ok(Storage::Cuda(storage.dequantize(elem_count)?)), + } + } + + fn data(&self) -> Result> { + match self { + QStorage::Cpu(storage) => { + let data_ptr = storage.as_ptr(); + let size_in_bytes = storage.storage_size_in_bytes(); + let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) }; + Ok(Cow::from(data)) + } + QStorage::Cuda(storage) => Ok(Cow::from(storage.data()?)), + QStorage::Metal(storage) => Ok(Cow::from(storage.data()?)), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum GgmlDType { + F32, + F16, + BF16, + Q4_0, + Q4_1, + Q5_0, + Q5_1, + Q8_0, + Q8_1, + Q2K, + Q3K, + Q4K, + Q5K, + Q6K, + Q8K, +} + +impl GgmlDType { + pub(crate) fn from_u32(u: u32) -> Result { + let dtype = match u { + 0 => Self::F32, + 1 => Self::F16, + 2 => Self::Q4_0, + 3 => Self::Q4_1, + 6 => Self::Q5_0, + 7 => Self::Q5_1, + 8 => Self::Q8_0, + 9 => Self::Q8_1, + 10 => Self::Q2K, + 11 => Self::Q3K, + 12 => Self::Q4K, + 13 => Self::Q5K, + 14 => Self::Q6K, + 15 => Self::Q8K, + // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389 + 30 => Self::BF16, + _ => crate::bail!("unknown dtype for tensor {u}"), + }; + Ok(dtype) + } + + pub(crate) fn to_u32(self) -> u32 { + match self { + Self::F32 => 0, + Self::F16 => 1, + Self::Q4_0 => 2, + Self::Q4_1 => 3, + Self::Q5_0 => 6, + Self::Q5_1 => 7, + Self::Q8_0 => 8, + Self::Q8_1 => 9, + Self::Q2K => 10, + Self::Q3K => 11, + Self::Q4K => 12, + Self::Q5K => 13, + Self::Q6K => 14, + Self::Q8K => 15, + // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389 + Self::BF16 => 30, + } + } + + /// The block dtype + pub fn cpu_zeros(&self, elem_count: usize) -> Box { + match self { + Self::F32 => Box::new(vec![f32::zeros(); elem_count]), + Self::F16 => Box::new(vec![f16::zeros(); elem_count]), + Self::Q4_0 => Box::new(vec![BlockQ4_0::zeros(); elem_count / BlockQ4_0::BLCK_SIZE]), + Self::Q4_1 => Box::new(vec![BlockQ4_1::zeros(); elem_count / BlockQ4_1::BLCK_SIZE]), + Self::Q5_0 => Box::new(vec![BlockQ5_0::zeros(); elem_count / BlockQ5_0::BLCK_SIZE]), + Self::Q5_1 => Box::new(vec![BlockQ5_1::zeros(); elem_count / BlockQ5_1::BLCK_SIZE]), + Self::Q8_0 => Box::new(vec![BlockQ8_0::zeros(); elem_count / BlockQ8_0::BLCK_SIZE]), + Self::Q8_1 => Box::new(vec![BlockQ8_1::zeros(); elem_count / BlockQ8_1::BLCK_SIZE]), + Self::Q2K => Box::new(vec![BlockQ2K::zeros(); elem_count / BlockQ2K::BLCK_SIZE]), + Self::Q3K => Box::new(vec![BlockQ3K::zeros(); elem_count / BlockQ3K::BLCK_SIZE]), + Self::Q4K => Box::new(vec![BlockQ4K::zeros(); elem_count / BlockQ4K::BLCK_SIZE]), + Self::Q5K => Box::new(vec![BlockQ5K::zeros(); elem_count / BlockQ5K::BLCK_SIZE]), + Self::Q6K => Box::new(vec![BlockQ6K::zeros(); elem_count / BlockQ6K::BLCK_SIZE]), + Self::Q8K => Box::new(vec![BlockQ8K::zeros(); elem_count / BlockQ8K::BLCK_SIZE]), + Self::BF16 => Box::new(vec![bf16::zeros(); elem_count]), + } + } + /// The type size for blocks in bytes. + pub fn type_size(&self) -> usize { + use k_quants::*; + match self { + Self::F32 => 4, + Self::F16 | Self::BF16 => 2, + Self::Q4_0 => std::mem::size_of::(), + Self::Q4_1 => std::mem::size_of::(), + Self::Q5_0 => std::mem::size_of::(), + Self::Q5_1 => std::mem::size_of::(), + // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L932 + Self::Q8_0 => std::mem::size_of::(), + Self::Q8_1 => std::mem::size_of::(), + Self::Q2K => std::mem::size_of::(), + Self::Q3K => std::mem::size_of::(), + Self::Q4K => std::mem::size_of::(), + Self::Q5K => std::mem::size_of::(), + Self::Q6K => std::mem::size_of::(), + Self::Q8K => std::mem::size_of::(), + } + } + + /// The block size, i.e. the number of elements stored in each block. + pub fn block_size(&self) -> usize { + match self { + Self::F32 => 1, + Self::F16 | Self::BF16 => 1, + Self::Q4_0 => k_quants::QK4_0, + Self::Q4_1 => k_quants::QK4_1, + Self::Q5_0 => k_quants::QK5_0, + Self::Q5_1 => k_quants::QK5_1, + Self::Q8_0 => k_quants::QK8_0, + Self::Q8_1 => k_quants::QK8_1, + Self::Q2K | Self::Q3K | Self::Q4K | Self::Q5K | Self::Q6K | Self::Q8K => k_quants::QK_K, + } + } +} + +// A version of GgmlType without `vec_dot` so that it can be dyn boxed. +pub trait QuantizedType: Send + Sync { + fn dtype(&self) -> GgmlDType; + fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>; + fn dequantize(&self, elem_count: usize) -> Result; + fn storage_size_in_bytes(&self) -> usize; + fn as_ptr(&self) -> *const u8; + fn block_size(&self) -> usize; + #[allow(clippy::wrong_self_convention)] + fn from_float(&mut self, xs: &[f32]) -> Result<()>; + fn size(&self) -> usize; +} + +impl QuantizedType for Vec { + fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> { + k_quants::matmul(mkn, lhs, self.as_slice(), dst) + } + + fn size(&self) -> usize { + self.len() * core::mem::size_of::() + } + + fn from_float(&mut self, xs: &[f32]) -> Result<()> { + T::from_float(xs, self) + } + + fn dtype(&self) -> GgmlDType { + T::DTYPE + } + + fn block_size(&self) -> usize { + T::BLCK_SIZE + } + + fn dequantize(&self, elem_count: usize) -> Result { + let mut ys = vec![0.0f32; elem_count]; + T::to_float(self.as_slice(), &mut ys)?; + Ok(CpuStorage::F32(ys)) + } + + fn storage_size_in_bytes(&self) -> usize { + self.len() * std::mem::size_of::() + } + + fn as_ptr(&self) -> *const u8 { + self.as_ptr() as *const u8 + } +} + +impl std::fmt::Debug for QTensor { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "QTensor[{:?}; {:?}]", self.shape, self.dtype()) + } +} + +fn check_shape(shape: &Shape, block_size: usize) -> Result<()> { + let dims = shape.dims(); + if dims.is_empty() { + crate::bail!("scalar tensor cannot be quantized {shape:?}") + } + if dims[dims.len() - 1] % block_size != 0 { + crate::bail!( + "quantized tensor must have their last dim divisible by block size {shape:?} {}", + block_size + ) + } + Ok(()) +} + +impl QTensor { + pub fn new>(storage: QStorage, shape: S) -> Result { + let shape = shape.into(); + check_shape(&shape, storage.block_size())?; + Ok(Self { storage, shape }) + } + + pub fn quantize(src: &Tensor, dtype: GgmlDType) -> Result { + let shape = src.shape(); + let block_size = dtype.block_size(); + check_shape(shape, block_size)?; + let src = src.to_dtype(crate::DType::F32)?.flatten_all()?; + let elem_count = shape.elem_count(); + if elem_count % block_size != 0 { + crate::bail!( + "tensor size ({shape:?}) is not divisible by block size {}", + block_size + ) + } + let mut storage = src.device().qzeros(elem_count, dtype)?; + storage.quantize(&src.storage())?; + Ok(Self { + storage, + shape: shape.clone(), + }) + } + + /// Quantize `src` (currently on the CPU) to a QTensor on `dev` + pub fn quantize_onto(src: &Tensor, dtype: GgmlDType, dev: &Device) -> Result { + if !src.device().is_cpu() { + crate::bail!( + "`quantize_onto` expects a `src` to be on the cpu, got {:?}.", + src.device() + ) + } + let shape = src.shape(); + let block_size = dtype.block_size(); + check_shape(shape, block_size)?; + let src = src.to_dtype(crate::DType::F32)?.flatten_all()?; + let elem_count = shape.elem_count(); + if elem_count % block_size != 0 { + crate::bail!( + "tensor size ({shape:?}) is not divisible by block size {}", + block_size + ) + } + // storage is on the `dev`, src is on `cpu` + let mut storage = dev.qzeros(elem_count, dtype)?; + storage.quantize_onto(&src.storage())?; + Ok(Self { + storage, + shape: shape.clone(), + }) + } + + pub fn dtype(&self) -> GgmlDType { + self.storage.dtype() + } + + pub fn device(&self) -> Device { + self.storage.device() + } + + pub fn rank(&self) -> usize { + self.shape.rank() + } + + pub fn shape(&self) -> &Shape { + &self.shape + } + + pub fn dequantize(&self, device: &Device) -> Result { + let storage = self.storage.dequantize(self.shape.elem_count())?; + let none = crate::op::BackpropOp::none(); + crate::tensor::from_storage(storage, self.shape.clone(), none, false).to_device(device) + } + + pub fn dequantize_f16(&self, device: &Device) -> Result { + // In the CUDA case, we have a specialized kernel as this can be useful for volta + // architectures. https://github.com/huggingface/mcandle/issues/2136 + match &self.storage { + QStorage::Cuda(s) => { + let s = s.dequantize_f16(self.shape.elem_count())?; + let none = crate::op::BackpropOp::none(); + crate::tensor::from_storage(Storage::Cuda(s), self.shape.clone(), none, false) + .to_device(device) + } + _ => { + let s = self.dequantize(device)?.to_dtype(crate::DType::F16)?; + Ok(s) + } + } + } + + pub fn storage_size_in_bytes(&self) -> usize { + self.storage.size_in_bytes() + } + + pub fn data(&self) -> Result> { + self.storage.data() + } +} + +#[derive(Clone, Debug)] +pub enum QMatMul { + QTensor(std::sync::Arc), + Tensor(Tensor), + TensorF16(Tensor), +} + +thread_local! { + static DEQUANTIZE_ALL: bool = { + match std::env::var("CANDLE_DEQUANTIZE_ALL") { + Ok(s) => { + !s.is_empty() && s != "0" + }, + Err(_) => false, + } + } +} + +thread_local! { + static DEQUANTIZE_ALL_F16: bool = { + match std::env::var("CANDLE_DEQUANTIZE_ALL_F16") { + Ok(s) => { + !s.is_empty() && s != "0" + }, + Err(_) => false, + } + } +} + +impl QMatMul { + pub fn from_arc(qtensor: std::sync::Arc) -> Result { + let dequantize = match qtensor.dtype() { + GgmlDType::F32 | GgmlDType::F16 | GgmlDType::BF16 => true, + _ => DEQUANTIZE_ALL.with(|b| *b), + }; + let t = if dequantize { + let tensor = qtensor.dequantize(&qtensor.device())?; + Self::Tensor(tensor) + } else if DEQUANTIZE_ALL_F16.with(|b| *b) { + let tensor = qtensor.dequantize_f16(&qtensor.device())?; + Self::TensorF16(tensor) + } else { + Self::QTensor(qtensor) + }; + Ok(t) + } + + pub fn from_qtensor(qtensor: QTensor) -> Result { + Self::from_arc(std::sync::Arc::new(qtensor)) + } + + pub fn dequantize_f16(&self) -> Result { + match self { + Self::QTensor(t) => t.dequantize_f16(&t.device()), + Self::Tensor(t) => t.to_dtype(DType::F16), + Self::TensorF16(t) => Ok(t.clone()), + } + } + + pub fn forward_via_f16(&self, xs: &Tensor) -> Result { + let w = self.dequantize_f16()?; + let in_dtype = xs.dtype(); + let w = match *xs.dims() { + [b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?, + [bsize, _, _] => w.broadcast_left(bsize)?.t()?, + _ => w.t()?, + }; + xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype) + } +} + +impl crate::CustomOp1 for QTensor { + fn name(&self) -> &'static str { + "qmatmul" + } + + fn cpu_fwd( + &self, + storage: &crate::CpuStorage, + layout: &crate::Layout, + ) -> Result<(crate::CpuStorage, Shape)> { + if !layout.is_contiguous() { + crate::bail!("input tensor is not contiguous {layout:?}") + } + let src_shape = layout.shape(); + // self is transposed so n is first then k. + let (n, k) = self.shape.dims2()?; + if src_shape.rank() < 2 { + crate::bail!("input tensor has only one dimension {layout:?}") + } + let mut dst_shape = src_shape.dims().to_vec(); + let last_k = dst_shape.pop().unwrap(); + if last_k != k { + crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape) + } + dst_shape.push(n); + let dst_shape = Shape::from(dst_shape); + #[allow(clippy::infallible_destructuring_match)] + let self_storage = match &self.storage { + QStorage::Cpu(storage) => storage, + QStorage::Metal(_) | QStorage::Cuda(_) => crate::bail!("Invalid storage"), + }; + let slice = storage.as_slice::()?; + let slice = &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()]; + let mut dst_storage = vec![0f32; dst_shape.elem_count()]; + self_storage.matmul_t((dst_shape.elem_count() / n, k, n), slice, &mut dst_storage)?; + Ok((crate::CpuStorage::F32(dst_storage), dst_shape)) + } + + fn metal_fwd( + &self, + storage: &crate::MetalStorage, + layout: &crate::Layout, + ) -> Result<(crate::MetalStorage, Shape)> { + let self_storage = match &self.storage { + QStorage::Metal(metal) => metal, + _ => unreachable!("Cannot call metal matmul on non metal QTensor"), + }; + self_storage.fwd(&self.shape, storage, layout) + } + + fn cuda_fwd( + &self, + storage: &crate::CudaStorage, + layout: &crate::Layout, + ) -> Result<(crate::CudaStorage, Shape)> { + let self_storage = match &self.storage { + QStorage::Cuda(cuda) => cuda, + _ => unreachable!("Cannot call cuda matmul on non cuda QTensor"), + }; + self_storage.fwd(&self.shape, storage, layout) + } +} + +impl crate::Module for QMatMul { + fn forward(&self, xs: &Tensor) -> Result { + match self { + Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()), + Self::Tensor(w) => { + let w = match *xs.dims() { + [b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?, + [bsize, _, _] => w.broadcast_left(bsize)?.t()?, + _ => w.t()?, + }; + xs.matmul(&w) + } + Self::TensorF16(w) => { + let in_dtype = xs.dtype(); + let w = match *xs.dims() { + [b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?, + [bsize, _, _] => w.broadcast_left(bsize)?.t()?, + _ => w.t()?, + }; + xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype) + } + } + } +} diff --git a/mcandle/mcandle-core/src/quantized/neon.rs b/mcandle/mcandle-core/src/quantized/neon.rs new file mode 100644 index 000000000..c4d5d6f41 --- /dev/null +++ b/mcandle/mcandle-core/src/quantized/neon.rs @@ -0,0 +1,613 @@ +use super::k_quants::{ + BlockQ2K, BlockQ3K, BlockQ4K, BlockQ4_0, BlockQ5K, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K, +}; +use crate::Result; +use byteorder::{ByteOrder, LittleEndian}; + +#[allow(unused_imports)] +#[cfg(target_arch = "arm")] +use core::arch::arm::*; + +#[allow(unused_imports)] +#[cfg(target_arch = "aarch64")] +use core::arch::aarch64::*; + +#[inline(always)] +unsafe fn vdotq_s32(a: int8x16_t, b: int8x16_t) -> int32x4_t { + // TODO: dotprod + let p0 = vmull_s8(vget_low_s8(a), vget_low_s8(b)); + let p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b)); + vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1)) +} + +#[inline(always)] +pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result { + let qk = QK8_0; + let nb = n / qk; + if n % QK8_0 != 0 { + crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}") + } + + unsafe { + let mut sumv0 = vdupq_n_f32(0.0f32); + for i in 0..nb { + let x0 = &xs[i]; + let y0 = &ys[i]; + + let m4b = vdupq_n_u8(0x0F); + let s8b = vdupq_n_s8(0x8); + + let v0_0 = vld1q_u8(x0.qs.as_ptr()); + + // 4-bit -> 8-bit + let v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b)); + let v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + + // sub 8 + let v0_0ls = vsubq_s8(v0_0l, s8b); + let v0_0hs = vsubq_s8(v0_0h, s8b); + + // load y + let v1_0l = vld1q_s8(y0.qs.as_ptr()); + let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16)); + + let pl0 = vdotq_s32(v0_0ls, v1_0l); + let ph0 = vdotq_s32(v0_0hs, v1_0h); + sumv0 = vmlaq_n_f32( + sumv0, + vcvtq_f32_s32(vaddq_s32(pl0, ph0)), + x0.d.to_f32() * y0.d.to_f32(), + ); + } + Ok(vaddvq_f32(sumv0)) + } +} + +#[inline(always)] +pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> Result { + let qk = QK8_0; + if n % QK8_0 != 0 { + crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}") + } + let nb = n / QK8_0; + unsafe { + let mut sumv0 = vdupq_n_f32(0.0f32); + for i in 0..nb { + let x0 = &xs[i]; + let y0 = &ys[i]; + + let x0_0 = vld1q_s8(x0.qs.as_ptr()); + let x0_1 = vld1q_s8(x0.qs.as_ptr().add(16)); + + // load y + let y0_0 = vld1q_s8(y0.qs.as_ptr()); + let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16)); + + let p0 = vdotq_s32(x0_0, y0_0); + let p1 = vdotq_s32(x0_1, y0_1); + + sumv0 = vmlaq_n_f32( + sumv0, + vcvtq_f32_s32(vaddq_s32(p0, p1)), + x0.d.to_f32() * y0.d.to_f32(), + ); + } + Ok(vaddvq_f32(sumv0)) + } +} + +#[inline(always)] +pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result { + let qk = QK_K; + if n % QK_K != 0 { + crate::bail!("vec_dot_q8k_q8k: {n} is not divisible by {qk}") + } + + let mut sumf = 0f32; + for (xs, ys) in xs.iter().zip(ys.iter()) { + unsafe { + let mut sum_i = vdupq_n_s32(0); + let scale = xs.d * ys.d; + let xs = xs.qs.as_ptr(); + let ys = ys.qs.as_ptr(); + for i in (0..QK_K).step_by(16) { + let xs = vld1q_s8(xs.add(i)); + let ys = vld1q_s8(ys.add(i)); + let xy = vdotq_s32(xs, ys); + sum_i = vaddq_s32(sum_i, xy) + } + sumf += vaddvq_s32(sum_i) as f32 * scale + } + } + Ok(sumf) +} + +#[inline(always)] +pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result { + if n % QK_K != 0 { + crate::bail!("vec_dot_q6k_q8k: {n} is not divisible by {QK_K}") + } + let mut sum = 0f32; + unsafe { + let m4b = vdupq_n_u8(0xF); + + let mone = vdupq_n_u8(3); + + for (x, y) in xs.iter().zip(ys.iter()) { + let d_all = x.d.to_f32(); + + let mut q6 = x.ql.as_ptr(); + let mut qh = x.qh.as_ptr(); + let mut q8 = y.qs.as_ptr(); + + let mut scale = x.scales.as_ptr(); + + let q8sums = vld1q_s16_x2(y.bsums.as_ptr()); + let scales = vld1q_s8(scale); + let q6scales = int16x8x2_t( + vmovl_s8(vget_low_s8(scales)), + vmovl_s8(vget_high_s8(scales)), + ); + + let prod = vaddq_s32( + vaddq_s32( + vmull_s16(vget_low_s16(q8sums.0), vget_low_s16(q6scales.0)), + vmull_s16(vget_high_s16(q8sums.0), vget_high_s16(q6scales.0)), + ), + vaddq_s32( + vmull_s16(vget_low_s16(q8sums.1), vget_low_s16(q6scales.1)), + vmull_s16(vget_high_s16(q8sums.1), vget_high_s16(q6scales.1)), + ), + ); + let isum_mins = vaddvq_s32(prod); + + let mut isum = 0i32; + + for _j in 0..QK_K / 128 { + let qhbits = vld1q_u8_x2(qh); + qh = qh.add(32); + let q6bits = vld1q_u8_x4(q6); + q6 = q6.add(64); + let q8bytes = vld1q_s8_x4(q8); + q8 = q8.add(64); + + let q6h_0 = vshlq_n_u8(vandq_u8(mone, qhbits.0), 4); + let q6h_1 = vshlq_n_u8(vandq_u8(mone, qhbits.1), 4); + let shifted = vshrq_n_u8(qhbits.0, 2); + let q6h_2 = vshlq_n_u8(vandq_u8(mone, shifted), 4); + let shifted = vshrq_n_u8(qhbits.1, 2); + let q6h_3 = vshlq_n_u8(vandq_u8(mone, shifted), 4); + + let q6bytes_0 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.0, m4b), q6h_0)); + let q6bytes_1 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.1, m4b), q6h_1)); + let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.2, m4b), q6h_2)); + let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.3, m4b), q6h_3)); + + let p0 = vdotq_s32(q6bytes_0, q8bytes.0); + let p1 = vdotq_s32(q6bytes_1, q8bytes.1); + let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); + isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1; + scale = scale.add(2); + + let p2 = vdotq_s32(q6bytes_2, q8bytes.2); + let p3 = vdotq_s32(q6bytes_3, q8bytes.3); + let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); + isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1; + scale = scale.add(2); + + let q8bytes = vld1q_s8_x4(q8); + q8 = q8.add(64); + + let shifted = vshrq_n_u8(qhbits.0, 4); + let q6h_0 = vshlq_n_u8(vandq_u8(mone, shifted), 4); + let shifted = vshrq_n_u8(qhbits.1, 4); + let q6h_1 = vshlq_n_u8(vandq_u8(mone, shifted), 4); + let shifted = vshrq_n_u8(qhbits.0, 6); + let q6h_2 = vshlq_n_u8(vandq_u8(mone, shifted), 4); + let shifted = vshrq_n_u8(qhbits.1, 6); + let q6h_3 = vshlq_n_u8(vandq_u8(mone, shifted), 4); + + let q6bytes_0 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.0, 4), q6h_0)); + let q6bytes_1 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.1, 4), q6h_1)); + let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.2, 4), q6h_2)); + let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.3, 4), q6h_3)); + + let p0 = vdotq_s32(q6bytes_0, q8bytes.0); + let p1 = vdotq_s32(q6bytes_1, q8bytes.1); + let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); + isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1; + scale = scale.add(2); + + let p2 = vdotq_s32(q6bytes_2, q8bytes.2); + let p3 = vdotq_s32(q6bytes_3, q8bytes.3); + let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); + isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1; + scale = scale.add(2); + } + sum += d_all * y.d * ((isum - 32 * isum_mins) as f32); + } + } + Ok(sum) +} + +#[inline(always)] +pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Result { + if n % QK_K != 0 { + crate::bail!("vec_dot_q5k_q8k: {n} is not divisible by {QK_K}") + } + let mut sumf = 0f32; + let mut utmp = [0u32; 4]; + const KMASK1: u32 = 0x3f3f3f3f; + const KMASK2: u32 = 0x0f0f0f0f; + const KMASK3: u32 = 0x03030303; + + unsafe { + let m4b = vdupq_n_u8(0xF); + let mone = vdupq_n_u8(1); + let mtwo = vdupq_n_u8(2); + + for (x, y) in xs.iter().zip(ys.iter()) { + let d = y.d * x.d.to_f32(); + let dmin = y.d * x.dmin.to_f32(); + + let q8sums = vpaddq_s16( + vld1q_s16(y.bsums.as_ptr()), + vld1q_s16(y.bsums.as_ptr().add(8)), + ); + + LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]); + + utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4); + let uaux = utmp[1] & KMASK1; + utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4); + utmp[2] = uaux; + utmp[0] &= KMASK1; + + let mins8 = vld1_u8((utmp.as_ptr() as *const u8).add(8)); + let mins = vreinterpretq_s16_u16(vmovl_u8(mins8)); + let prod = vaddq_s32( + vmull_s16(vget_low_s16(q8sums), vget_low_s16(mins)), + vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)), + ); + let sumi_mins = vaddvq_s32(prod); + + let mut scales = utmp.as_ptr() as *const u8; + + let mut q5 = x.qs.as_ptr(); + let mut q8 = y.qs.as_ptr(); + + let mut qhbits = vld1q_u8_x2(x.qh.as_ptr()); + + let mut sumi = 0i32; + + for _j in 0..QK_K / 64 { + let q5bits = vld1q_u8_x2(q5); + q5 = q5.add(32); + let q8bytes = vld1q_s8_x4(q8); + q8 = q8.add(64); + + let q5h_0 = vshlq_n_u8(vandq_u8(mone, qhbits.0), 4); + let q5h_1 = vshlq_n_u8(vandq_u8(mone, qhbits.1), 4); + let q5h_2 = vshlq_n_u8(vandq_u8(mtwo, qhbits.0), 3); + let q5h_3 = vshlq_n_u8(vandq_u8(mtwo, qhbits.1), 3); + qhbits.0 = vshrq_n_u8(qhbits.0, 2); + qhbits.1 = vshrq_n_u8(qhbits.1, 2); + + let q5bytes_0 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.0, m4b), q5h_0)); + let q5bytes_1 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.1, m4b), q5h_1)); + let q5bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.0, 4), q5h_2)); + let q5bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.1, 4), q5h_3)); + + let p0 = vdotq_s32(q5bytes_0, q8bytes.0); + let p1 = vdotq_s32(q5bytes_1, q8bytes.1); + sumi += vaddvq_s32(vaddq_s32(p0, p1)) * *scales as i32; + scales = scales.add(1); + + let p2 = vdotq_s32(q5bytes_2, q8bytes.2); + let p3 = vdotq_s32(q5bytes_3, q8bytes.3); + sumi += vaddvq_s32(vaddq_s32(p2, p3)) * *scales as i32; + scales = scales.add(1); + } + sumf += d * sumi as f32 - dmin * sumi_mins as f32; + } + } + Ok(sumf) +} + +#[inline(always)] +pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result { + if n % QK_K != 0 { + crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}") + } + let mut sumf = 0f32; + let mut utmp = [0u32; 4]; + let mut scales = [0u8; 16]; + const KMASK1: u32 = 0x3f3f3f3f; + const KMASK2: u32 = 0x0f0f0f0f; + const KMASK3: u32 = 0x03030303; + + unsafe { + let m4b = vdupq_n_u8(0xF); + + for (x, y) in xs.iter().zip(ys.iter()) { + let d = y.d * x.d.to_f32(); + let dmin = y.d * x.dmin.to_f32(); + + let q8sums = vpaddq_s16( + vld1q_s16(y.bsums.as_ptr()), + vld1q_s16(y.bsums.as_ptr().add(8)), + ); + + LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]); + + let mins8 = vld1_u32( + [ + utmp[1] & KMASK1, + ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4), + ] + .as_ptr(), + ); + utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4); + utmp[0] &= KMASK1; + + let mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8))); + let prod = vaddq_s32( + vmull_s16(vget_low_s16(q8sums), vget_low_s16(mins)), + vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)), + ); + sumf -= dmin * vaddvq_s32(prod) as f32; + + LittleEndian::write_u32_into(&utmp, &mut scales); + + let mut q4 = x.qs.as_ptr(); + let mut q8 = y.qs.as_ptr(); + + let mut sumi1 = 0i32; + let mut sumi2 = 0i32; + + for j in 0..QK_K / 64 { + let q4bits = vld1q_u8_x2(q4); + q4 = q4.add(32); + let q8bytes = vld1q_s8_x2(q8); + q8 = q8.add(32); + let q4bytes = int8x16x2_t( + vreinterpretq_s8_u8(vandq_u8(q4bits.0, m4b)), + vreinterpretq_s8_u8(vandq_u8(q4bits.1, m4b)), + ); + let p0 = vdotq_s32(q4bytes.0, q8bytes.0); + let p1 = vdotq_s32(q4bytes.1, q8bytes.1); + sumi1 += vaddvq_s32(vaddq_s32(p0, p1)) * scales[2 * j] as i32; + + let q8bytes = vld1q_s8_x2(q8); + q8 = q8.add(32); + let q4bytes = int8x16x2_t( + vreinterpretq_s8_u8(vshrq_n_u8(q4bits.0, 4)), + vreinterpretq_s8_u8(vshrq_n_u8(q4bits.1, 4)), + ); + let p2 = vdotq_s32(q4bytes.0, q8bytes.0); + let p3 = vdotq_s32(q4bytes.1, q8bytes.1); + sumi2 += vaddvq_s32(vaddq_s32(p2, p3)) * scales[2 * j + 1] as i32; + } + sumf += d * (sumi1 + sumi2) as f32; + } + } + Ok(sumf) +} + +#[inline(always)] +pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Result { + if n % QK_K != 0 { + crate::bail!("vec_dot_q3k_q8k: {n} is not divisible by {QK_K}") + } + let mut sumf = 0f32; + let mut utmp = [0u32; 4]; + let mut aux = [0u32; 3]; + const KMASK1: u32 = 0x03030303; + const KMASK2: u32 = 0x0f0f0f0f; + + unsafe { + let m3b = vdupq_n_u8(0x3); + let m0 = vdupq_n_u8(1); + let m1 = vshlq_n_u8(m0, 1); + let m2 = vshlq_n_u8(m0, 2); + let m3 = vshlq_n_u8(m0, 3); + for (x, y) in xs.iter().zip(ys.iter()) { + let d = y.d * x.d.to_f32(); + let mut q3 = x.qs.as_ptr(); + let qh = x.hmask.as_ptr(); + let mut q8 = y.qs.as_ptr(); + + let mut qhbits = vld1q_u8_x2(qh); + + let mut isum = 0i32; + + // Set up scales + LittleEndian::read_u32_into(&x.scales, &mut aux); + + utmp[3] = ((aux[1] >> 4) & KMASK2) | (((aux[2] >> 6) & KMASK1) << 4); + utmp[2] = ((aux[0] >> 4) & KMASK2) | (((aux[2] >> 4) & KMASK1) << 4); + utmp[1] = (aux[1] & KMASK2) | (((aux[2] >> 2) & KMASK1) << 4); + utmp[0] = (aux[0] & KMASK2) | ((aux[2] & KMASK1) << 4); + + let mut scale = utmp.as_mut_ptr() as *mut i8; + for j in 0..16 { + *scale.add(j) -= 32i8 + } + + for j in 0..QK_K / 128 { + let q3bits = vld1q_u8_x2(q3); + q3 = q3.add(32); + let q8bytes_1 = vld1q_s8_x4(q8); + q8 = q8.add(64); + let q8bytes_2 = vld1q_s8_x4(q8); + q8 = q8.add(64); + + let q3h_0 = vshlq_n_u8(vbicq_u8(m0, qhbits.0), 2); + let q3h_1 = vshlq_n_u8(vbicq_u8(m0, qhbits.1), 2); + let q3h_2 = vshlq_n_u8(vbicq_u8(m1, qhbits.0), 1); + let q3h_3 = vshlq_n_u8(vbicq_u8(m1, qhbits.1), 1); + + let q3bytes_0 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(q3bits.0, m3b)), + vreinterpretq_s8_u8(q3h_0), + ); + let q3bytes_1 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(q3bits.1, m3b)), + vreinterpretq_s8_u8(q3h_1), + ); + let q3bytes_2 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.0, 2), m3b)), + vreinterpretq_s8_u8(q3h_2), + ); + let q3bytes_3 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.1, 2), m3b)), + vreinterpretq_s8_u8(q3h_3), + ); + + let p0 = vdotq_s32(q3bytes_0, q8bytes_1.0); + let p1 = vdotq_s32(q3bytes_1, q8bytes_1.1); + let p2 = vdotq_s32(q3bytes_2, q8bytes_1.2); + let p3 = vdotq_s32(q3bytes_3, q8bytes_1.3); + isum += vaddvq_s32(p0) * *scale as i32 + + vaddvq_s32(p1) * *scale.add(1) as i32 + + vaddvq_s32(p2) * *scale.add(2) as i32 + + vaddvq_s32(p3) * *scale.add(3) as i32; + scale = scale.add(4); + + let q3h_0 = vbicq_u8(m2, qhbits.0); + let q3h_1 = vbicq_u8(m2, qhbits.1); + let q3h_2 = vshrq_n_u8(vbicq_u8(m3, qhbits.0), 1); + let q3h_3 = vshrq_n_u8(vbicq_u8(m3, qhbits.1), 1); + + let q3bytes_0 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.0, 4), m3b)), + vreinterpretq_s8_u8(q3h_0), + ); + let q3bytes_1 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.1, 4), m3b)), + vreinterpretq_s8_u8(q3h_1), + ); + let q3bytes_2 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.0, 6), m3b)), + vreinterpretq_s8_u8(q3h_2), + ); + let q3bytes_3 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.1, 6), m3b)), + vreinterpretq_s8_u8(q3h_3), + ); + + let p0 = vdotq_s32(q3bytes_0, q8bytes_2.0); + let p1 = vdotq_s32(q3bytes_1, q8bytes_2.1); + let p2 = vdotq_s32(q3bytes_2, q8bytes_2.2); + let p3 = vdotq_s32(q3bytes_3, q8bytes_2.3); + isum += vaddvq_s32(p0) * *scale as i32 + + vaddvq_s32(p1) * *scale.add(1) as i32 + + vaddvq_s32(p2) * *scale.add(2) as i32 + + vaddvq_s32(p3) * *scale.add(3) as i32; + scale = scale.add(4); + + if j == 0 { + qhbits.0 = vshrq_n_u8(qhbits.0, 4); + qhbits.1 = vshrq_n_u8(qhbits.1, 4); + } + } + sumf += d * isum as f32; + } + } + Ok(sumf) +} + +#[inline(always)] +pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result { + if n % QK_K != 0 { + crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}") + } + let mut sumf = 0f32; + let mut aux = [0u8; 16]; + + unsafe { + let m3 = vdupq_n_u8(0x3); + let m4 = vdupq_n_u8(0xF); + + for (x, y) in xs.iter().zip(ys.iter()) { + let d = y.d * x.d.to_f32(); + let dmin = -y.d * x.dmin.to_f32(); + + let mut q2 = x.qs.as_ptr(); + let mut q8 = y.qs.as_ptr(); + let sc = x.scales.as_ptr(); + + let mins_and_scales = vld1q_u8(sc); + let scales = vandq_u8(mins_and_scales, m4); + vst1q_u8(aux.as_mut_ptr(), scales); + + let mins = vshrq_n_u8(mins_and_scales, 4); + let q8sums = vld1q_s16_x2(y.bsums.as_ptr()); + let mins16 = int16x8x2_t( + vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), + vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins))), + ); + let s0 = vaddq_s32( + vmull_s16(vget_low_s16(mins16.0), vget_low_s16(q8sums.0)), + vmull_s16(vget_high_s16(mins16.0), vget_high_s16(q8sums.0)), + ); + let s1 = vaddq_s32( + vmull_s16(vget_low_s16(mins16.1), vget_low_s16(q8sums.1)), + vmull_s16(vget_high_s16(mins16.1), vget_high_s16(q8sums.1)), + ); + sumf += dmin * vaddvq_s32(vaddq_s32(s0, s1)) as f32; + + let mut isum = 0i32; + let mut is = 0usize; + + // TODO: dotprod + for _j in 0..QK_K / 128 { + let q2bits = vld1q_u8_x2(q2); + q2 = q2.add(32); + + let q8bytes = vld1q_s8_x2(q8); + q8 = q8.add(32); + let mut q2bytes = int8x16x2_t( + vreinterpretq_s8_u8(vandq_u8(q2bits.0, m3)), + vreinterpretq_s8_u8(vandq_u8(q2bits.1, m3)), + ); + isum += multiply_accum_with_scale(&aux, is, 0, q2bytes, q8bytes); + + let q8bytes = vld1q_s8_x2(q8); + q8 = q8.add(32); + q2bytes.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.0, 2), m3)); + q2bytes.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.1, 2), m3)); + isum += multiply_accum_with_scale(&aux, is, 2, q2bytes, q8bytes); + + let q8bytes = vld1q_s8_x2(q8); + q8 = q8.add(32); + q2bytes.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.0, 4), m3)); + q2bytes.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.1, 4), m3)); + isum += multiply_accum_with_scale(&aux, is, 4, q2bytes, q8bytes); + + let q8bytes = vld1q_s8_x2(q8); + q8 = q8.add(32); + q2bytes.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.0, 6), m3)); + q2bytes.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.1, 6), m3)); + isum += multiply_accum_with_scale(&aux, is, 6, q2bytes, q8bytes); + + is += 8; + } + sumf += d * isum as f32; + } + } + Ok(sumf) +} + +#[inline(always)] +unsafe fn multiply_accum_with_scale( + aux: &[u8; 16], + is: usize, + index: usize, + q2bytes: int8x16x2_t, + q8bytes: int8x16x2_t, +) -> i32 { + let p1 = vdotq_s32(q2bytes.0, q8bytes.0); + let p2 = vdotq_s32(q2bytes.1, q8bytes.1); + vaddvq_s32(p1) * aux[is + index] as i32 + vaddvq_s32(p2) * aux[is + 1 + index] as i32 +} diff --git a/mcandle/mcandle-core/src/quantized/simd128.rs b/mcandle/mcandle-core/src/quantized/simd128.rs new file mode 100644 index 000000000..1c8c0f206 --- /dev/null +++ b/mcandle/mcandle-core/src/quantized/simd128.rs @@ -0,0 +1,419 @@ +use super::k_quants::{BlockQ2K, BlockQ4K, BlockQ4_0, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K}; +use crate::Result; +use byteorder::{ByteOrder, LittleEndian}; +use half::f16; + +use core::arch::wasm32::*; + +#[inline(always)] +pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result { + let qk = QK8_0; + if n % QK8_0 != 0 { + crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}") + } + unsafe { + let mut acc = f32x4_splat(0.0f32); + for (x, y) in xs.iter().zip(ys.iter()) { + let x1234 = v128_load(x.qs.as_ptr() as *const v128); + let x12 = v128_and(x1234, u8x16_splat(0x0F)); + let x12 = i8x16_sub(x12, i8x16_splat(8)); + let x34 = u8x16_shr(x1234, 4); + let x34 = i8x16_sub(x34, i8x16_splat(8)); + + let x1 = i16x8_extend_low_i8x16(x12); + let y1 = i16x8_load_extend_i8x8(y.qs.as_ptr()); + let sum_xy = i32x4_dot_i16x8(x1, y1); + + let x2 = i16x8_extend_high_i8x16(x12); + let y2 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(8)); + let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x2, y2)); + + let x3 = i16x8_extend_low_i8x16(x34); + let y3 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(16)); + let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x3, y3)); + + let x4 = i16x8_extend_high_i8x16(x34); + let y4 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(24)); + let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x4, y4)); + + let sum_xy = f32x4_convert_i32x4(sum_xy); + + // f32x4_relaxed_madd is nightly only. + let d = f32x4_splat(f16::to_f32(x.d) * f16::to_f32(y.d)); + let scaled = f32x4_mul(sum_xy, d); + acc = f32x4_add(acc, scaled) + } + let res = f32x4_extract_lane::<0>(acc) + + f32x4_extract_lane::<1>(acc) + + f32x4_extract_lane::<2>(acc) + + f32x4_extract_lane::<3>(acc); + Ok(res) + } +} + +#[inline(always)] +pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> Result { + let qk = QK8_0; + if n % QK8_0 != 0 { + crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}") + } + unsafe { + let mut acc = f32x4_splat(0.0f32); + for (x, y) in xs.iter().zip(ys.iter()) { + let x1 = i16x8_load_extend_i8x8(x.qs.as_ptr()); + let y1 = i16x8_load_extend_i8x8(y.qs.as_ptr()); + let sum_xy = i32x4_dot_i16x8(x1, y1); + + let x2 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(8)); + let y2 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(8)); + let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x2, y2)); + + let x3 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(16)); + let y3 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(16)); + let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x3, y3)); + + let x4 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(24)); + let y4 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(24)); + let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x4, y4)); + + let sum_xy = f32x4_convert_i32x4(sum_xy); + + // f32x4_relaxed_madd is nightly only. + let d = f32x4_splat(f16::to_f32(x.d) * f16::to_f32(y.d)); + let scaled = f32x4_mul(sum_xy, d); + acc = f32x4_add(acc, scaled) + } + let res = f32x4_extract_lane::<0>(acc) + + f32x4_extract_lane::<1>(acc) + + f32x4_extract_lane::<2>(acc) + + f32x4_extract_lane::<3>(acc); + Ok(res) + } +} + +#[inline(always)] +pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result { + if n % QK_K != 0 { + crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}") + } + unsafe { + let mut sumf = f32x4_splat(0f32); + for (x, y) in xs.iter().zip(ys.iter()) { + let mut q2: &[_] = &x.qs; + let mut q8: &[_] = &y.qs; + let sc = &x.scales; + + let mut summs = i32x4_splat(0); + for i in (0..(QK_K / 16)).step_by(4) { + let bsums = i32x4_load_extend_i16x4(y.bsums.as_ptr().add(i)); + let scales = i32x4_shr( + i32x4( + sc[i] as i32, + sc[i + 1] as i32, + sc[i + 2] as i32, + sc[i + 3] as i32, + ), + 4, + ); + summs = i32x4_add(summs, i32x4_mul(bsums, scales)) + } + let summs = f32x4_convert_i32x4(summs); + + let dall = y.d * x.d.to_f32(); + let dmin = y.d * x.dmin.to_f32(); + + let mut isum = i32x4_splat(0); + let mut is = 0; + for _ in 0..(QK_K / 128) { + let mut shift = 0; + for _ in 0..4 { + let d = (sc[is] & 0xF) as i32; + is += 1; + let mut isuml = i16x8_splat(0); + for l in (0..16).step_by(8) { + let q8 = i16x8_load_extend_i8x8(q8.as_ptr().add(l)); + let q2 = i16x8_load_extend_u8x8(q2.as_ptr().add(l)); + let q2 = v128_and(i16x8_shr(q2, shift), i16x8_splat(3)); + isuml = i16x8_add(isuml, i16x8_mul(q2, q8)) + } + let dd = i32x4_splat(d); + isum = i32x4_add(isum, i32x4_mul(i32x4_extend_low_i16x8(isuml), dd)); + isum = i32x4_add(isum, i32x4_mul(i32x4_extend_high_i16x8(isuml), dd)); + let d = (sc[is] & 0xF) as i32; + is += 1; + let mut isuml = i16x8_splat(0); + for l in (16..32).step_by(8) { + let q8 = i16x8_load_extend_i8x8(q8.as_ptr().add(l)); + let q2 = i16x8_load_extend_u8x8(q2.as_ptr().add(l)); + let q2 = v128_and(i16x8_shr(q2, shift), i16x8_splat(3)); + isuml = i16x8_add(isuml, i16x8_mul(q2, q8)) + } + let dd = i32x4_splat(d); + isum = i32x4_add(isum, i32x4_mul(i32x4_extend_low_i16x8(isuml), dd)); + isum = i32x4_add(isum, i32x4_mul(i32x4_extend_high_i16x8(isuml), dd)); + shift += 2; + // adjust the indexing + q8 = &q8[32..]; + } + // adjust the indexing + q2 = &q2[32..]; + } + let isum = f32x4_convert_i32x4(isum); + sumf = f32x4_add( + sumf, + f32x4_sub( + f32x4_mul(isum, f32x4_splat(dall)), + f32x4_mul(summs, f32x4_splat(dmin)), + ), + ); + } + let sumf = f32x4_extract_lane::<0>(sumf) + + f32x4_extract_lane::<1>(sumf) + + f32x4_extract_lane::<2>(sumf) + + f32x4_extract_lane::<3>(sumf); + Ok(sumf) + } +} + +#[inline(always)] +pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result { + if n % QK_K != 0 { + crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}") + } + + const KMASK1: u32 = 0x3f3f3f3f; + const KMASK2: u32 = 0x0f0f0f0f; + const KMASK3: u32 = 0x03030303; + + let mut utmp: [u32; 4] = [0; 4]; + let mut scales: [u8; 8] = [0; 8]; + let mut mins: [u8; 8] = [0; 8]; + + let mut aux8: [u8; QK_K] = [0; QK_K]; + let mut sums = f32x4_splat(0f32); + unsafe { + for (y, x) in ys.iter().zip(xs.iter()) { + let q4 = &x.qs; + let q8 = &y.qs; + + for j in 0..QK_K / 64 { + let q4_1 = v128_load(q4.as_ptr().add(32 * j) as *const v128); + let q4_2 = v128_load(q4.as_ptr().add(32 * j + 16) as *const v128); + v128_store( + aux8.as_mut_ptr().add(64 * j) as *mut v128, + v128_and(q4_1, u8x16_splat(0x0F)), + ); + v128_store( + aux8.as_mut_ptr().add(64 * j + 16) as *mut v128, + v128_and(q4_2, u8x16_splat(0x0F)), + ); + v128_store( + aux8.as_mut_ptr().add(64 * j + 32) as *mut v128, + u8x16_shr(q4_1, 4), + ); + v128_store( + aux8.as_mut_ptr().add(64 * j + 48) as *mut v128, + u8x16_shr(q4_2, 4), + ); + } + + LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]); + + utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4); + let uaux = utmp[1] & KMASK1; + utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4); + utmp[2] = uaux; + utmp[0] &= KMASK1; + + //extract scales and mins + LittleEndian::write_u32_into(&utmp[0..2], &mut scales); + LittleEndian::write_u32_into(&utmp[2..4], &mut mins); + + let mut sumi = i32x4_splat(0); + for j in (0..QK_K / 16).step_by(4) { + let bsums = i32x4_load_extend_i16x4(y.bsums.as_ptr().add(j)); + let (m1, m2) = (mins[j / 2] as i32, mins[j / 2 + 1] as i32); + let mins = i32x4(m1, m1, m2, m2); + sumi = i32x4_add(sumi, i32x4_mul(bsums, mins)); + } + + let mut aux32 = i32x4_splat(0i32); + for (scale_i, scale) in scales.iter().enumerate() { + let scale = i32x4_splat(*scale as i32); + for j in 0..4 { + let i = 32 * scale_i + 8 * j; + let q8 = i16x8_load_extend_i8x8(q8.as_ptr().add(i)); + let aux8 = i16x8_load_extend_u8x8(aux8.as_ptr().add(i)); + let aux16 = i16x8_mul(q8, aux8); + aux32 = i32x4_add(aux32, i32x4_mul(scale, i32x4_extend_low_i16x8(aux16))); + aux32 = i32x4_add(aux32, i32x4_mul(scale, i32x4_extend_high_i16x8(aux16))); + } + } + let aux32 = f32x4_convert_i32x4(aux32); + let d = f32x4_splat(x.d.to_f32() * y.d); + sums = f32x4_add(sums, f32x4_mul(aux32, d)); + let dmin = x.dmin.to_f32() * y.d; + let dmin = f32x4_splat(dmin); + let sumi = f32x4_convert_i32x4(sumi); + sums = f32x4_sub(sums, f32x4_mul(sumi, dmin)); + } + let sums = f32x4_extract_lane::<0>(sums) + + f32x4_extract_lane::<1>(sums) + + f32x4_extract_lane::<2>(sums) + + f32x4_extract_lane::<3>(sums); + Ok(sums) + } +} + +#[inline(always)] +pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result { + if n % QK_K != 0 { + crate::bail!("vec_dot_q6k_q8k: {n} is not divisible by {QK_K}") + } + + let mut aux8 = [0i8; QK_K]; + unsafe { + let mut sums = f32x4_splat(0f32); + + for (x, y) in xs.iter().zip(ys.iter()) { + let q4 = &x.ql; + let qh = &x.qh; + let q8 = &y.qs; + let mut aux32 = f32x4_splat(0f32); + + for j in (0..QK_K).step_by(128) { + let aux8 = aux8.as_mut_ptr().add(j); + let q4 = &q4.as_ptr().add(j / 2); + let qh = &qh.as_ptr().add(j / 4); + for l in (0..32).step_by(16) { + // aux8[l] = (((q4[l] & 0xF) | ((qh[l] & 3) << 4)) as i32 - 32) as i8; + let a8 = v128_or( + v128_and(v128_load(q4.add(l) as *const v128), u8x16_splat(0xF)), + u8x16_shl( + v128_and(v128_load(qh.add(l) as *const v128), u8x16_splat(3)), + 4, + ), + ); + let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32)); + let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32)); + v128_store( + aux8.add(l) as *mut v128, + i8x16_narrow_i16x8(a8_low, a8_high), + ); + + // aux8[l + 32] = + // (((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) as i32 - 32) as i8; + let a8 = v128_or( + v128_and(v128_load(q4.add(l + 32) as *const v128), u8x16_splat(0xF)), + u8x16_shl( + v128_and( + u8x16_shr(v128_load(qh.add(l) as *const v128), 2), + u8x16_splat(3), + ), + 4, + ), + ); + let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32)); + let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32)); + v128_store( + aux8.add(l + 32) as *mut v128, + i8x16_narrow_i16x8(a8_low, a8_high), + ); + + // aux8[l + 64] = (((q4[l] >> 4) | (((qh[l] >> 4) & 3) << 4)) as i32 - 32) as i8; + let a8 = v128_or( + u8x16_shr(v128_load(q4.add(l) as *const v128), 4), + u8x16_shl( + v128_and( + u8x16_shr(v128_load(qh.add(l) as *const v128), 4), + u8x16_splat(3), + ), + 4, + ), + ); + let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32)); + let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32)); + v128_store( + aux8.add(l + 64) as *mut v128, + i8x16_narrow_i16x8(a8_low, a8_high), + ); + + // aux8[l + 96] = + // (((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) as i32 - 32) as i8; + let a8 = v128_or( + u8x16_shr(v128_load(q4.add(l + 32) as *const v128), 4), + u8x16_shl( + v128_and( + u8x16_shr(v128_load(qh.add(l) as *const v128), 6), + u8x16_splat(3), + ), + 4, + ), + ); + let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32)); + let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32)); + v128_store( + aux8.add(l + 96) as *mut v128, + i8x16_narrow_i16x8(a8_low, a8_high), + ); + } + } + + for (j, &scale) in x.scales.iter().enumerate() { + let scale = f32x4_splat(scale as f32); + for offset in [0, 8] { + let aux16 = i16x8_mul( + i16x8_load_extend_i8x8(q8.as_ptr().add(16 * j + offset)), + i16x8_load_extend_i8x8(aux8.as_ptr().add(16 * j + offset)), + ); + aux32 = f32x4_add( + aux32, + f32x4_mul(f32x4_convert_i32x4(i32x4_extend_low_i16x8(aux16)), scale), + ); + aux32 = f32x4_add( + aux32, + f32x4_mul(f32x4_convert_i32x4(i32x4_extend_high_i16x8(aux16)), scale), + ); + } + } + + let d = f32x4_splat(x.d.to_f32() * y.d); + sums = f32x4_add(sums, f32x4_mul(aux32, d)); + } + let sums = f32x4_extract_lane::<0>(sums) + + f32x4_extract_lane::<1>(sums) + + f32x4_extract_lane::<2>(sums) + + f32x4_extract_lane::<3>(sums); + Ok(sums) + } +} + +#[inline(always)] +pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result { + let qk = QK_K; + if n % QK_K != 0 { + crate::bail!("vec_dot_q8k_q8k: {n} is not divisible by {qk}") + } + + unsafe { + let mut acc = f32x4_splat(0.0f32); + for (xs, ys) in xs.iter().zip(ys.iter()) { + let x_qs = xs.qs.as_ptr(); + let y_qs = ys.qs.as_ptr(); + let mut sumi = i32x4_splat(0); + for j in (0..QK_K).step_by(8) { + let xs = i16x8_load_extend_i8x8(x_qs.add(j)); + let ys = i16x8_load_extend_i8x8(y_qs.add(j)); + let sum_xy = i32x4_dot_i16x8(xs, ys); + sumi = i32x4_add(sumi, sum_xy) + } + let d = f32x4_splat(xs.d * ys.d); + acc = f32x4_add(acc, f32x4_mul(f32x4_convert_i32x4(sumi), d)) + } + let res = f32x4_extract_lane::<0>(acc) + + f32x4_extract_lane::<1>(acc) + + f32x4_extract_lane::<2>(acc) + + f32x4_extract_lane::<3>(acc); + Ok(res) + } +} diff --git a/mcandle/mcandle-core/src/quantized/utils.rs b/mcandle/mcandle-core/src/quantized/utils.rs new file mode 100644 index 000000000..fa6eff51d --- /dev/null +++ b/mcandle/mcandle-core/src/quantized/utils.rs @@ -0,0 +1,326 @@ +use crate::Result; + +pub(super) fn nearest_int(v: f32) -> i32 { + v.round() as i32 +} + +/// Validates that the input and output are the right size and returns an iterator which maps each +/// input region `xs` to its corresponding output block in `ys`. Each output region is guaranteed +/// to be `T::BLCK_SIZE` long. +pub(super) fn group_for_quantization<'a, 'b, T: super::k_quants::GgmlType>( + xs: &'b [f32], + ys: &'a mut [T], +) -> Result> { + let block_size = T::BLCK_SIZE; + let dtype = T::DTYPE; + + let expected_blocks = xs.len() / block_size; + let actual_blocks = ys.len(); + + // Validate that the input is the right size + if expected_blocks != actual_blocks { + crate::bail!("quantize {dtype:?}: expected {expected_blocks} blocks but only {actual_blocks} were provided!") + } + + Ok(ys.iter_mut().zip(xs.chunks_exact(block_size)).collect()) +} + +/// Validates that the input and output are the right size and returns an iterator which maps each +/// input block `xs` to its corresponding output region in `ys`. Each output region is guaranteed +/// to be `T::BLCK_SIZE` long. +pub(super) fn group_for_dequantization<'a, 'b, T: super::k_quants::GgmlType>( + xs: &'a [T], + ys: &'b mut [f32], +) -> Result> { + let block_size = T::BLCK_SIZE; + let dtype = T::DTYPE; + + let actual_output_len = ys.len(); + let expected_output_len = xs.len() * block_size; + // Validate that the output is the right size + if expected_output_len != actual_output_len { + crate::bail!("dequantize {dtype:?}: ys (len = {actual_output_len}) does not match the expected length of {expected_output_len}!") + } + + // Zip the blocks and outputs together + Ok(xs.iter().zip(ys.chunks_exact_mut(block_size)).collect()) +} + +pub(super) fn get_scale_min_k4(j: usize, q: &[u8]) -> (u8, u8) { + if j < 4 { + let d = q[j] & 63; + let m = q[j + 4] & 63; + (d, m) + } else { + let d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4); + let m = (q[j + 4] >> 4) | ((q[j] >> 6) << 4); + (d, m) + } +} + +pub(super) unsafe fn make_qx_quants( + n: usize, + nmax: i32, + x: *const f32, + ls: *mut i8, + rmse_type: i32, +) -> f32 { + let mut max = 0f32; + let mut amax = 0f32; + for i in 0..n { + let x = *x.add(i); + let ax = x.abs(); + if ax > amax { + amax = ax; + max = x; + } + } + if amax == 0. { + // all zero + for i in 0..n { + *ls.add(i) = 0; + } + return 0.; + } + let mut iscale = -(nmax as f32) / max; + if rmse_type == 0 { + for i in 0..n { + let x = *x.add(i); + let l = nearest_int(iscale * x); + *ls.add(i) = (nmax + l.clamp(-nmax, nmax - 1)) as i8; + } + return 1.0 / iscale; + } + let weight_type = rmse_type % 2; + let mut sumlx = 0f32; + let mut suml2 = 0f32; + for i in 0..n { + let x = *x.add(i); + let l = nearest_int(iscale * x); + let l = l.clamp(-nmax, nmax - 1); + *ls.add(i) = (l + nmax) as i8; + let w = if weight_type == 1 { x * x } else { 1.0 }; + let l = l as f32; + sumlx += w * x * l; + suml2 += w * l * l; + } + let mut scale = sumlx / suml2; + let mut best = scale * sumlx; + for _itry in 0..3 { + let iscale = 1.0 / scale; + let mut slx = 0f32; + let mut sl2 = 0f32; + let mut changed = false; + for i in 0..n { + let x = *x.add(i); + let l = nearest_int(iscale * x); + let l = l.clamp(-nmax, nmax - 1); + if l + nmax != *ls.add(i) as i32 { + changed = true; + } + let w = if weight_type == 1 { x * x } else { 1f32 }; + let l = l as f32; + slx += w * x * l; + sl2 += w * l * l; + } + if !changed || sl2 == 0.0 || slx * slx <= best * sl2 { + break; + } + for i in 0..n { + let x = *x.add(i); + let l = nearest_int(iscale * x); + *ls.add(i) = (nmax + l.clamp(-nmax, nmax - 1)) as i8; + } + sumlx = slx; + suml2 = sl2; + scale = sumlx / suml2; + best = scale * sumlx; + } + for _itry in 0..5 { + let mut n_changed = 0; + for i in 0..n { + let x = *x.add(i); + let w = if weight_type == 1 { x * x } else { 1. }; + let l = *ls.add(i) as i32 - nmax; + let mut slx = sumlx - w * x * l as f32; + if slx > 0. { + let mut sl2 = suml2 - w * l as f32 * l as f32; + let new_l = nearest_int(x * sl2 / slx); + let new_l = new_l.clamp(-nmax, nmax - 1); + if new_l != l { + slx += w * x * new_l as f32; + sl2 += w * new_l as f32 * new_l as f32; + if sl2 > 0. && slx * slx * suml2 > sumlx * sumlx * sl2 { + *ls.add(i) = (nmax + new_l) as i8; + sumlx = slx; + suml2 = sl2; + scale = sumlx / suml2; + best = scale * sumlx; + n_changed += 1; + } + } + } + } + if n_changed == 0 { + break; + } + } + if rmse_type < 3 { + return scale; + } + for is in -4..4 { + if is == 0 { + continue; + } + iscale = -(nmax as f32 + 0.1f32 * is as f32) / max; + let mut sumlx = 0.; + let mut suml2 = 0.; + for i in 0..n { + let x = *x.add(i); + let l = nearest_int(iscale * x); + let l = l.clamp(-nmax, nmax - 1); + let w = if weight_type == 1 { x * x } else { 1. }; + let l = l as f32; + sumlx += w * x * l; + suml2 += w * l * l; + } + if suml2 > 0. && sumlx * sumlx > best * suml2 { + for i in 0..n { + let x = *x.add(i); + let l = nearest_int(iscale * x); + *ls.add(i) = (nmax + l.clamp(-nmax, nmax - 1)) as i8; + } + scale = sumlx / suml2; + best = scale * sumlx; + } + } + scale +} + +// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L224 +pub(super) fn make_qkx1_quants(nmax: i32, ntry: usize, x: &[f32]) -> (f32, f32) { + let n = x.len(); + let mut l = vec![0; n]; + // Get min/max + let min = *x + .iter() + .take(n) + .min_by(|a, b| a.total_cmp(b)) + .unwrap_or(&x[0]); + let max = *x.iter().max_by(|a, b| a.total_cmp(b)).unwrap_or(&x[0]); + + // If min == max, all values are the same => nothing to do here + if max == min { + return (0.0, 0.0); + } + + // Ensure min <= 0.0 + let mut min = min.min(0.); + + // Compute scale and inverse scale + let mut iscale = nmax as f32 / (max - min); + let mut scale = 1.0 / iscale; + + for _ in 0..ntry { + let mut sumlx = 0.0; + let mut suml2 = 0; + let mut did_change = false; + + for (i, value) in x.iter().enumerate().take(n) { + let li = nearest_int(iscale * (value - min)).clamp(0, nmax); + let clamped_li = li as u8; + if clamped_li != l[i] { + l[i] = clamped_li; + did_change = true; + } + sumlx += (value - min) * li as f32; + suml2 += li * li; + } + scale = sumlx / suml2 as f32; + + let sum: f32 = x + .iter() + .take(n) + .zip(l.iter().take(n)) + .map(|(xi, &li)| xi - scale * li as f32) + .sum(); + + min = sum / n as f32; + if min > 0.0 { + min = 0.0; + } + iscale = 1.0 / scale; + if !did_change { + break; + } + } + (scale, -min) +} + +// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L165 +pub(super) fn make_q3_quants(x: &[f32], nmax: i32, do_rmse: bool) -> f32 { + let n = x.len(); + let mut l = vec![0i8; n]; + + let mut max = 0.0; + let mut amax = 0.0; + for &xi in x.iter().take(n) { + let ax = xi.abs(); + if ax > amax { + amax = ax; + max = xi; + } + } + + if amax == 0.0 { + return 0.0; + } + + let iscale = -(nmax as f32) / max; + if do_rmse { + let mut sumlx = 0.0; + let mut suml2 = 0.0; + for i in 0..n { + let li = (iscale * x[i]).round() as i32; + let li = li.clamp(-nmax, nmax - 1); + l[i] = li as i8; + let w = x[i] * x[i]; + sumlx += w * x[i] * li as f32; + suml2 += w * (li * li) as f32; + } + for _ in 0..5 { + let mut n_changed = 0; + for i in 0..n { + let w = x[i] * x[i]; + let mut slx = sumlx - w * x[i] * l[i] as f32; + if slx > 0.0 { + let mut sl2 = suml2 - w * (l[i] as i32 * l[i] as i32) as f32; + let mut new_l = (x[i] * sl2 / slx).round() as i32; + new_l = new_l.clamp(-nmax, nmax - 1); + if new_l != l[i] as i32 { + slx += w * x[i] * new_l as f32; + sl2 += w * (new_l * new_l) as f32; + if sl2 > 0.0 && slx * slx * suml2 > sumlx * sumlx * sl2 { + l[i] = new_l as i8; + sumlx = slx; + suml2 = sl2; + n_changed += 1; + } + } + } + } + if n_changed == 0 { + break; + } + } + for li in l.iter_mut() { + *li += nmax as i8; + } + return sumlx / suml2; + } + for i in 0..n { + let li = (iscale * x[i]).round() as i32; + l[i] = (li.clamp(-nmax, nmax - 1) + nmax) as i8; + } + 1.0 / iscale +} diff --git a/mcandle/mcandle-core/src/safetensors.rs b/mcandle/mcandle-core/src/safetensors.rs new file mode 100644 index 000000000..52df16631 --- /dev/null +++ b/mcandle/mcandle-core/src/safetensors.rs @@ -0,0 +1,474 @@ +use crate::{DType, Device, Error, Result, Tensor, WithDType}; +use float8::F8E4M3; +use safetensors::tensor as st; +use safetensors::tensor::SafeTensors; +use std::borrow::Cow; +use std::collections::HashMap; +use std::path::Path; + +impl From for st::Dtype { + fn from(value: DType) -> Self { + match value { + DType::U8 => st::Dtype::U8, + DType::U32 => st::Dtype::U32, + DType::I64 => st::Dtype::I64, + DType::I16 => st::Dtype::I16, + DType::I32 => st::Dtype::I32, + DType::BF16 => st::Dtype::BF16, + DType::F16 => st::Dtype::F16, + DType::F32 => st::Dtype::F32, + DType::F64 => st::Dtype::F64, + DType::F8E4M3 => st::Dtype::F8_E4M3, + } + } +} + +impl TryFrom for DType { + type Error = Error; + fn try_from(value: st::Dtype) -> Result { + match value { + st::Dtype::U8 => Ok(DType::U8), + st::Dtype::U32 => Ok(DType::U32), + st::Dtype::I64 => Ok(DType::I64), + st::Dtype::BF16 => Ok(DType::BF16), + st::Dtype::F16 => Ok(DType::F16), + st::Dtype::F32 => Ok(DType::F32), + st::Dtype::F64 => Ok(DType::F64), + st::Dtype::F8_E4M3 => Ok(DType::F8E4M3), + dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)), + } + } +} + +impl st::View for Tensor { + fn dtype(&self) -> st::Dtype { + self.dtype().into() + } + fn shape(&self) -> &[usize] { + self.shape().dims() + } + + fn data(&self) -> Cow<[u8]> { + // This copies data from GPU to CPU. + // TODO: Avoid the unwrap here. + Cow::Owned(convert_back(self).unwrap()) + } + + fn data_len(&self) -> usize { + let n: usize = self.shape().elem_count(); + let bytes_per_element = self.dtype().size_in_bytes(); + n * bytes_per_element + } +} + +impl st::View for &Tensor { + fn dtype(&self) -> st::Dtype { + (*self).dtype().into() + } + fn shape(&self) -> &[usize] { + self.dims() + } + + fn data(&self) -> Cow<[u8]> { + // This copies data from GPU to CPU. + // TODO: Avoid the unwrap here. + Cow::Owned(convert_back(self).unwrap()) + } + + fn data_len(&self) -> usize { + let n: usize = self.dims().iter().product(); + let bytes_per_element = (*self).dtype().size_in_bytes(); + n * bytes_per_element + } +} + +impl Tensor { + pub fn save_safetensors>(&self, name: &str, filename: P) -> Result<()> { + let data = [(name, self.clone())]; + Ok(st::serialize_to_file(data, &None, filename.as_ref())?) + } +} + +fn convert_slice(data: &[u8], shape: &[usize], device: &Device) -> Result { + let size_in_bytes = T::DTYPE.size_in_bytes(); + let elem_count = data.len() / size_in_bytes; + if (data.as_ptr() as usize) % size_in_bytes == 0 { + // SAFETY This is safe because we just checked that this + // was correctly aligned. + let data: &[T] = + unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) }; + Tensor::from_slice(data, shape, device) + } else { + // XXX: We need to specify `T` here, otherwise the compiler will infer u8 because of the following cast + // Making this vector too small to fit a full f16/f32/f64 weights, resulting in out-of-bounds access + let mut c: Vec = Vec::with_capacity(elem_count); + // SAFETY: We just created c, so the allocated memory is necessarily + // contiguous and non overlapping with the view's data. + // We're downgrading the `c` pointer from T to u8, which removes alignment + // constraints. + unsafe { + std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len()); + c.set_len(elem_count) + } + Tensor::from_slice(&c, shape, device) + } +} + +fn convert_slice_with_cast Result>( + data: &[u8], + shape: &[usize], + device: &Device, + conv: F, +) -> Result { + let size_in_bytes = std::mem::size_of::(); + let elem_count = data.len() / size_in_bytes; + if (data.as_ptr() as usize) % size_in_bytes == 0 { + // SAFETY This is safe because we just checked that this + // was correctly aligned. + let data: &[T] = + unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) }; + let data = data.iter().map(|t| conv(*t)).collect::>>()?; + Tensor::from_vec(data, shape, device) + } else { + // XXX: We need to specify `T` here, otherwise the compiler will infer u8 because of the following cast + // Making this vector too small to fit a full f16/f32/f64 weights, resulting in out-of-bounds access + let mut c: Vec = Vec::with_capacity(elem_count); + // SAFETY: We just created c, so the allocated memory is necessarily + // contiguous and non overlapping with the view's data. + // We're downgrading the `c` pointer from T to u8, which removes alignment + // constraints. + unsafe { + std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len()); + c.set_len(elem_count) + } + let c = c.into_iter().map(conv).collect::>>()?; + Tensor::from_vec(c, shape, device) + } +} + +fn convert_with_cast_ Result>( + view: &st::TensorView<'_>, + device: &Device, + conv: F, +) -> Result { + convert_slice_with_cast::(view.data(), view.shape(), device, conv) +} + +fn convert_(view: &st::TensorView<'_>, device: &Device) -> Result { + convert_slice::(view.data(), view.shape(), device) +} + +fn convert_back_(mut vs: Vec) -> Vec { + let size_in_bytes = T::DTYPE.size_in_bytes(); + let length = vs.len() * size_in_bytes; + let capacity = vs.capacity() * size_in_bytes; + let ptr = vs.as_mut_ptr() as *mut u8; + // Don't run the destructor for Vec + std::mem::forget(vs); + // SAFETY: + // + // Every T is larger than u8, so there is no issue regarding alignment. + // This re-interpret the Vec as a Vec. + unsafe { Vec::from_raw_parts(ptr, length, capacity) } +} + +pub trait Load { + fn load(&self, device: &Device) -> Result; +} + +impl<'a> Load for st::TensorView<'a> { + fn load(&self, device: &Device) -> Result { + convert(self, device) + } +} + +impl Tensor { + pub fn from_raw_buffer( + data: &[u8], + dtype: DType, + shape: &[usize], + device: &Device, + ) -> Result { + match dtype { + DType::U8 => convert_slice::(data, shape, device), + DType::U32 => convert_slice::(data, shape, device), + DType::I16 => convert_slice::(data, shape, device), + DType::I32 => convert_slice::(data, shape, device), + DType::I64 => convert_slice::(data, shape, device), + DType::BF16 => convert_slice::(data, shape, device), + DType::F16 => convert_slice::(data, shape, device), + DType::F32 => convert_slice::(data, shape, device), + DType::F64 => convert_slice::(data, shape, device), + DType::F8E4M3 => convert_slice::(data, shape, device), + } + } +} + +fn convert(view: &st::TensorView<'_>, device: &Device) -> Result { + match view.dtype() { + st::Dtype::U8 => convert_::(view, device), + st::Dtype::U16 => { + let conv = |x| Ok(u32::from(x)); + convert_with_cast_::(view, device, conv) + } + st::Dtype::U32 => convert_::(view, device), + st::Dtype::I16 => convert_::(view, device), + st::Dtype::I32 => convert_::(view, device), + st::Dtype::I64 => convert_::(view, device), + st::Dtype::BF16 => convert_::(view, device), + st::Dtype::F16 => convert_::(view, device), + st::Dtype::F32 => convert_::(view, device), + st::Dtype::F64 => convert_::(view, device), + dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)), + } +} + +fn convert_back(tensor: &Tensor) -> Result> { + // TODO: This makes an unnecessary copy when the tensor is on the cpu. + let tensor = tensor.flatten_all()?; + match tensor.dtype() { + DType::U8 => Ok(convert_back_::(tensor.to_vec1()?)), + DType::U32 => Ok(convert_back_::(tensor.to_vec1()?)), + DType::I16 => Ok(convert_back_::(tensor.to_vec1()?)), + DType::I32 => Ok(convert_back_::(tensor.to_vec1()?)), + DType::I64 => Ok(convert_back_::(tensor.to_vec1()?)), + DType::F16 => Ok(convert_back_::(tensor.to_vec1()?)), + DType::BF16 => Ok(convert_back_::(tensor.to_vec1()?)), + DType::F32 => Ok(convert_back_::(tensor.to_vec1()?)), + DType::F64 => Ok(convert_back_::(tensor.to_vec1()?)), + DType::F8E4M3 => Ok(convert_back_::(tensor.to_vec1()?)), + } +} + +pub fn load>(filename: P, device: &Device) -> Result> { + let data = std::fs::read(filename.as_ref())?; + load_buffer(&data[..], device) +} + +pub fn load_buffer(data: &[u8], device: &Device) -> Result> { + let st = safetensors::SafeTensors::deserialize(data)?; + st.tensors() + .into_iter() + .map(|(name, view)| Ok((name, view.load(device)?))) + .collect() +} + +pub fn save + Ord + std::fmt::Display, P: AsRef>( + tensors: &HashMap, + filename: P, +) -> Result<()> { + Ok(st::serialize_to_file(tensors, &None, filename.as_ref())?) +} + +#[derive(yoke::Yokeable)] +struct SafeTensors_<'a>(SafeTensors<'a>); + +pub struct MmapedSafetensors { + safetensors: Vec, memmap2::Mmap>>, + routing: Option>, +} + +impl MmapedSafetensors { + /// Creates a wrapper around a memory mapped file and deserialize the safetensors header. + /// + /// # Safety + /// + /// The unsafe is inherited from [`memmap2::MmapOptions`]. + pub unsafe fn new>(p: P) -> Result { + let p = p.as_ref(); + let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?; + let file = memmap2::MmapOptions::new() + .map(&file) + .map_err(|e| Error::from(e).with_path(p))?; + let safetensors = yoke::Yoke::, memmap2::Mmap>::try_attach_to_cart( + file, + |data: &[u8]| { + let st = safetensors::SafeTensors::deserialize(data) + .map_err(|e| Error::from(e).with_path(p))?; + Ok::<_, Error>(SafeTensors_(st)) + }, + )?; + Ok(Self { + safetensors: vec![safetensors], + routing: None, + }) + } + + /// Creates a wrapper around multiple memory mapped file and deserialize the safetensors headers. + /// + /// If a tensor name appears in multiple files, the last entry is returned. + /// + /// # Safety + /// + /// The unsafe is inherited from [`memmap2::MmapOptions`]. + pub unsafe fn multi>(paths: &[P]) -> Result { + let mut routing = HashMap::new(); + let mut safetensors = vec![]; + for (index, p) in paths.iter().enumerate() { + let p = p.as_ref(); + let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?; + let file = memmap2::MmapOptions::new() + .map(&file) + .map_err(|e| Error::from(e).with_path(p))?; + let data = yoke::Yoke::, memmap2::Mmap>::try_attach_to_cart( + file, + |data: &[u8]| { + let st = safetensors::SafeTensors::deserialize(data) + .map_err(|e| Error::from(e).with_path(p))?; + Ok::<_, Error>(SafeTensors_(st)) + }, + )?; + for k in data.get().0.names() { + routing.insert(k.to_string(), index); + } + safetensors.push(data) + } + Ok(Self { + safetensors, + routing: Some(routing), + }) + } + + pub fn load(&self, name: &str, dev: &Device) -> Result { + self.get(name)?.load(dev) + } + + pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> { + let mut tensors = vec![]; + for safetensors in self.safetensors.iter() { + tensors.push(safetensors.get().0.tensors()) + } + tensors.into_iter().flatten().collect() + } + + pub fn get(&self, name: &str) -> Result> { + let index = match &self.routing { + None => 0, + Some(routing) => { + let index = routing.get(name).ok_or_else(|| { + Error::CannotFindTensor { + path: name.to_string(), + } + .bt() + })?; + *index + } + }; + Ok(self.safetensors[index].get().0.tensor(name)?) + } +} + +pub struct SliceSafetensors<'a> { + safetensors: SafeTensors<'a>, +} + +impl<'a> SliceSafetensors<'a> { + /// Creates a wrapper around a binary buffer and deserialize the safetensors header. + pub fn new(buffer: &'a [u8]) -> Result { + let safetensors = safetensors::SafeTensors::deserialize(buffer)?; + Ok(Self { safetensors }) + } + + pub fn load(&self, name: &str, dev: &Device) -> Result { + self.safetensors.tensor(name)?.load(dev) + } + + pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> { + self.safetensors.tensors() + } + + pub fn get(&self, name: &str) -> Result> { + Ok(self.safetensors.tensor(name)?) + } +} + +pub struct BufferedSafetensors { + safetensors: yoke::Yoke, Vec>, +} + +impl BufferedSafetensors { + /// Creates a wrapper around a binary buffer and deserialize the safetensors header. + pub fn new(buffer: Vec) -> Result { + let safetensors = yoke::Yoke::, Vec>::try_attach_to_cart( + buffer, + |data: &[u8]| { + let st = safetensors::SafeTensors::deserialize(data)?; + Ok::<_, Error>(SafeTensors_(st)) + }, + )?; + Ok(Self { safetensors }) + } + + pub fn load(&self, name: &str, dev: &Device) -> Result { + self.get(name)?.load(dev) + } + + pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> { + self.safetensors.get().0.tensors() + } + + pub fn get(&self, name: &str) -> Result> { + Ok(self.safetensors.get().0.tensor(name)?) + } +} + +pub struct MmapedFile { + path: std::path::PathBuf, + inner: memmap2::Mmap, +} + +impl MmapedFile { + /// Creates a wrapper around a memory mapped file from which you can retrieve + /// tensors using [`MmapedFile::deserialize`] + /// + /// # Safety + /// + /// The unsafe is inherited from [`memmap2::MmapOptions`]. + pub unsafe fn new>(p: P) -> Result { + let p = p.as_ref(); + let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?; + let inner = memmap2::MmapOptions::new() + .map(&file) + .map_err(|e| Error::from(e).with_path(p))?; + Ok(Self { + inner, + path: p.to_path_buf(), + }) + } + + pub fn deserialize(&self) -> Result> { + let st = safetensors::SafeTensors::deserialize(&self.inner) + .map_err(|e| Error::from(e).with_path(&self.path))?; + Ok(st) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::collections::HashMap; + + #[test] + fn save_single_tensor() { + let t = Tensor::zeros((2, 2), DType::F32, &Device::Cpu).unwrap(); + t.save_safetensors("t", "t.safetensors").unwrap(); + let bytes = std::fs::read("t.safetensors").unwrap(); + assert_eq!(bytes, b"@\0\0\0\0\0\0\0{\"t\":{\"dtype\":\"F32\",\"shape\":[2,2],\"data_offsets\":[0,16]}} \0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0"); + std::fs::remove_file("t.safetensors").unwrap(); + } + + #[test] + fn save_load_multiple_tensors() { + let t = Tensor::zeros((2, 2), DType::F32, &Device::Cpu).unwrap(); + let u = Tensor::zeros((1, 2), DType::F32, &Device::Cpu).unwrap(); + let map: HashMap<_, _> = [("t", t), ("u", u)].into_iter().collect(); + save(&map, "multi.safetensors").unwrap(); + + let weights = load("multi.safetensors", &Device::Cpu).unwrap(); + assert_eq!(weights.get("t").unwrap().dims(), &[2, 2]); + assert_eq!(weights.get("u").unwrap().dims(), &[1, 2]); + let bytes = std::fs::read("multi.safetensors").unwrap(); + assert_eq!(bytes, b"x\0\0\0\0\0\0\0{\"t\":{\"dtype\":\"F32\",\"shape\":[2,2],\"data_offsets\":[0,16]},\"u\":{\"dtype\":\"F32\",\"shape\":[1,2],\"data_offsets\":[16,24]}} \0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0"); + std::fs::remove_file("multi.safetensors").unwrap(); + } +} diff --git a/mcandle/mcandle-core/src/scalar.rs b/mcandle/mcandle-core/src/scalar.rs new file mode 100644 index 000000000..43e1f4c8c --- /dev/null +++ b/mcandle/mcandle-core/src/scalar.rs @@ -0,0 +1,23 @@ +use crate::{Result, Tensor, WithDType}; + +pub enum TensorScalar { + Tensor(Tensor), + Scalar(Tensor), +} + +pub trait TensorOrScalar { + fn to_tensor_scalar(self) -> Result; +} + +impl TensorOrScalar for &Tensor { + fn to_tensor_scalar(self) -> Result { + Ok(TensorScalar::Tensor(self.clone())) + } +} + +impl TensorOrScalar for T { + fn to_tensor_scalar(self) -> Result { + let scalar = Tensor::new(self, &crate::Device::Cpu)?; + Ok(TensorScalar::Scalar(scalar)) + } +} diff --git a/mcandle/mcandle-core/src/shape.rs b/mcandle/mcandle-core/src/shape.rs new file mode 100644 index 000000000..ca05d216a --- /dev/null +++ b/mcandle/mcandle-core/src/shape.rs @@ -0,0 +1,639 @@ +//! The shape of a tensor is a tuple with the size of each of its dimensions. +#![allow(clippy::redundant_closure_call)] +use crate::{Error, Result}; + +#[derive(Clone, PartialEq, Eq)] +pub struct Shape(Vec); + +pub const SCALAR: Shape = Shape(vec![]); + +impl std::fmt::Debug for Shape { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", &self.dims()) + } +} + +impl From<&[usize; C]> for Shape { + fn from(dims: &[usize; C]) -> Self { + Self(dims.to_vec()) + } +} + +impl From<&[usize]> for Shape { + fn from(dims: &[usize]) -> Self { + Self(dims.to_vec()) + } +} + +impl From<&Shape> for Shape { + fn from(shape: &Shape) -> Self { + Self(shape.0.to_vec()) + } +} + +impl From<()> for Shape { + fn from(_: ()) -> Self { + Self(vec![]) + } +} + +impl From for Shape { + fn from(d1: usize) -> Self { + Self(vec![d1]) + } +} + +impl From<(usize,)> for Shape { + fn from(d1: (usize,)) -> Self { + Self(vec![d1.0]) + } +} + +impl From<(usize, usize)> for Shape { + fn from(d12: (usize, usize)) -> Self { + Self(vec![d12.0, d12.1]) + } +} + +impl From<(usize, usize, usize)> for Shape { + fn from(d123: (usize, usize, usize)) -> Self { + Self(vec![d123.0, d123.1, d123.2]) + } +} + +impl From<(usize, usize, usize, usize)> for Shape { + fn from(d1234: (usize, usize, usize, usize)) -> Self { + Self(vec![d1234.0, d1234.1, d1234.2, d1234.3]) + } +} + +impl From<(usize, usize, usize, usize, usize)> for Shape { + fn from(d12345: (usize, usize, usize, usize, usize)) -> Self { + Self(vec![d12345.0, d12345.1, d12345.2, d12345.3, d12345.4]) + } +} + +impl From<(usize, usize, usize, usize, usize, usize)> for Shape { + fn from(d123456: (usize, usize, usize, usize, usize, usize)) -> Self { + Self(vec![ + d123456.0, d123456.1, d123456.2, d123456.3, d123456.4, d123456.5, + ]) + } +} + +impl From> for Shape { + fn from(dims: Vec) -> Self { + Self(dims) + } +} + +macro_rules! extract_dims { + ($fn_name:ident, $cnt:tt, $dims:expr, $out_type:ty) => { + pub fn $fn_name(dims: &[usize]) -> Result<$out_type> { + if dims.len() != $cnt { + Err(Error::UnexpectedNumberOfDims { + expected: $cnt, + got: dims.len(), + shape: Shape::from(dims), + } + .bt()) + } else { + Ok($dims(dims)) + } + } + + impl Shape { + pub fn $fn_name(&self) -> Result<$out_type> { + $fn_name(self.0.as_slice()) + } + } + + impl crate::Tensor { + pub fn $fn_name(&self) -> Result<$out_type> { + self.shape().$fn_name() + } + } + + impl std::convert::TryInto<$out_type> for Shape { + type Error = crate::Error; + fn try_into(self) -> std::result::Result<$out_type, Self::Error> { + self.$fn_name() + } + } + }; +} + +impl Shape { + pub fn from_dims(dims: &[usize]) -> Self { + Self(dims.to_vec()) + } + + /// The rank is the number of dimensions, 0 for a scalar value, 1 for a vector, etc. + pub fn rank(&self) -> usize { + self.0.len() + } + + pub fn into_dims(self) -> Vec { + self.0 + } + + /// The dimensions as a slice of `usize`. + pub fn dims(&self) -> &[usize] { + &self.0 + } + + /// The dimension size for a specified dimension index. + pub fn dim(&self, dim: D) -> Result { + let dim = dim.to_index(self, "dim")?; + Ok(self.dims()[dim]) + } + + /// The total number of elements, this is the product of all dimension sizes. + pub fn elem_count(&self) -> usize { + self.0.iter().product() + } + + /// The strides given in number of elements for a contiguous n-dimensional + /// arrays using this shape. + pub(crate) fn stride_contiguous(&self) -> Vec { + let mut stride: Vec<_> = self + .0 + .iter() + .rev() + .scan(1, |prod, u| { + let prod_pre_mult = *prod; + *prod *= u; + Some(prod_pre_mult) + }) + .collect(); + stride.reverse(); + stride + } + + /// Returns true if the strides are C contiguous (aka row major). + pub fn is_contiguous(&self, stride: &[usize]) -> bool { + if self.0.len() != stride.len() { + return false; + } + let mut acc = 1; + for (&stride, &dim) in stride.iter().zip(self.0.iter()).rev() { + if dim > 1 && stride != acc { + return false; + } + acc *= dim; + } + true + } + + /// Returns true if the strides are Fortran contiguous (aka column major). + pub fn is_fortran_contiguous(&self, stride: &[usize]) -> bool { + if self.0.len() != stride.len() { + return false; + } + let mut acc = 1; + for (&stride, &dim) in stride.iter().zip(self.0.iter()) { + if dim > 1 && stride != acc { + return false; + } + acc *= dim; + } + true + } + + /// Modifies the shape by adding a list of additional dimensions at the end of the existing + /// dimensions. + pub fn extend(mut self, additional_dims: &[usize]) -> Self { + self.0.extend(additional_dims); + self + } + + /// Check whether the two shapes are compatible for broadcast, and if it is the case return the + /// broadcasted shape. This is to be used for binary pointwise ops. + pub fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result { + let lhs = self; + let lhs_dims = lhs.dims(); + let rhs_dims = rhs.dims(); + let lhs_ndims = lhs_dims.len(); + let rhs_ndims = rhs_dims.len(); + let bcast_ndims = usize::max(lhs_ndims, rhs_ndims); + let mut bcast_dims = vec![0; bcast_ndims]; + for (idx, bcast_value) in bcast_dims.iter_mut().enumerate() { + let rev_idx = bcast_ndims - idx; + let l_value = if lhs_ndims < rev_idx { + 1 + } else { + lhs_dims[lhs_ndims - rev_idx] + }; + let r_value = if rhs_ndims < rev_idx { + 1 + } else { + rhs_dims[rhs_ndims - rev_idx] + }; + *bcast_value = if l_value == r_value { + l_value + } else if l_value == 1 { + r_value + } else if r_value == 1 { + l_value + } else { + Err(Error::ShapeMismatchBinaryOp { + lhs: lhs.clone(), + rhs: rhs.clone(), + op, + } + .bt())? + } + } + Ok(Shape::from(bcast_dims)) + } + + pub(crate) fn broadcast_shape_matmul(&self, rhs: &Self) -> Result<(Shape, Shape)> { + let lhs = self; + let lhs_dims = lhs.dims(); + let rhs_dims = rhs.dims(); + if lhs_dims.len() < 2 || rhs_dims.len() < 2 { + crate::bail!("only 2d matrixes are supported {lhs:?} {rhs:?}") + } + let (m, lhs_k) = (lhs_dims[lhs_dims.len() - 2], lhs_dims[lhs_dims.len() - 1]); + let (rhs_k, n) = (rhs_dims[rhs_dims.len() - 2], rhs_dims[rhs_dims.len() - 1]); + if lhs_k != rhs_k { + crate::bail!("different inner dimensions in broadcast matmul {lhs:?} {rhs:?}") + } + + let lhs_b = Self::from(&lhs_dims[..lhs_dims.len() - 2]); + let rhs_b = Self::from(&rhs_dims[..rhs_dims.len() - 2]); + let bcast = lhs_b.broadcast_shape_binary_op(&rhs_b, "broadcast_matmul")?; + let bcast_dims = bcast.dims(); + + let bcast_lhs = [bcast_dims, &[m, lhs_k]].concat(); + let bcast_rhs = [bcast_dims, &[rhs_k, n]].concat(); + Ok((Shape::from(bcast_lhs), Shape::from(bcast_rhs))) + } +} + +pub trait Dim { + fn to_index(&self, shape: &Shape, op: &'static str) -> Result; + fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result; +} + +impl Dim for usize { + fn to_index(&self, shape: &Shape, op: &'static str) -> Result { + let dim = *self; + if dim >= shape.dims().len() { + Err(Error::DimOutOfRange { + shape: shape.clone(), + dim: dim as i32, + op, + } + .bt())? + } else { + Ok(dim) + } + } + + fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result { + let dim = *self; + if dim > shape.dims().len() { + Err(Error::DimOutOfRange { + shape: shape.clone(), + dim: dim as i32, + op, + } + .bt())? + } else { + Ok(dim) + } + } +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub enum D { + Minus1, + Minus2, + Minus(usize), +} + +impl D { + fn out_of_range(&self, shape: &Shape, op: &'static str) -> Error { + let dim = match self { + Self::Minus1 => -1, + Self::Minus2 => -2, + Self::Minus(u) => -(*u as i32), + }; + Error::DimOutOfRange { + shape: shape.clone(), + dim, + op, + } + .bt() + } +} + +impl Dim for D { + fn to_index(&self, shape: &Shape, op: &'static str) -> Result { + let rank = shape.rank(); + match self { + Self::Minus1 if rank >= 1 => Ok(rank - 1), + Self::Minus2 if rank >= 2 => Ok(rank - 2), + Self::Minus(u) if *u > 0 && rank >= *u => Ok(rank - *u), + _ => Err(self.out_of_range(shape, op)), + } + } + + fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result { + let rank = shape.rank(); + match self { + Self::Minus1 => Ok(rank), + Self::Minus2 if rank >= 1 => Ok(rank - 1), + Self::Minus(u) if *u > 0 && rank + 1 >= *u => Ok(rank + 1 - *u), + _ => Err(self.out_of_range(shape, op)), + } + } +} + +pub trait Dims: Sized { + fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result>; + + fn to_indexes(self, shape: &Shape, op: &'static str) -> Result> { + let dims = self.to_indexes_internal(shape, op)?; + for (i, &dim) in dims.iter().enumerate() { + if dims[..i].contains(&dim) { + Err(Error::DuplicateDimIndex { + shape: shape.clone(), + dims: dims.clone(), + op, + } + .bt())? + } + if dim >= shape.rank() { + Err(Error::DimOutOfRange { + shape: shape.clone(), + dim: dim as i32, + op, + } + .bt())? + } + } + Ok(dims) + } +} + +impl Dims for Vec { + fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result> { + Ok(self) + } +} + +impl Dims for [usize; N] { + fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result> { + Ok(self.to_vec()) + } +} + +impl Dims for &[usize] { + fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result> { + Ok(self.to_vec()) + } +} + +impl Dims for () { + fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result> { + Ok(vec![]) + } +} + +impl Dims for D { + fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result> { + let dim = self.to_index(shape, op)?; + Ok(vec![dim]) + } +} + +impl Dims for (D,) { + fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result> { + let dim = self.0.to_index(shape, op)?; + Ok(vec![dim]) + } +} + +impl Dims for (D1, D2) { + fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result> { + let d0 = self.0.to_index(shape, op)?; + let d1 = self.1.to_index(shape, op)?; + Ok(vec![d0, d1]) + } +} + +impl Dims for (D1, D2, D3) { + fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result> { + let d0 = self.0.to_index(shape, op)?; + let d1 = self.1.to_index(shape, op)?; + let d2 = self.2.to_index(shape, op)?; + Ok(vec![d0, d1, d2]) + } +} + +impl Dims for (D1, D2, D3, D4) { + fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result> { + let d0 = self.0.to_index(shape, op)?; + let d1 = self.1.to_index(shape, op)?; + let d2 = self.2.to_index(shape, op)?; + let d3 = self.3.to_index(shape, op)?; + Ok(vec![d0, d1, d2, d3]) + } +} + +impl Dims for (D1, D2, D3, D4, D5) { + fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result> { + let d0 = self.0.to_index(shape, op)?; + let d1 = self.1.to_index(shape, op)?; + let d2 = self.2.to_index(shape, op)?; + let d3 = self.3.to_index(shape, op)?; + let d4 = self.4.to_index(shape, op)?; + Ok(vec![d0, d1, d2, d3, d4]) + } +} + +impl Dims for (D1, D2, D3, D4, D5, D6) { + fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result> { + let d0 = self.0.to_index(shape, op)?; + let d1 = self.1.to_index(shape, op)?; + let d2 = self.2.to_index(shape, op)?; + let d3 = self.3.to_index(shape, op)?; + let d4 = self.4.to_index(shape, op)?; + let d5 = self.5.to_index(shape, op)?; + Ok(vec![d0, d1, d2, d3, d4, d5]) + } +} + +extract_dims!(dims0, 0, |_: &[usize]| (), ()); +extract_dims!(dims1, 1, |d: &[usize]| d[0], usize); +extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize)); +extract_dims!( + dims3, + 3, + |d: &[usize]| (d[0], d[1], d[2]), + (usize, usize, usize) +); +extract_dims!( + dims4, + 4, + |d: &[usize]| (d[0], d[1], d[2], d[3]), + (usize, usize, usize, usize) +); +extract_dims!( + dims5, + 5, + |d: &[usize]| (d[0], d[1], d[2], d[3], d[4]), + (usize, usize, usize, usize, usize) +); + +pub trait ShapeWithOneHole { + fn into_shape(self, el_count: usize) -> Result; +} + +impl> ShapeWithOneHole for S { + fn into_shape(self, _el_count: usize) -> Result { + Ok(self.into()) + } +} + +impl ShapeWithOneHole for ((),) { + fn into_shape(self, el_count: usize) -> Result { + Ok(el_count.into()) + } +} + +fn hole_size(el_count: usize, prod_d: usize, s: &dyn std::fmt::Debug) -> Result { + if prod_d == 0 { + crate::bail!("cannot reshape tensor of {el_count} elements to {s:?}") + } + if el_count % prod_d != 0 { + crate::bail!("cannot reshape tensor with {el_count} elements to {s:?}") + } + Ok(el_count / prod_d) +} + +impl ShapeWithOneHole for ((), usize) { + fn into_shape(self, el_count: usize) -> Result { + let ((), d1) = self; + Ok((hole_size(el_count, d1, &self)?, d1).into()) + } +} + +impl ShapeWithOneHole for (usize, ()) { + fn into_shape(self, el_count: usize) -> Result { + let (d1, ()) = self; + Ok((d1, hole_size(el_count, d1, &self)?).into()) + } +} + +impl ShapeWithOneHole for ((), usize, usize) { + fn into_shape(self, el_count: usize) -> Result { + let ((), d1, d2) = self; + Ok((hole_size(el_count, d1 * d2, &self)?, d1, d2).into()) + } +} + +impl ShapeWithOneHole for (usize, (), usize) { + fn into_shape(self, el_count: usize) -> Result { + let (d1, (), d2) = self; + Ok((d1, hole_size(el_count, d1 * d2, &self)?, d2).into()) + } +} + +impl ShapeWithOneHole for (usize, usize, ()) { + fn into_shape(self, el_count: usize) -> Result { + let (d1, d2, ()) = self; + Ok((d1, d2, hole_size(el_count, d1 * d2, &self)?).into()) + } +} + +impl ShapeWithOneHole for ((), usize, usize, usize) { + fn into_shape(self, el_count: usize) -> Result { + let ((), d1, d2, d3) = self; + let d = hole_size(el_count, d1 * d2 * d3, &self)?; + Ok((d, d1, d2, d3).into()) + } +} + +impl ShapeWithOneHole for (usize, (), usize, usize) { + fn into_shape(self, el_count: usize) -> Result { + let (d1, (), d2, d3) = self; + let d = hole_size(el_count, d1 * d2 * d3, &self)?; + Ok((d1, d, d2, d3).into()) + } +} + +impl ShapeWithOneHole for (usize, usize, (), usize) { + fn into_shape(self, el_count: usize) -> Result { + let (d1, d2, (), d3) = self; + let d = hole_size(el_count, d1 * d2 * d3, &self)?; + Ok((d1, d2, d, d3).into()) + } +} + +impl ShapeWithOneHole for (usize, usize, usize, ()) { + fn into_shape(self, el_count: usize) -> Result { + let (d1, d2, d3, ()) = self; + let d = hole_size(el_count, d1 * d2 * d3, &self)?; + Ok((d1, d2, d3, d).into()) + } +} + +impl ShapeWithOneHole for ((), usize, usize, usize, usize) { + fn into_shape(self, el_count: usize) -> Result { + let ((), d1, d2, d3, d4) = self; + let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?; + Ok((d, d1, d2, d3, d4).into()) + } +} + +impl ShapeWithOneHole for (usize, (), usize, usize, usize) { + fn into_shape(self, el_count: usize) -> Result { + let (d1, (), d2, d3, d4) = self; + let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?; + Ok((d1, d, d2, d3, d4).into()) + } +} + +impl ShapeWithOneHole for (usize, usize, (), usize, usize) { + fn into_shape(self, el_count: usize) -> Result { + let (d1, d2, (), d3, d4) = self; + let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?; + Ok((d1, d2, d, d3, d4).into()) + } +} + +impl ShapeWithOneHole for (usize, usize, usize, (), usize) { + fn into_shape(self, el_count: usize) -> Result { + let (d1, d2, d3, (), d4) = self; + let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?; + Ok((d1, d2, d3, d, d4).into()) + } +} + +impl ShapeWithOneHole for (usize, usize, usize, usize, ()) { + fn into_shape(self, el_count: usize) -> Result { + let (d1, d2, d3, d4, ()) = self; + let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?; + Ok((d1, d2, d3, d4, d).into()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn stride() { + let shape = Shape::from(()); + assert_eq!(shape.stride_contiguous(), Vec::::new()); + let shape = Shape::from(42); + assert_eq!(shape.stride_contiguous(), [1]); + let shape = Shape::from((42, 1337)); + assert_eq!(shape.stride_contiguous(), [1337, 1]); + let shape = Shape::from((299, 792, 458)); + assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]); + } +} diff --git a/mcandle/mcandle-core/src/sort.rs b/mcandle/mcandle-core/src/sort.rs new file mode 100644 index 000000000..491b5eb0b --- /dev/null +++ b/mcandle/mcandle-core/src/sort.rs @@ -0,0 +1,248 @@ +use crate::{Result, Tensor}; +use rayon::prelude::*; + +#[derive(Debug, Clone, Copy)] +struct ArgSort { + asc: bool, + last_dim: usize, +} + +impl ArgSort { + fn asort(&self, vs: &[T], layout: &crate::Layout) -> Vec { + #[allow(clippy::uninit_vec)] + // Safety: indexes are set later in the parallelized section. + let mut sort_indexes = unsafe { + let el_count = layout.shape().elem_count(); + let mut v = Vec::with_capacity(el_count); + v.set_len(el_count); + v + }; + if self.asc { + sort_indexes + .par_chunks_exact_mut(self.last_dim) + .zip(vs.par_chunks_exact(self.last_dim)) + .for_each(|(indexes, vs)| { + indexes + .iter_mut() + .enumerate() + .for_each(|(i, v)| *v = i as u32); + indexes.sort_by(|&i, &j| { + vs[i as usize] + .partial_cmp(&vs[j as usize]) + .unwrap_or(std::cmp::Ordering::Greater) + }) + }); + } else { + sort_indexes + .par_chunks_exact_mut(self.last_dim) + .zip(vs.par_chunks_exact(self.last_dim)) + .for_each(|(indexes, vs)| { + indexes + .iter_mut() + .enumerate() + .for_each(|(i, v)| *v = i as u32); + indexes.sort_by(|&j, &i| { + vs[i as usize] + .partial_cmp(&vs[j as usize]) + .unwrap_or(std::cmp::Ordering::Greater) + }) + }); + } + sort_indexes + } +} + +impl crate::CustomOp1 for ArgSort { + fn name(&self) -> &'static str { + "argsort" + } + + fn cpu_fwd( + &self, + storage: &crate::CpuStorage, + layout: &crate::Layout, + ) -> Result<(crate::CpuStorage, crate::Shape)> { + let sort_indexes = match storage { + crate::CpuStorage::U8(vs) => self.asort(vs, layout), + crate::CpuStorage::U32(vs) => self.asort(vs, layout), + crate::CpuStorage::I16(vs) => self.asort(vs, layout), + crate::CpuStorage::I32(vs) => self.asort(vs, layout), + crate::CpuStorage::I64(vs) => self.asort(vs, layout), + crate::CpuStorage::BF16(vs) => self.asort(vs, layout), + crate::CpuStorage::F16(vs) => self.asort(vs, layout), + crate::CpuStorage::F32(vs) => self.asort(vs, layout), + crate::CpuStorage::F64(vs) => self.asort(vs, layout), + crate::CpuStorage::F8E4M3(vs) => self.asort(vs, layout), + }; + let sort_indexes = crate::CpuStorage::U32(sort_indexes); + Ok((sort_indexes, layout.shape().into())) + } + + #[cfg(feature = "cuda")] + fn cuda_fwd( + &self, + storage: &crate::CudaStorage, + layout: &crate::Layout, + ) -> Result<(crate::CudaStorage, crate::Shape)> { + use crate::cuda_backend::cudarc::driver::{ + CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits, + }; + use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, Map1Any, WrapErr}; + use crate::{CudaDevice, WithDType}; + + impl Map1Any for ArgSort { + fn f) -> S>( + &self, + src: &CudaSlice, + dev: &CudaDevice, + layout: &crate::Layout, + _wrap: W, + ) -> Result { + let slice = match layout.contiguous_offsets() { + None => crate::bail!("input has to be contiguous"), + Some((o1, o2)) => src.slice(o1..o2), + }; + let elem_count = layout.shape().elem_count(); + let dst = unsafe { dev.alloc::(elem_count) }.w()?; + let func = if self.asc { + dev.get_or_load_func(&kernel_name::("asort_asc"), kernels::SORT)? + } else { + dev.get_or_load_func(&kernel_name::("asort_desc"), kernels::SORT)? + }; + let ncols = self.last_dim; + let nrows = elem_count / ncols; + let ncols_pad = next_power_of_2(ncols); + let params = (&slice, &dst, ncols as i32, ncols_pad as i32); + let cfg = LaunchConfig { + grid_dim: (1, nrows as u32, 1), + block_dim: (ncols_pad as u32, 1, 1), + shared_mem_bytes: (ncols_pad * std::mem::size_of::()) as u32, + }; + unsafe { func.launch(cfg, params) }.w()?; + Ok(S::U32(dst)) + } + } + + use crate::backend::BackendStorage; + let dev = storage.device(); + let slice = self.map(&storage.slice, dev, layout)?; + let dst = crate::cuda_backend::CudaStorage { + slice, + device: dev.clone(), + }; + Ok((dst, layout.shape().clone())) + } + + #[cfg(feature = "metal")] + fn metal_fwd( + &self, + storage: &crate::MetalStorage, + layout: &crate::Layout, + ) -> Result<(crate::MetalStorage, crate::Shape)> { + use crate::backend::BackendStorage; + use crate::DType; + + let name = { + if self.asc { + match storage.dtype() { + DType::BF16 => "asort_asc_bf16", + DType::F16 => "asort_asc_f16", + DType::F32 => "asort_asc_f32", + DType::F64 => "asort_asc_f64", + DType::U8 => "asort_asc_u8", + DType::U32 => "asort_asc_u32", + DType::I64 => "asort_asc_i64", + DType::I32 => "asort_asc_i32", + DType::I16 => "asort_asc_i16", + DType::F8E4M3 => crate::bail!("Metal device does not yet support F8E4M3."), + } + } else { + match storage.dtype() { + DType::BF16 => "asort_desc_bf16", + DType::F16 => "asort_desc_f16", + DType::F32 => "asort_desc_f32", + DType::F64 => "asort_desc_f64", + DType::U8 => "asort_desc_u8", + DType::U32 => "asort_desc_u32", + DType::I64 => "asort_desc_i64", + DType::I32 => "asort_desc_i32", + DType::I16 => "asort_desc_i16", + DType::F8E4M3 => crate::bail!("Metal device does not yet support F8E4M3."), + } + } + }; + let device = storage.device(); + let kernels = device.kernels(); + let command_buffer = device.command_buffer()?; + let el = layout.shape().elem_count(); + let ncols = self.last_dim; + let nrows = el / ncols; + let src = crate::metal_backend::buffer_o(storage.buffer(), layout, storage.dtype()); + let dst = device.new_buffer(el, DType::U32, "asort")?; + let mut ncols_pad = 1; + while ncols_pad < ncols { + ncols_pad *= 2; + } + mcandle_metal_kernels::call_arg_sort( + device.metal_device(), + &command_buffer, + kernels, + name, + nrows, + ncols, + ncols_pad, + src, + &dst, + ) + .map_err(crate::Error::wrap)?; + let dst = crate::MetalStorage::new(dst, device.clone(), el, DType::U32); + Ok((dst, layout.shape().clone())) + } +} + +#[allow(unused)] +fn next_power_of_2(x: usize) -> usize { + let mut n = 1; + while n < x { + n *= 2 + } + n +} + +impl Tensor { + /// Returns the indices that sort the tensor along the last dimension. + /// + /// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in + /// descending order. The sort is unstable so there is no guarantees on the final order when it + /// comes to ties. + pub fn arg_sort_last_dim(&self, asc: bool) -> Result { + if !self.is_contiguous() { + return Err(crate::Error::RequiresContiguous { + op: "arg_sort_last_dim", + }); + } + let last_dim = match self.dims().last() { + None => crate::bail!("empty last-dim in arg-sort"), + Some(last_dim) => *last_dim, + }; + // No need for a backward pass for arg sort. + self.apply_op1_no_bwd(&ArgSort { asc, last_dim }) + } + + /// Sorts the tensor along the last dimension, returns the sorted tensor together with the + /// sorted indexes. + /// + /// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in + /// descending order. The sort is unstable so there is no guarantees on the final order when it + /// comes to ties. + pub fn sort_last_dim(&self, asc: bool) -> Result<(Tensor, Tensor)> { + if !self.is_contiguous() { + return Err(crate::Error::RequiresContiguous { + op: "sort_last_dim", + }); + } + let asort = self.arg_sort_last_dim(asc)?; + let sorted = self.gather(&asort, crate::D::Minus1)?; + Ok((sorted, asort)) + } +} diff --git a/mcandle/mcandle-core/src/storage.rs b/mcandle/mcandle-core/src/storage.rs new file mode 100644 index 000000000..8ff1cbf82 --- /dev/null +++ b/mcandle/mcandle-core/src/storage.rs @@ -0,0 +1,824 @@ +use crate::backend::BackendStorage; +use crate::op::{self, CmpOp, ReduceOp}; +use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape}; +use crate::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3}; + +// We do not want to implement Clone on Storage as cloning may fail because of +// out of memory. Instead try_clone should be used. +#[derive(Debug)] +pub enum Storage { + Cpu(CpuStorage), + Cuda(CudaStorage), + Metal(MetalStorage), +} + +impl Storage { + pub fn try_clone(&self, layout: &Layout) -> Result { + match self { + Self::Cpu(storage) => Ok(Self::Cpu(storage.clone())), + Self::Cuda(storage) => { + let storage = storage.try_clone(layout)?; + Ok(Self::Cuda(storage)) + } + Self::Metal(storage) => { + let storage = storage.try_clone(layout)?; + Ok(Self::Metal(storage)) + } + } + } + + pub fn device(&self) -> Device { + match self { + Self::Cpu(_) => Device::Cpu, + Self::Cuda(storage) => Device::Cuda(storage.device().clone()), + Self::Metal(storage) => Device::Metal(storage.device().clone()), + } + } + + pub fn dtype(&self) -> DType { + match self { + Self::Cpu(storage) => storage.dtype(), + Self::Cuda(storage) => storage.dtype(), + Self::Metal(storage) => storage.dtype(), + } + } + + pub(crate) fn same_device(&self, rhs: &Self, op: &'static str) -> Result<()> { + let lhs_device = self.device(); + let rhs_device = rhs.device(); + let lhs = lhs_device.location(); + let rhs = rhs_device.location(); + let same_device = if self.device().is_metal() { + // On metal, we require the device to be exactly the same rather than + // having the same location. In cuda this is not necessary as all CudaDevice on the + // same GPU will use the same cuda stream. + lhs_device.same_device(&rhs_device) + } else { + lhs == rhs + }; + if !same_device { + Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op }.bt()) + } else { + Ok(()) + } + } + + pub(crate) fn same_dtype(&self, rhs: &Self, op: &'static str) -> Result<()> { + let lhs = self.dtype(); + let rhs = rhs.dtype(); + if lhs != rhs { + Err(Error::DTypeMismatchBinaryOp { lhs, rhs, op }.bt()) + } else { + Ok(()) + } + } + + pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result { + match self { + Storage::Cpu(storage) => { + let storage = storage.affine(layout, mul, add)?; + Ok(Self::Cpu(storage)) + } + Self::Cuda(storage) => { + let storage = storage.affine(layout, mul, add)?; + Ok(Self::Cuda(storage)) + } + Self::Metal(storage) => { + let storage = storage.affine(layout, mul, add)?; + Ok(Self::Metal(storage)) + } + } + } + + pub(crate) fn powf(&self, layout: &Layout, alpha: f64) -> Result { + match self { + Storage::Cpu(storage) => { + let storage = storage.powf(layout, alpha)?; + Ok(Self::Cpu(storage)) + } + Self::Cuda(storage) => { + let storage = storage.powf(layout, alpha)?; + Ok(Self::Cuda(storage)) + } + Self::Metal(storage) => { + let storage = storage.powf(layout, alpha)?; + Ok(Self::Metal(storage)) + } + } + } + + pub(crate) fn elu(&self, layout: &Layout, alpha: f64) -> Result { + match self { + Storage::Cpu(storage) => { + let storage = storage.elu(layout, alpha)?; + Ok(Self::Cpu(storage)) + } + Self::Cuda(storage) => { + let storage = storage.elu(layout, alpha)?; + Ok(Self::Cuda(storage)) + } + Self::Metal(storage) => { + let storage = storage.elu(layout, alpha)?; + Ok(Self::Metal(storage)) + } + } + } + + pub(crate) fn cmp( + &self, + op: CmpOp, + rhs: &Self, + lhs_layout: &Layout, + rhs_layout: &Layout, + ) -> Result { + self.same_device(rhs, "cmp")?; + self.same_dtype(rhs, "cmp")?; + match (self, rhs) { + (Storage::Cpu(lhs), Storage::Cpu(rhs)) => { + let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?; + Ok(Self::Cpu(storage)) + } + (Self::Cuda(lhs), Self::Cuda(rhs)) => { + let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?; + Ok(Self::Cuda(storage)) + } + (Self::Metal(lhs), Self::Metal(rhs)) => { + let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?; + Ok(Self::Metal(storage)) + } + (lhs, rhs) => { + // Should not happen because of the same device check above but we're defensive + // anyway. + Err(Error::DeviceMismatchBinaryOp { + lhs: lhs.device().location(), + rhs: rhs.device().location(), + op: "cmp", + } + .bt()) + } + } + } + + pub(crate) fn reduce_op(&self, op: ReduceOp, layout: &Layout, s: &[usize]) -> Result { + match self { + Storage::Cpu(storage) => { + let storage = storage.reduce_op(op, layout, s)?; + Ok(Self::Cpu(storage)) + } + Self::Cuda(storage) => { + let storage = storage.reduce_op(op, layout, s)?; + Ok(Self::Cuda(storage)) + } + Self::Metal(storage) => { + let storage = storage.reduce_op(op, layout, s)?; + Ok(Self::Metal(storage)) + } + } + } + + pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result { + match self { + Storage::Cpu(storage) => { + let storage = storage.to_dtype(layout, dtype)?; + Ok(Self::Cpu(storage)) + } + Self::Cuda(storage) => { + let storage = storage.to_dtype(layout, dtype)?; + Ok(Self::Cuda(storage)) + } + Self::Metal(storage) => { + let storage = storage.to_dtype(layout, dtype)?; + Ok(Self::Metal(storage)) + } + } + } + + pub(crate) fn apply_op1(&self, l: &Layout, c: &dyn CustomOp1) -> Result<(Self, Shape)> { + match self { + Self::Cpu(storage) => { + let (storage, shape) = c.cpu_fwd(storage, l)?; + Ok((Self::Cpu(storage), shape)) + } + Self::Cuda(storage) => { + let (storage, shape) = c.cuda_fwd(storage, l)?; + Ok((Self::Cuda(storage), shape)) + } + Self::Metal(storage) => { + let (storage, shape) = c.metal_fwd(storage, l)?; + Ok((Self::Metal(storage), shape)) + } + } + } + + pub(crate) fn apply_op2( + &self, + l1: &Layout, + t2: &Self, + l2: &Layout, + c: &dyn CustomOp2, + ) -> Result<(Self, Shape)> { + self.same_device(t2, c.name())?; + match (self, t2) { + (Self::Cpu(s1), Self::Cpu(s2)) => { + let (s, shape) = c.cpu_fwd(s1, l1, s2, l2)?; + Ok((Self::Cpu(s), shape)) + } + (Self::Cuda(s1), Self::Cuda(s2)) => { + let (s, shape) = c.cuda_fwd(s1, l1, s2, l2)?; + Ok((Self::Cuda(s), shape)) + } + (Self::Metal(s1), Self::Metal(s2)) => { + let (s, shape) = c.metal_fwd(s1, l1, s2, l2)?; + Ok((Self::Metal(s), shape)) + } + _ => unreachable!(), + } + } + + pub(crate) fn apply_op3( + &self, + l1: &Layout, + t2: &Self, + l2: &Layout, + t3: &Self, + l3: &Layout, + c: &dyn CustomOp3, + ) -> Result<(Self, Shape)> { + self.same_device(t2, c.name())?; + self.same_device(t3, c.name())?; + match (self, t2, t3) { + (Self::Cpu(s1), Self::Cpu(s2), Self::Cpu(s3)) => { + let (s, shape) = c.cpu_fwd(s1, l1, s2, l2, s3, l3)?; + Ok((Self::Cpu(s), shape)) + } + (Self::Cuda(s1), Self::Cuda(s2), Self::Cuda(s3)) => { + let (s, shape) = c.cuda_fwd(s1, l1, s2, l2, s3, l3)?; + Ok((Self::Cuda(s), shape)) + } + (Self::Metal(s1), Self::Metal(s2), Self::Metal(s3)) => { + let (s, shape) = c.metal_fwd(s1, l1, s2, l2, s3, l3)?; + Ok((Self::Metal(s), shape)) + } + _ => unreachable!(), + } + } + + pub(crate) fn inplace_op1(&mut self, l: &Layout, c: &dyn InplaceOp1) -> Result<()> { + match self { + Self::Cpu(storage) => c.cpu_fwd(storage, l), + Self::Cuda(storage) => c.cuda_fwd(storage, l), + Self::Metal(storage) => c.metal_fwd(storage, l), + } + } + + pub(crate) fn inplace_op2( + &mut self, + l1: &Layout, + t2: &Self, + l2: &Layout, + c: &dyn InplaceOp2, + ) -> Result<()> { + self.same_device(t2, c.name())?; + match (self, t2) { + (Self::Cpu(s1), Self::Cpu(s2)) => c.cpu_fwd(s1, l1, s2, l2), + (Self::Cuda(s1), Self::Cuda(s2)) => c.cuda_fwd(s1, l1, s2, l2), + (Self::Metal(s1), Self::Metal(s2)) => c.metal_fwd(s1, l1, s2, l2), + _ => unreachable!(), + } + } + + pub(crate) fn inplace_op3( + &mut self, + l1: &Layout, + t2: &Self, + l2: &Layout, + t3: &Self, + l3: &Layout, + c: &dyn InplaceOp3, + ) -> Result<()> { + self.same_device(t2, c.name())?; + self.same_device(t3, c.name())?; + match (self, t2, t3) { + (Self::Cpu(s1), Self::Cpu(s2), Self::Cpu(s3)) => c.cpu_fwd(s1, l1, s2, l2, s3, l3), + (Self::Cuda(s1), Self::Cuda(s2), Self::Cuda(s3)) => c.cuda_fwd(s1, l1, s2, l2, s3, l3), + (Self::Metal(s1), Self::Metal(s2), Self::Metal(s3)) => { + c.metal_fwd(s1, l1, s2, l2, s3, l3) + } + _ => unreachable!(), + } + } + + pub(crate) fn unary_impl(&self, layout: &Layout) -> Result { + match self { + Storage::Cpu(storage) => { + let storage = storage.unary_impl::(layout)?; + Ok(Self::Cpu(storage)) + } + Self::Cuda(storage) => { + let storage = storage.unary_impl::(layout)?; + Ok(Self::Cuda(storage)) + } + Self::Metal(storage) => { + let storage = storage.unary_impl::(layout)?; + Ok(Self::Metal(storage)) + } + } + } + + pub(crate) fn binary_impl( + &self, + rhs: &Self, + lhs_layout: &Layout, + rhs_layout: &Layout, + ) -> Result { + self.same_device(rhs, B::NAME)?; + self.same_dtype(rhs, B::NAME)?; + match (self, rhs) { + (Storage::Cpu(lhs), Storage::Cpu(rhs)) => { + let storage = lhs.binary_impl::(rhs, lhs_layout, rhs_layout)?; + Ok(Self::Cpu(storage)) + } + (Self::Cuda(lhs), Self::Cuda(rhs)) => { + let storage = lhs.binary_impl::(rhs, lhs_layout, rhs_layout)?; + Ok(Self::Cuda(storage)) + } + (Self::Metal(lhs), Self::Metal(rhs)) => { + let storage = lhs.binary_impl::(rhs, lhs_layout, rhs_layout)?; + Ok(Self::Metal(storage)) + } + (lhs, rhs) => { + // Should not happen because of the same device check above but we're defensive + // anyway. + Err(Error::DeviceMismatchBinaryOp { + lhs: lhs.device().location(), + rhs: rhs.device().location(), + op: B::NAME, + } + .bt()) + } + } + } + + pub(crate) fn conv1d( + &self, + l: &Layout, + kernel: &Self, + kernel_l: &Layout, + params: &crate::conv::ParamsConv1D, + ) -> Result { + self.same_device(kernel, "conv1d")?; + self.same_dtype(kernel, "conv1d")?; + match (self, &kernel) { + (Storage::Cpu(inp), Storage::Cpu(kernel)) => { + let s = inp.conv1d(l, kernel, kernel_l, params)?; + Ok(Self::Cpu(s)) + } + (Storage::Cuda(inp), Storage::Cuda(kernel)) => { + let s = inp.conv1d(l, kernel, kernel_l, params)?; + Ok(Self::Cuda(s)) + } + (Storage::Metal(inp), Storage::Metal(kernel)) => { + let s = inp.conv1d(l, kernel, kernel_l, params)?; + Ok(Self::Metal(s)) + } + (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { + lhs: lhs.device().location(), + rhs: rhs.device().location(), + op: "conv1d", + } + .bt()), + } + } + + pub(crate) fn conv_transpose1d( + &self, + l: &Layout, + kernel: &Self, + kernel_l: &Layout, + params: &crate::conv::ParamsConvTranspose1D, + ) -> Result { + self.same_device(kernel, "conv-transpose1d")?; + self.same_dtype(kernel, "conv-transpose1d")?; + match (self, &kernel) { + (Storage::Cpu(inp), Storage::Cpu(kernel)) => { + let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?; + Ok(Self::Cpu(s)) + } + (Storage::Cuda(inp), Storage::Cuda(kernel)) => { + let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?; + Ok(Self::Cuda(s)) + } + (Storage::Metal(inp), Storage::Metal(kernel)) => { + let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?; + Ok(Self::Metal(s)) + } + (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { + lhs: lhs.device().location(), + rhs: rhs.device().location(), + op: "conv-transpose1d", + } + .bt()), + } + } + + pub(crate) fn conv2d( + &self, + l: &Layout, + kernel: &Self, + kernel_l: &Layout, + params: &crate::conv::ParamsConv2D, + ) -> Result { + self.same_device(kernel, "conv2d")?; + self.same_dtype(kernel, "conv2d")?; + match (self, &kernel) { + (Storage::Cpu(inp), Storage::Cpu(kernel)) => { + let s = inp.conv2d(l, kernel, kernel_l, params)?; + Ok(Self::Cpu(s)) + } + (Storage::Cuda(inp), Storage::Cuda(kernel)) => { + let s = inp.conv2d(l, kernel, kernel_l, params)?; + Ok(Self::Cuda(s)) + } + (Storage::Metal(inp), Storage::Metal(kernel)) => { + let s = inp.conv2d(l, kernel, kernel_l, params)?; + Ok(Self::Metal(s)) + } + (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { + lhs: lhs.device().location(), + rhs: rhs.device().location(), + op: "conv2d", + } + .bt()), + } + } + + pub(crate) fn conv_transpose2d( + &self, + l: &Layout, + kernel: &Self, + kernel_l: &Layout, + params: &crate::conv::ParamsConvTranspose2D, + ) -> Result { + self.same_device(kernel, "conv_transpose2d")?; + self.same_dtype(kernel, "conv_transpose2d")?; + match (self, &kernel) { + (Storage::Cpu(inp), Storage::Cpu(kernel)) => { + let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?; + Ok(Self::Cpu(s)) + } + (Storage::Cuda(inp), Storage::Cuda(kernel)) => { + let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?; + Ok(Self::Cuda(s)) + } + (Storage::Metal(inp), Storage::Metal(kernel)) => { + let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?; + Ok(Self::Metal(s)) + } + (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { + lhs: lhs.device().location(), + rhs: rhs.device().location(), + op: "conv_transpose2d", + } + .bt()), + } + } + + pub(crate) fn avg_pool2d( + &self, + layout: &Layout, + kernel_size: (usize, usize), + stride: (usize, usize), + ) -> Result { + match self { + Storage::Cpu(storage) => { + let storage = storage.avg_pool2d(layout, kernel_size, stride)?; + Ok(Self::Cpu(storage)) + } + Self::Cuda(storage) => { + let storage = storage.avg_pool2d(layout, kernel_size, stride)?; + Ok(Self::Cuda(storage)) + } + Self::Metal(storage) => { + let storage = storage.avg_pool2d(layout, kernel_size, stride)?; + Ok(Self::Metal(storage)) + } + } + } + + pub(crate) fn max_pool2d( + &self, + layout: &Layout, + kernel_size: (usize, usize), + stride: (usize, usize), + ) -> Result { + match self { + Storage::Cpu(storage) => { + let storage = storage.max_pool2d(layout, kernel_size, stride)?; + Ok(Self::Cpu(storage)) + } + Self::Cuda(storage) => { + let storage = storage.max_pool2d(layout, kernel_size, stride)?; + Ok(Self::Cuda(storage)) + } + Self::Metal(storage) => { + let storage = storage.max_pool2d(layout, kernel_size, stride)?; + Ok(Self::Metal(storage)) + } + } + } + + pub(crate) fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result { + match self { + Storage::Cpu(storage) => { + let storage = storage.upsample_nearest1d(layout, sz)?; + Ok(Self::Cpu(storage)) + } + Self::Cuda(storage) => { + let storage = storage.upsample_nearest1d(layout, sz)?; + Ok(Self::Cuda(storage)) + } + Self::Metal(storage) => { + let storage = storage.upsample_nearest1d(layout, sz)?; + Ok(Self::Metal(storage)) + } + } + } + + pub(crate) fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result { + match self { + Storage::Cpu(storage) => { + let storage = storage.upsample_nearest2d(layout, h, w)?; + Ok(Self::Cpu(storage)) + } + Self::Cuda(storage) => { + let storage = storage.upsample_nearest2d(layout, h, w)?; + Ok(Self::Cuda(storage)) + } + Self::Metal(storage) => { + let storage = storage.upsample_nearest2d(layout, h, w)?; + Ok(Self::Metal(storage)) + } + } + } + + pub(crate) fn where_cond( + &self, + layout: &Layout, + t: &Self, + layout_t: &Layout, + f: &Self, + layout_f: &Layout, + ) -> Result { + self.same_device(t, "where")?; + self.same_device(f, "where")?; + t.same_dtype(f, "where")?; + match (self, t, f) { + (Storage::Cpu(cond), Storage::Cpu(t), Storage::Cpu(f)) => { + let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?; + Ok(Self::Cpu(storage)) + } + (Self::Cuda(cond), Self::Cuda(t), Self::Cuda(f)) => { + let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?; + Ok(Self::Cuda(storage)) + } + (Self::Metal(cond), Self::Metal(t), Self::Metal(f)) => { + let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?; + Ok(Self::Metal(storage)) + } + (_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { + lhs: lhs.device().location(), + rhs: rhs.device().location(), + op: "where", + } + .bt()), + } + } + + pub(crate) fn gather( + &self, + l: &Layout, + indexes: &Self, + indexes_l: &Layout, + d: usize, + ) -> Result { + self.same_device(indexes, "index-add")?; + match (self, indexes) { + (Self::Cpu(s), Self::Cpu(indexes)) => { + let storage = s.gather(l, indexes, indexes_l, d)?; + Ok(Self::Cpu(storage)) + } + (Self::Cuda(s), Self::Cuda(indexes)) => { + let storage = s.gather(l, indexes, indexes_l, d)?; + Ok(Self::Cuda(storage)) + } + (Self::Metal(s), Self::Metal(indexes)) => { + let storage = s.gather(l, indexes, indexes_l, d)?; + Ok(Self::Metal(storage)) + } + _ => unreachable!(), + } + } + + pub(crate) fn scatter_add( + &self, + l: &Layout, + indexes: &Self, + indexes_l: &Layout, + source: &Self, + source_l: &Layout, + d: usize, + ) -> Result { + self.same_device(indexes, "scatter-add")?; + self.same_device(source, "scatter-add")?; + match (self, indexes, source) { + (Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => { + let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?; + Ok(Self::Cpu(storage)) + } + (Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => { + let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?; + Ok(Self::Cuda(storage)) + } + (Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => { + let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?; + Ok(Self::Metal(storage)) + } + _ => unreachable!(), + } + } + + pub(crate) fn index_add( + &self, + l: &Layout, + indexes: &Self, + indexes_l: &Layout, + source: &Self, + source_l: &Layout, + d: usize, + ) -> Result { + self.same_device(indexes, "index-add")?; + self.same_device(source, "index-add")?; + match (self, indexes, source) { + (Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => { + let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?; + Ok(Self::Cpu(storage)) + } + (Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => { + let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?; + Ok(Self::Cuda(storage)) + } + (Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => { + let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?; + Ok(Self::Metal(storage)) + } + _ => unreachable!(), + } + } + + pub(crate) fn index_select( + &self, + rhs: &Self, + lhs_l: &Layout, + rhs_l: &Layout, + d: usize, + ) -> Result { + self.same_device(rhs, "index-select")?; + match (self, rhs) { + (Self::Cpu(lhs), Self::Cpu(rhs)) => { + let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?; + Ok(Self::Cpu(storage)) + } + (Self::Cuda(lhs), Self::Cuda(rhs)) => { + let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?; + Ok(Self::Cuda(storage)) + } + (Self::Metal(lhs), Self::Metal(rhs)) => { + let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?; + Ok(Self::Metal(storage)) + } + (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { + lhs: lhs.device().location(), + rhs: rhs.device().location(), + op: "index-select", + } + .bt()), + } + } + + #[allow(clippy::too_many_arguments)] + pub(crate) fn matmul_with_alpha_beta( + &self, + rhs: &Self, + c: &mut Self, + s: Option, + bmnk: (usize, usize, usize, usize), + lhs_layout: &Layout, + rhs_layout: &Layout, + c_layout: &Layout, + ) -> Result<()> { + self.same_device(rhs, "matmul_with_alpha_beta")?; + self.same_dtype(rhs, "matmul_with_alpha_beta")?; + self.same_device(c, "matmul_with_alpha_beta")?; + self.same_dtype(c, "matmul_with_alpha_beta")?; + match (self, rhs, c) { + (Self::Cpu(lhs), Self::Cpu(rhs), Self::Cpu(c)) => { + lhs.matmul_with_alpha_beta(rhs, c, s, bmnk, lhs_layout, rhs_layout, c_layout) + } + (Self::Cuda(lhs), Self::Cuda(rhs), Self::Cuda(c)) => { + lhs.matmul_with_alpha_beta(rhs, c, s, bmnk, lhs_layout, rhs_layout, c_layout) + } + (Self::Metal(lhs), Self::Metal(rhs), Self::Metal(c)) => { + lhs.matmul_with_alpha_beta(rhs, c, s, bmnk, lhs_layout, rhs_layout, c_layout) + } + (lhs, rhs, c) => Err(Error::DeviceMismatchBinaryOp3 { + lhs: lhs.device().location(), + rhs: rhs.device().location(), + c: c.device().location(), + op: "matmul_with_alpha_beta", + } + .bt()), + } + } + + pub(crate) fn matmul_with_alpha( + &self, + rhs: &Self, + s: Option, + bmnk: (usize, usize, usize, usize), + lhs_layout: &Layout, + rhs_layout: &Layout, + ) -> Result { + self.same_device(rhs, "matmul_with_alpha")?; + self.same_dtype(rhs, "matmul_with_alpha")?; + match (self, rhs) { + (Self::Cpu(lhs), Self::Cpu(rhs)) => { + let storage = lhs.matmul_with_alpha(rhs, s, bmnk, lhs_layout, rhs_layout)?; + Ok(Self::Cpu(storage)) + } + (Self::Cuda(lhs), Self::Cuda(rhs)) => { + let storage = lhs.matmul_with_alpha(rhs, s, bmnk, lhs_layout, rhs_layout)?; + Ok(Self::Cuda(storage)) + } + (Self::Metal(lhs), Self::Metal(rhs)) => { + let storage = lhs.matmul_with_alpha(rhs, s, bmnk, lhs_layout, rhs_layout)?; + Ok(Self::Metal(storage)) + } + (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { + lhs: lhs.device().location(), + rhs: rhs.device().location(), + op: "matmul", + } + .bt()), + } + } + + // self, the source can be strided whereas dst is contiguous. + pub(crate) fn copy_strided_src( + &self, + dst: &mut Self, + dst_offset: usize, + src_l: &Layout, + ) -> Result<()> { + match (self, dst) { + (Self::Cpu(src), Self::Cpu(dst)) => src.copy_strided_src(dst, dst_offset, src_l), + (Self::Cuda(src), Self::Cuda(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?), + (Self::Metal(src), Self::Metal(dst)) => { + Ok(src.copy_strided_src(dst, dst_offset, src_l)?) + } + (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { + lhs: lhs.device().location(), + rhs: rhs.device().location(), + op: "copy", + } + .bt()), + } + } + + #[allow(clippy::too_many_arguments)] + pub(crate) fn copy2d( + &self, + dst: &mut Self, + d1: usize, + d2: usize, + src_s: usize, + dst_s: usize, + src_o: usize, + dst_o: usize, + ) -> Result<()> { + match (self, dst) { + (Self::Cpu(src), Self::Cpu(dst)) => src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o), + (Self::Cuda(src), Self::Cuda(dst)) => { + Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?) + } + (Self::Metal(src), Self::Metal(dst)) => { + Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?) + } + (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { + lhs: lhs.device().location(), + rhs: rhs.device().location(), + op: "copy2d", + } + .bt()), + } + } +} diff --git a/mcandle/mcandle-core/src/streaming.rs b/mcandle/mcandle-core/src/streaming.rs new file mode 100644 index 000000000..f70ec51e6 --- /dev/null +++ b/mcandle/mcandle-core/src/streaming.rs @@ -0,0 +1,206 @@ +use crate::{Result, Shape, Tensor}; + +pub trait Dim: crate::shape::Dim + Copy {} +impl Dim for T {} + +/// A stream tensor is used in streaming module. It can either contain an actual tensor or be +/// empty. +#[derive(Clone)] +pub struct StreamTensor(Option); + +impl std::fmt::Debug for StreamTensor { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self.0 { + Some(t) => write!(f, "{:?}", t.shape()), + None => write!(f, "Empty"), + } + } +} + +impl std::convert::From> for StreamTensor { + fn from(value: Option) -> Self { + Self(value) + } +} + +impl std::convert::From for StreamTensor { + fn from(value: Tensor) -> Self { + Self(Some(value)) + } +} + +impl std::convert::From<()> for StreamTensor { + fn from(_value: ()) -> Self { + Self(None) + } +} + +impl StreamTensor { + pub fn empty() -> Self { + Self(None) + } + + pub fn from_tensor(tensor: Tensor) -> Self { + Self(Some(tensor)) + } + + pub fn shape(&self) -> Option<&Shape> { + self.0.as_ref().map(|t| t.shape()) + } + + pub fn cat2(&self, rhs: &Self, dim: D) -> Result { + let xs = match (&self.0, &rhs.0) { + (Some(lhs), Some(rhs)) => { + let xs = Tensor::cat(&[lhs, rhs], dim)?; + Some(xs) + } + (Some(xs), None) | (None, Some(xs)) => Some(xs.clone()), + (None, None) => None, + }; + Ok(Self(xs)) + } + + pub fn seq_len(&self, dim: D) -> Result { + match &self.0 { + None => Ok(0), + Some(v) => v.dim(dim), + } + } + + pub fn reset(&mut self) { + self.0 = None + } + + pub fn narrow(&self, dim: D, offset: usize, len: usize) -> Result { + let t = match &self.0 { + None => None, + Some(t) => { + let seq_len = t.dim(dim)?; + if seq_len <= offset { + None + } else { + let t = t.narrow(dim, offset, usize::min(len, seq_len - offset))?; + Some(t) + } + } + }; + Ok(Self(t)) + } + + /// Splits the Streaming Tensor on the time axis `dim` with the first `lhs_len` elements + /// returned in the first output and the remaining in the second output. + pub fn split(&self, dim: D, lhs_len: usize) -> Result<(Self, Self)> { + match &self.0 { + None => Ok((Self::empty(), Self::empty())), + Some(t) => { + let seq_len = t.dim(dim)?; + let lhs_len = usize::min(seq_len, lhs_len); + if lhs_len == 0 { + Ok((Self::empty(), t.clone().into())) + } else { + let lhs = Self::from_tensor(t.narrow(dim, 0, lhs_len)?); + let rhs_len = seq_len - lhs_len; + let rhs = if rhs_len == 0 { + Self::empty() + } else { + Self::from_tensor(t.narrow(dim, lhs_len, rhs_len)?) + }; + Ok((lhs, rhs)) + } + } + } + } + + pub fn as_option(&self) -> Option<&Tensor> { + self.0.as_ref() + } + + pub fn apply(&self, m: &M) -> Result { + match &self.0 { + None => Ok(Self::empty()), + Some(t) => Ok(Self::from_tensor(t.apply(m)?)), + } + } +} + +/// Streaming modules take as input a stream tensor and return a stream tensor. They may perform +/// some internal buffering so that enough data has been received for the module to be able to +/// perform some operations. +pub trait StreamingModule { + // TODO: Should we also have a flush method? + fn step(&mut self, xs: &StreamTensor) -> Result; + fn reset_state(&mut self); +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum BinOp { + Add, + Mul, + Sub, + Div, +} + +#[derive(Debug, Clone)] +pub struct StreamingBinOp { + prev_lhs: StreamTensor, + prev_rhs: StreamTensor, + pub op: BinOp, + pub dim: crate::D, +} + +impl StreamingBinOp { + pub fn new(op: BinOp, dim: crate::D) -> Self { + Self { + prev_lhs: StreamTensor::empty(), + prev_rhs: StreamTensor::empty(), + op, + dim, + } + } + + pub fn reset_state(&mut self) { + self.prev_lhs.reset(); + self.prev_rhs.reset(); + } + + pub fn forward(&self, lhs: &Tensor, rhs: &Tensor) -> Result { + match self.op { + BinOp::Add => Tensor::add(lhs, rhs), + BinOp::Mul => Tensor::mul(lhs, rhs), + BinOp::Sub => Tensor::sub(lhs, rhs), + BinOp::Div => Tensor::div(lhs, rhs), + } + } + + pub fn step(&mut self, lhs: &StreamTensor, rhs: &StreamTensor) -> Result { + let lhs = StreamTensor::cat2(&self.prev_lhs, lhs, self.dim)?; + let rhs = StreamTensor::cat2(&self.prev_rhs, rhs, self.dim)?; + let lhs_len = lhs.seq_len(self.dim)?; + let rhs_len = rhs.seq_len(self.dim)?; + let common_len = usize::min(lhs_len, rhs_len); + let (lhs, prev_lhs) = lhs.split(self.dim, common_len)?; + let (rhs, prev_rhs) = rhs.split(self.dim, common_len)?; + let ys = match (lhs.0, rhs.0) { + (Some(lhs), Some(rhs)) => { + let ys = self.forward(&lhs, &rhs)?; + StreamTensor::from_tensor(ys) + } + (None, None) => StreamTensor::empty(), + (lhs, rhs) => crate::bail!("INTERNAL ERROR inconsistent lhs and rhs {lhs:?} {rhs:?}"), + }; + self.prev_lhs = prev_lhs; + self.prev_rhs = prev_rhs; + Ok(ys) + } +} + +/// Simple wrapper that doesn't do any buffering. +pub struct Map(T); + +impl StreamingModule for Map { + fn reset_state(&mut self) {} + + fn step(&mut self, xs: &StreamTensor) -> Result { + xs.apply(&self.0) + } +} diff --git a/mcandle/mcandle-core/src/strided_index.rs b/mcandle/mcandle-core/src/strided_index.rs new file mode 100644 index 000000000..eb6a736f8 --- /dev/null +++ b/mcandle/mcandle-core/src/strided_index.rs @@ -0,0 +1,82 @@ +use crate::Layout; + +/// An iterator over offset position for items of an N-dimensional arrays stored in a +/// flat buffer using some potential strides. +#[derive(Debug)] +pub struct StridedIndex<'a> { + next_storage_index: Option, + multi_index: Vec, + dims: &'a [usize], + stride: &'a [usize], +} + +impl<'a> StridedIndex<'a> { + pub(crate) fn new(dims: &'a [usize], stride: &'a [usize], start_offset: usize) -> Self { + let elem_count: usize = dims.iter().product(); + let next_storage_index = if elem_count == 0 { + None + } else { + // This applies to the scalar case. + Some(start_offset) + }; + StridedIndex { + next_storage_index, + multi_index: vec![0; dims.len()], + dims, + stride, + } + } + + pub(crate) fn from_layout(l: &'a Layout) -> Self { + Self::new(l.dims(), l.stride(), l.start_offset()) + } +} + +impl<'a> Iterator for StridedIndex<'a> { + type Item = usize; + + fn next(&mut self) -> Option { + let storage_index = match self.next_storage_index { + None => return None, + Some(storage_index) => storage_index, + }; + let mut updated = false; + let mut next_storage_index = storage_index; + for ((multi_i, max_i), stride_i) in self + .multi_index + .iter_mut() + .zip(self.dims.iter()) + .zip(self.stride.iter()) + .rev() + { + let next_i = *multi_i + 1; + if next_i < *max_i { + *multi_i = next_i; + updated = true; + next_storage_index += stride_i; + break; + } else { + next_storage_index -= *multi_i * stride_i; + *multi_i = 0 + } + } + self.next_storage_index = if updated { + Some(next_storage_index) + } else { + None + }; + Some(storage_index) + } +} + +#[derive(Debug)] +pub enum StridedBlocks<'a> { + SingleBlock { + start_offset: usize, + len: usize, + }, + MultipleBlocks { + block_start_index: StridedIndex<'a>, + block_len: usize, + }, +} diff --git a/mcandle/mcandle-core/src/tensor.rs b/mcandle/mcandle-core/src/tensor.rs new file mode 100644 index 000000000..834bac675 --- /dev/null +++ b/mcandle/mcandle-core/src/tensor.rs @@ -0,0 +1,2662 @@ +//! Tensors are N-dimensional matrixes of elements using a single data type. +#![allow(clippy::redundant_closure_call)] +use crate::backend::{BackendDevice, BackendStorage}; +use crate::op::{BackpropOp, BinaryOp, CmpOp, Op, ReduceOp, UnaryOp}; +use crate::scalar::TensorOrScalar; +use crate::shape::{Dim, Dims}; +use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape}; +use std::sync::{Arc, RwLock}; + +/// Unique identifier for tensors. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub struct TensorId(usize); + +impl TensorId { + fn new() -> Self { + // https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805 + use std::sync::atomic; + static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1); + Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed)) + } +} + +pub struct Tensor_ { + id: TensorId, + // As we provide inner mutability on the tensor content, the alternatives are: + // - Using a mutex, this would have the highest cost when retrieving the storage but would + // prevent errors when concurrent access takes place. Mutex would also be subject to + // deadlocks for example using the current code if the same tensor is used twice by a single + // binary op. + // - Using a refcell unsafe cell would have some intermediary cost, borrow checking would be + // verified dynamically, but the resulting tensors would not be send or sync. + // - Using an unsafe cell would have the lowest cost but undefined behavior on concurrent + // accesses. + // Ideally, we would use Arc for tensors on which we don't plan on modifying the data + // and Arc> for tensors where the data could be modified, e.g. variables but + // that's tricky to encode in the current setup. + storage: Arc>, + layout: Layout, + op: BackpropOp, + is_variable: bool, + dtype: DType, + device: Device, +} + +impl AsRef for Tensor { + fn as_ref(&self) -> &Tensor { + self + } +} + +// Tensors are refcounted so that cloning is cheap when building the op graph. +// Storages are also refcounted independently so that its possible to avoid +// copying the storage for operations that only modify the shape or stride. +#[derive(Clone)] +/// The core struct for manipulating tensors. +/// +/// ```rust +/// use mcandle_core::{Tensor, DType, Device}; +/// +/// let a = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((2, 3))?; +/// let b = Tensor::arange(0f32, 12f32, &Device::Cpu)?.reshape((3, 4))?; +/// +/// let c = a.matmul(&b)?; +/// # Ok::<(), mcandle_core::Error>(()) +/// ``` +/// +/// Tensors are reference counted with [`Arc`] so cloning them is cheap. +pub struct Tensor(Arc); + +impl std::ops::Deref for Tensor { + type Target = Tensor_; + + fn deref(&self) -> &Self::Target { + self.0.as_ref() + } +} + +macro_rules! unary_op { + ($fn_name:ident, $op_name:ident) => { + pub fn $fn_name(&self) -> Result { + let shape = self.shape(); + if shape.elem_count() == 0 { + return Ok(self.clone()); + } + let storage = self + .storage() + .unary_impl::(self.layout())?; + let op = BackpropOp::new1(self, |s| Op::Unary(s, UnaryOp::$op_name)); + Ok(from_storage(storage, shape.clone(), op, false)) + } + }; +} + +macro_rules! binary_op { + ($fn_name:ident, $op_name:ident) => { + pub fn $fn_name(&self, rhs: &Self) -> Result { + let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?; + if shape.elem_count() == 0 { + return Ok(self.clone()); + } + let storage = self.storage().binary_impl::( + &*rhs.storage(), + self.layout(), + rhs.layout(), + )?; + let op = BackpropOp::new2(self, rhs, |t1, t2| Op::Binary(t1, t2, BinaryOp::$op_name)); + Ok(from_storage(storage, shape.clone(), op, false)) + } + }; +} + +macro_rules! binary_op_scalar { + ($fn_name:ident, $op_name:ident) => { + pub fn $fn_name(&self, rhs: T) -> Result { + let rhs = match rhs.to_tensor_scalar()? { + crate::scalar::TensorScalar::Tensor(rhs) => rhs, + crate::scalar::TensorScalar::Scalar(rhs) => rhs + .to_dtype(self.dtype())? + .to_device(self.device())? + .broadcast_as(self.shape())?, + }; + let shape = self.same_shape_binary_op(&rhs, stringify!($fn_name))?; + if self.elem_count() == 0 { + return Ok(self.clone()); + } + let storage = self.storage().binary_impl::( + &*rhs.storage(), + self.layout(), + rhs.layout(), + )?; + let op = BackpropOp::new2(self, &rhs, |t1, t2| Op::Binary(t1, t2, BinaryOp::$op_name)); + Ok(from_storage(storage, shape.clone(), op, false)) + } + }; +} + +macro_rules! broadcast_binary_op { + ($fn_name:ident, $inner_fn_name:ident) => { + pub fn $fn_name(&self, rhs: &Self) -> Result { + let lhs = self; + let shape = lhs + .shape() + .broadcast_shape_binary_op(rhs.shape(), stringify!($fn_name))?; + let l_broadcast = shape != *lhs.shape(); + let r_broadcast = shape != *rhs.shape(); + match (l_broadcast, r_broadcast) { + (true, true) => lhs + .broadcast_as(&shape)? + .$inner_fn_name(&rhs.broadcast_as(&shape)?), + (false, true) => lhs.$inner_fn_name(&rhs.broadcast_as(&shape)?), + (true, false) => lhs.broadcast_as(&shape)?.$inner_fn_name(rhs), + (false, false) => lhs.$inner_fn_name(rhs), + } + } + }; +} + +/// Creates a fresh tensor structure based on a storage and a shape, this uses contiguous strides. +pub(crate) fn from_storage>( + storage: Storage, + shape: S, + op: BackpropOp, + is_variable: bool, +) -> Tensor { + let dtype = storage.dtype(); + let device = storage.device(); + let tensor_ = Tensor_ { + id: TensorId::new(), + storage: Arc::new(RwLock::new(storage)), + layout: Layout::contiguous(shape), + op, + is_variable, + dtype, + device, + }; + Tensor(Arc::new(tensor_)) +} + +/// Creates a fresh tensor structure based on a storage and a shape, this uses contiguous strides. This has a BackpropOp:none(). +pub fn from_storage_no_op>(storage: Storage, shape: S, is_variable: bool) -> Tensor { + let dtype = storage.dtype(); + let device = storage.device(); + let tensor_ = Tensor_ { + id: TensorId::new(), + storage: Arc::new(RwLock::new(storage)), + layout: Layout::contiguous(shape), + op: BackpropOp::none(), + is_variable, + dtype, + device, + }; + Tensor(Arc::new(tensor_)) +} + +impl Tensor { + pub(crate) fn ones_impl>( + shape: S, + dtype: DType, + device: &Device, + is_variable: bool, + ) -> Result { + let none = BackpropOp::none(); + let shape = shape.into(); + let storage = device.ones(&shape, dtype)?; + Ok(from_storage(storage, shape, none, is_variable)) + } + + /// Creates a new tensor filled with ones. + /// + /// ```rust + /// use mcandle_core::{Tensor, DType, Device}; + /// let a = Tensor::ones((2, 3), DType::F32, &Device::Cpu)?; + /// let b = Tensor::from_slice(&[1.0f32, 1.0, 1.0, 1.0, 1.0, 1.0], (2, 3), &Device::Cpu)?; + /// // a == b + /// # Ok::<(), mcandle_core::Error>(()) + /// ``` + pub fn ones>(shape: S, dtype: DType, device: &Device) -> Result { + Self::ones_impl(shape, dtype, device, false) + } + + /// Creates a new tensor filled with ones with same shape, dtype, and device as the other tensor. + /// + /// ```rust + /// use mcandle_core::{Tensor, DType, Device}; + /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// let b = a.ones_like()?; + /// // b == a + 1 + /// # Ok::<(), mcandle_core::Error>(()) + /// ``` + pub fn ones_like(&self) -> Result { + Tensor::ones(self.shape(), self.dtype(), self.device()) + } + + // Do not expose outside of the crate, the `is_variable=true` case should only be accessed from + // the variable module. + pub(crate) fn zeros_impl>( + shape: S, + dtype: DType, + device: &Device, + is_variable: bool, + ) -> Result { + let none = BackpropOp::none(); + let shape = shape.into(); + let storage = device.zeros(&shape, dtype)?; + Ok(from_storage(storage, shape, none, is_variable)) + } + + /// Creates a new tensor filled with zeros. + /// + /// ```rust + /// use mcandle_core::{Tensor, DType, Device}; + /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// let b = Tensor::from_slice(&[0.0f32, 0.0, 0.0, 0.0, 0.0, 0.0], (2, 3), &Device::Cpu)?; + /// // a == b + /// # Ok::<(), mcandle_core::Error>(()) + /// ``` + pub fn zeros>(shape: S, dtype: DType, device: &Device) -> Result { + Self::zeros_impl(shape, dtype, device, false) + } + + /// Creates a new tensor filled with ones with same shape, dtype, and device as the other + /// tensor. + /// + /// ```rust + /// use mcandle_core::{Tensor, DType, Device}; + /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// let b = a.zeros_like()?; + /// // b is on CPU f32. + /// # Ok::<(), mcandle_core::Error>(()) + /// ``` + pub fn zeros_like(&self) -> Result { + Tensor::zeros(self.shape(), self.dtype(), self.device()) + } + + // Do not expose outside of the crate, the `is_variable=true` case should only be accessed from + // the variable module. + pub(crate) unsafe fn empty_impl>( + shape: S, + dtype: DType, + device: &Device, + is_variable: bool, + ) -> Result { + let none = BackpropOp::none(); + let shape = shape.into(); + let storage = device.alloc_uninit(&shape, dtype)?; + Ok(from_storage(storage, shape, none, is_variable)) + } + + /// Creates a new tensor filled with uninitialized memory. + /// + /// # Safety + /// This returns uninitialized memory. + /// + /// ```rust + /// use mcandle_core::{Tensor, DType, Device}; + /// let a = unsafe { Tensor::empty((2, 3), DType::F32, &Device::Cpu)? }; + /// // a == b + /// # Ok::<(), mcandle_core::Error>(()) + /// ``` + pub unsafe fn empty>(shape: S, dtype: DType, device: &Device) -> Result { + Self::empty_impl(shape, dtype, device, false) + } + + /// Creates a new tensor filled with uninitialized memory of the same shape, dtype, and device as the other + /// tensor. + /// + /// # Safety + /// This returns uninitialized memory. + /// + /// ```rust + /// use mcandle_core::{Tensor, DType, Device}; + /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// let b = unsafe { a.empty_like()? }; + /// # Ok::<(), mcandle_core::Error>(()) + /// ``` + pub unsafe fn empty_like(&self) -> Result { + Tensor::empty(self.shape(), self.dtype(), self.device()) + } + + pub(crate) fn rand_impl, T: crate::FloatDType>( + lo: T, + up: T, + s: S, + device: &Device, + is_variable: bool, + ) -> Result { + let s = s.into(); + let storage = device.rand_uniform(lo, up, &s)?; + let none = BackpropOp::none(); + Ok(from_storage(storage, s, none, is_variable)) + } + + pub(crate) fn rand_f64_impl>( + lo: f64, + up: f64, + s: S, + dtype: DType, + device: &Device, + is_variable: bool, + ) -> Result { + let s = s.into(); + let storage = device.rand_uniform_f64(lo, up, &s, dtype)?; + let none = BackpropOp::none(); + Ok(from_storage(storage, s, none, is_variable)) + } + + /// Creates a new tensor initialized with values sampled uniformly between `lo` and `up`. + pub fn rand, T: crate::FloatDType>( + lo: T, + up: T, + s: S, + device: &Device, + ) -> Result { + Self::rand_impl(lo, up, s, device, false) + } + + pub fn rand_like(&self, lo: f64, up: f64) -> Result { + Tensor::rand_f64_impl(lo, up, self.shape(), self.dtype(), self.device(), false) + } + + pub(crate) fn randn_impl, T: crate::FloatDType>( + mean: T, + std: T, + s: S, + device: &Device, + is_variable: bool, + ) -> Result { + let s = s.into(); + let storage = device.rand_normal(mean, std, &s)?; + let none = BackpropOp::none(); + Ok(from_storage(storage, s, none, is_variable)) + } + + pub(crate) fn randn_f64_impl>( + mean: f64, + std: f64, + s: S, + dtype: DType, + device: &Device, + is_variable: bool, + ) -> Result { + let s = s.into(); + let storage = device.rand_normal_f64(mean, std, &s, dtype)?; + let none = BackpropOp::none(); + Ok(from_storage(storage, s, none, is_variable)) + } + + pub fn randn_like(&self, mean: f64, stdev: f64) -> Result { + Tensor::randn_f64_impl( + mean, + stdev, + self.shape(), + self.dtype(), + self.device(), + false, + ) + } + + /// Creates a new tensor initialized with values sampled from a normal distribution with the + /// specified `mean` and standard deviation `std`. + pub fn randn, T: crate::FloatDType>( + mean: T, + std: T, + s: S, + device: &Device, + ) -> Result { + Self::randn_impl(mean, std, s, device, false) + } + + pub(crate) fn new_impl( + array: A, + shape: Shape, + device: &Device, + is_variable: bool, + ) -> Result { + let n: usize = shape.elem_count(); + let buffer_size: usize = array.shape()?.elem_count(); + if buffer_size != n { + return Err(Error::ShapeMismatch { buffer_size, shape }.bt()); + } + let storage = device.storage(array)?; + let none = BackpropOp::none(); + Ok(from_storage(storage, shape, none, is_variable)) + } + + /// Creates a new tensor on the specified device using the content and shape of the input. + pub fn new(array: A, device: &Device) -> Result { + let shape = array.shape()?; + Self::new_impl(array, shape, device, false) + } + + /// Returns a new tensor with all the elements having the same specified value. Note that + /// the tensor is not contiguous so you would have to call `.contiguous()` on it if needed. + ///```rust + /// use mcandle_core::{Tensor, Device}; + /// let a = Tensor::full(3.5, (2, 4), &Device::Cpu)?; + /// + /// assert_eq!(a.to_vec2::()?, &[ + /// [3.5, 3.5, 3.5, 3.5], + /// [3.5, 3.5, 3.5, 3.5], + /// ]); + /// # Ok::<(), mcandle_core::Error>(()) + pub fn full>( + value: D, + shape: S, + device: &Device, + ) -> Result { + Self::from_vec_impl(vec![value], (), device, false)?.broadcast_as(shape) + } + + /// Creates a new 1D tensor from an iterator. + ///```rust + /// use mcandle_core::{Tensor, Device}; + /// let a = Tensor::from_iter( [1.0, 2.0, 3.0, 4.0].into_iter(), &Device::Cpu)?; + /// + /// assert_eq!(a.to_vec1::()?, &[1.0, 2.0, 3.0, 4.0]); + /// # Ok::<(), mcandle_core::Error>(()) + /// ``` + pub fn from_iter( + iter: impl IntoIterator, + device: &Device, + ) -> Result { + let data = iter.into_iter().collect::>(); + let len = data.len(); + Self::from_vec_impl(data, len, device, false) + } + + /// Creates a new 1D tensor with values from the interval `[start, end)` taken with a common + /// difference `1` from `start`. + ///```rust + /// use mcandle_core::{Tensor, Device}; + /// let a = Tensor::arange(2., 5., &Device::Cpu)?; + /// + /// assert_eq!(a.to_vec1::()?, &[2., 3., 4.]); + /// # Ok::<(), mcandle_core::Error>(()) + /// ``` + pub fn arange(start: D, end: D, device: &Device) -> Result { + Self::arange_step(start, end, D::one(), device) + } + + /// Creates a new 1D tensor with values from the interval `[start, end)` taken with a common + /// difference `step` from `start`. + ///```rust + /// use mcandle_core::{Tensor, Device}; + /// let a = Tensor::arange_step(2.0, 4.0, 0.5, &Device::Cpu)?; + /// + /// assert_eq!(a.to_vec1::()?, &[2.0, 2.5, 3.0, 3.5]); + /// # Ok::<(), mcandle_core::Error>(()) + /// ``` + pub fn arange_step( + start: D, + end: D, + step: D, + device: &Device, + ) -> Result { + if D::is_zero(&step) { + bail!("step cannot be zero") + } + let mut data = vec![]; + let mut current = start; + if step >= D::zero() { + while current < end { + data.push(current); + current += step; + } + } else { + while current > end { + data.push(current); + current += step; + } + } + let len = data.len(); + Self::from_vec_impl(data, len, device, false) + } + + pub(crate) fn from_vec_impl, D: crate::WithDType>( + data: Vec, + shape: S, + device: &Device, + is_variable: bool, + ) -> Result { + let shape = shape.into(); + let buffer_size = data.len(); + if buffer_size != shape.elem_count() { + return Err(Error::ShapeMismatch { buffer_size, shape }.bt()); + } + let storage = device.storage_owned(data)?; + let none = BackpropOp::none(); + Ok(from_storage(storage, shape, none, is_variable)) + } + + /// Creates a new tensor initialized with values from the input vector. The number of elements + /// in this vector must be the same as the number of elements defined by the shape. + /// If the device is cpu, no data copy is made. + ///```rust + /// use mcandle_core::{Tensor, Device}; + /// let a = Tensor::from_vec(vec!{1., 2., 3., 4., 5., 6.}, (2, 3), &Device::Cpu)?; + /// + /// assert_eq!(a.to_vec2::()?, &[ + /// [1., 2., 3.], + /// [4., 5., 6.] + /// ]); + /// # Ok::<(), mcandle_core::Error>(()) + /// ``` + pub fn from_vec, D: crate::WithDType>( + data: Vec, + shape: S, + device: &Device, + ) -> Result { + Self::from_vec_impl(data, shape, device, false) + } + + /// Creates a new tensor initialized with values from the input slice. The number of elements + /// in this vector must be the same as the number of elements defined by the shape. + ///```rust + /// use mcandle_core::{Tensor, Device}; + /// let values = vec![1., 2., 3., 4., 5., 6., 7., 8.]; + /// let a = Tensor::from_slice(&values[1..7], (2, 3), &Device::Cpu)?; + /// + /// assert_eq!(a.to_vec2::()?, &[ + /// [2., 3., 4.], + /// [5., 6., 7.] + /// ]); + /// # Ok::<(), mcandle_core::Error>(()) + /// ``` + pub fn from_slice, D: crate::WithDType>( + array: &[D], + shape: S, + device: &Device, + ) -> Result { + let shape = shape.into(); + let n: usize = shape.elem_count(); + let buffer_size: usize = array.len(); + if buffer_size != n { + return Err(Error::ShapeMismatch { buffer_size, shape }.bt()); + } + let storage = device.storage_from_slice(array)?; + let none = BackpropOp::none(); + Ok(from_storage(storage, shape, none, false)) + } + + pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> { + let lhs = self.shape(); + let rhs = rhs.shape(); + if lhs != rhs { + Err(Error::ShapeMismatchBinaryOp { + lhs: lhs.clone(), + rhs: rhs.clone(), + op, + } + .bt()) + } else { + Ok(lhs) + } + } + + /// Returns true if the computation graph should track this op, that is if it is + /// a variable or if it has some variable as dependencies. + pub fn track_op(&self) -> bool { + self.is_variable || self.op.is_some() + } + + // TODO: Also make an inplace version or a pre-allocated? This could be tricky + // if this can create cycles in the compute graph. + binary_op!(add, Add); + binary_op!(mul, Mul); + binary_op!(sub, Sub); + binary_op!(div, Div); + binary_op_scalar!(maximum, Maximum); + binary_op_scalar!(minimum, Minimum); + broadcast_binary_op!(broadcast_add, add); + broadcast_binary_op!(broadcast_mul, mul); + broadcast_binary_op!(broadcast_sub, sub); + broadcast_binary_op!(broadcast_div, div); + broadcast_binary_op!(broadcast_maximum, maximum); + broadcast_binary_op!(broadcast_minimum, minimum); + broadcast_binary_op!(broadcast_eq, eq); + broadcast_binary_op!(broadcast_ne, ne); + broadcast_binary_op!(broadcast_lt, lt); + broadcast_binary_op!(broadcast_le, le); + broadcast_binary_op!(broadcast_gt, gt); + broadcast_binary_op!(broadcast_ge, ge); + + unary_op!(recip, Recip); + unary_op!(neg, Neg); + unary_op!(exp, Exp); + unary_op!(log, Log); + unary_op!(sin, Sin); + unary_op!(cos, Cos); + unary_op!(tanh, Tanh); + unary_op!(abs, Abs); + unary_op!(sqr, Sqr); + unary_op!(sqrt, Sqrt); + unary_op!(gelu, Gelu); + unary_op!(gelu_erf, GeluErf); + unary_op!(erf, Erf); + unary_op!(relu, Relu); + unary_op!(silu, Silu); + unary_op!(ceil, Ceil); + unary_op!(floor, Floor); + unary_op!(round, Round); + unary_op!(sign, Sign); + + /// Round element of the input tensor to the nearest integer. + /// + /// If the number of decimals is negative, it specifies the number of positions to the left of + /// the decimal point. + pub fn round_to(&self, decimals: i32) -> Result { + let mult = 10f64.powi(decimals); + (self * mult)?.round()? * (1f64 / mult) + } + + /// Retrieves the single scalar value hold in the tensor. If the tensor contains multiple + /// dimensions, an error is returned instead. + pub fn to_scalar(&self) -> Result { + if self.rank() != 0 { + Err(Error::UnexpectedNumberOfDims { + expected: 0, + got: self.rank(), + shape: self.shape().clone(), + } + .bt())? + } + let from_cpu_storage = |cpu_storage: &crate::CpuStorage| { + let data = S::cpu_storage_as_slice(cpu_storage)?; + Ok::<_, Error>(data[self.layout().start_offset()]) + }; + match &*self.storage() { + Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage), + Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?), + Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?), + } + } + + /// An alias for `to_scalar`. + pub fn to_vec0(&self) -> Result { + self.to_scalar::() + } + + /// Repeat this tensor along the specified dimensions. + pub fn repeat>(&self, shape: S) -> Result { + // Similar to PyTorch, we extend the number of dimensions of self if needed. + let repeats = shape.into(); + let repeats = repeats.dims(); + let mut inp = if self.rank() < repeats.len() { + let shape = [vec![1; repeats.len() - self.rank()], self.dims().to_vec()].concat(); + self.reshape(shape)? + } else { + self.clone() + }; + for (idx, &repeat) in repeats.iter().enumerate() { + if repeat > 1 { + inp = Tensor::cat(&vec![&inp; repeat], idx)? + } + } + Ok(inp) + } + + /// Creates grids of coordinates specified by the 1D inputs. + /// + /// # Arguments + /// + /// * `args` - A slice of 1D tensors. + /// * `xy_indexing` - Whether to use xy indexing or ij indexing. If xy is selected, the + /// first dimension corresponds to the cardinality of the second input and the second + /// dimension corresponds to the cardinality of the first input. If ij is selected, the + /// dimensions are in the same order as the cardinality of the inputs. + /// + /// # Examples + /// + /// ```rust + /// use mcandle_core::{Tensor, Device, Shape}; + /// let x = Tensor::new(&[1f32, 2., 3.], &Device::Cpu)?; + /// let y = Tensor::new(&[4f32, 5., 6.], &Device::Cpu)?; + /// + /// let grids_xy = Tensor::meshgrid(&[&x, &y], true)?; + /// + /// assert_eq!(grids_xy.len(), 2); + /// assert_eq!(grids_xy[0].dims(), &[3, 3]); + /// + /// assert_eq!(grids_xy[0].to_vec2::()?, &[[1., 2., 3.], [1., 2., 3.], [1., 2., 3.]]); + /// assert_eq!(grids_xy[1].to_vec2::()?, &[[4., 4., 4.], [5., 5., 5.], [6., 6., 6.]]); + /// + /// let grids_ij = Tensor::meshgrid(&[&x, &y], false)?; + /// + /// assert_eq!(grids_ij[0].to_vec2::()?, &[[1., 1., 1.], [2., 2., 2.], [3., 3., 3.]]); + /// assert_eq!(grids_ij[1].to_vec2::()?, &[[4., 5., 6.], [4., 5., 6.], [4., 5., 6.]]); + /// # Ok::<(), mcandle_core::Error>(()) + /// ``` + /// + /// # Errors + /// + /// * Will return `Err` if `args` contains less than 2 tensors. + /// + pub fn meshgrid>(args: &[A], xy_indexing: bool) -> Result> { + if args.len() <= 1 { + Err(Error::OpRequiresAtLeastTwoTensors { op: "meshgrid" }.bt())? + } + let args: Vec<_> = if xy_indexing { + args.iter().rev().collect() + } else { + args.iter().collect() + }; + + let mut shape = Vec::with_capacity(args.len()); + for arg in args.iter() { + shape.push(arg.as_ref().dims1()?) + } + + let mut grids = Vec::with_capacity(args.len()); + for idx in 0..args.len() { + let mut ones = vec![1usize; args.len()]; + ones[idx] = shape[idx]; + let arg = args[idx].as_ref().reshape(ones)?; + let mut repeats = shape.clone(); + repeats[idx] = 1; + let repeated_tensor = arg.repeat(repeats)?; + grids.push(repeated_tensor); + } + if xy_indexing { + grids.reverse(); + } + Ok(grids) + } + + /// This operation multiplies the input tensor by `mul` then adds `add` and return the result. + /// The input values `mul` and `add` are casted to the appropriate type so some rounding might + /// be performed. + /// + /// ```rust + /// use mcandle_core::{Tensor, Device}; + /// let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?; + /// let a = a.affine(4., -2.)?; + /// assert_eq!(a.to_vec2::()?, &[[-2.0, 2.0], [6.0, 10.0]]); + /// # Ok::<(), mcandle_core::Error>(()) + /// ``` + pub fn affine(&self, mul: f64, add: f64) -> Result { + if self.elem_count() == 0 { + return Ok(self.clone()); + } + let storage = self.storage().affine(self.layout(), mul, add)?; + let op = BackpropOp::new1(self, |arg| Op::Affine { arg, mul, add }); + Ok(from_storage(storage, self.shape(), op, false)) + } + + /// Applies the Exponential Linear Unit (ELU) function on each element of the input tensor. + pub fn elu(&self, alpha: f64) -> Result { + if self.elem_count() == 0 { + return Ok(self.clone()); + } + let storage = self.storage().elu(self.layout(), alpha)?; + let op = BackpropOp::new1(self, |t| Op::Elu(t, alpha)); + Ok(from_storage(storage, self.shape(), op, false)) + } + + /// Raise the tensor to some float exponent `e`. + pub fn powf(&self, e: f64) -> Result { + if self.elem_count() == 0 { + return Ok(self.clone()); + } + let storage = self.storage().powf(self.layout(), e)?; + let op = BackpropOp::new1(self, |t| Op::Powf(t, e)); + Ok(from_storage(storage, self.shape(), op, false)) + } + + pub(crate) fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> { + if dim >= self.dims().len() { + Err(Error::DimOutOfRange { + shape: self.shape().clone(), + dim: dim as i32, + op, + } + .bt())? + } else { + Ok(()) + } + } + + /// Split a tensor into the specified number of chunks, this may return less chunks than + /// specified. + pub fn chunk(&self, chunks: usize, dim: D) -> Result> { + let dim = dim.to_index(self.shape(), "chunk")?; + let size = self.dim(dim)?; + if size < chunks { + (0..size).map(|i| self.narrow(dim, i, 1)).collect() + } else { + let chunk_size = size / chunks; + let cnt_additional = size % chunks; + let mut tensors = vec![]; + let mut sum_chunk_size = 0; + for i in 0..chunks { + let chunk_size = if i < cnt_additional { + chunk_size + 1 + } else { + chunk_size + }; + let tensor = self.narrow(dim, sum_chunk_size, chunk_size)?; + tensors.push(tensor); + sum_chunk_size += chunk_size + } + Ok(tensors) + } + } + + /// Returns a new tensor that is a narrowed version of the input, the dimension `dim` + /// ranges from `start` to `start + len`. + /// ``` + /// use mcandle_core::{Tensor, Device}; + /// let a = Tensor::new(&[ + /// [0f32, 1., 2.], + /// [3. , 4., 5.], + /// [6. , 7., 8.] + /// ], &Device::Cpu)?; + /// + /// let b = a.narrow(0, 1, 2)?; + /// assert_eq!(b.shape().dims(), &[2, 3]); + /// assert_eq!(b.to_vec2::()?, &[ + /// [3., 4., 5.], + /// [6., 7., 8.] + /// ]); + /// + /// let c = a.narrow(1, 1, 1)?; + /// assert_eq!(c.shape().dims(), &[3, 1]); + /// assert_eq!(c.to_vec2::()?, &[ + /// [1.], + /// [4.], + /// [7.] + /// ]); + /// # Ok::<(), mcandle_core::Error>(()) + /// ``` + pub fn narrow(&self, dim: D, start: usize, len: usize) -> Result { + let dims = self.dims(); + let dim = dim.to_index(self.shape(), "narrow")?; + let err = |msg| { + Err::<(), _>( + Error::NarrowInvalidArgs { + shape: self.shape().clone(), + dim, + start, + len, + msg, + } + .bt(), + ) + }; + if start > dims[dim] { + err("start > dim_len")? + } + if start.saturating_add(len) > dims[dim] { + err("start + len > dim_len")? + } + if start == 0 && dims[dim] == len { + Ok(self.clone()) + } else { + let op = BackpropOp::new1(self, |t| Op::Narrow(t, dim, start, len)); + let layout = self.layout().narrow(dim, start, len)?; + let tensor_ = Tensor_ { + id: TensorId::new(), + storage: self.storage.clone(), + layout, + op, + is_variable: false, + dtype: self.dtype, + device: self.device.clone(), + }; + Ok(Tensor(Arc::new(tensor_))) + } + } + + fn squeeze_dims(self, dims: &[usize]) -> Result { + match dims { + [] => Ok(self), + [i] => self.squeeze(*i), + dims => { + let dims = self + .dims() + .iter() + .enumerate() + .filter_map(|(dim_idx, &v)| { + if dims.contains(&dim_idx) { + None + } else { + Some(v) + } + }) + .collect::>(); + self.reshape(dims) + } + } + } + + fn reduce_impl(&self, dim: D, keepdim: bool, op: ReduceOp) -> Result { + let dim = dim.to_index(self.shape(), op.name())?; + let storage = self.storage().reduce_op(op, self.layout(), &[dim])?; + let mut dims = self.dims().to_vec(); + dims[dim] = 1; + let op = match op { + ReduceOp::Sum | ReduceOp::Min | ReduceOp::Max => { + BackpropOp::new1(self, |arg| Op::Reduce(arg, op, dims.to_vec())) + } + ReduceOp::ArgMin | ReduceOp::ArgMax => BackpropOp::none(), + }; + let res = from_storage(storage, dims, op, false); + if keepdim { + Ok(res) + } else { + res.squeeze_dims(&[dim]) + } + } + + fn sum_impl(&self, sum_dims: D, keepdim: bool) -> Result { + let sum_dims = sum_dims.to_indexes(self.shape(), "sum")?; + let storage = self + .storage() + .reduce_op(ReduceOp::Sum, self.layout(), &sum_dims)?; + let mut dims = self.dims().to_vec(); + for &sum_dim in sum_dims.iter() { + dims[sum_dim] = 1 + } + let op = BackpropOp::new1(self, |a| Op::Reduce(a, ReduceOp::Sum, dims.to_vec())); + let sum = from_storage(storage, dims, op, false); + if keepdim { + Ok(sum) + } else { + sum.squeeze_dims(&sum_dims) + } + } + + /// Roll the tensor input along the given dimension. + /// Elements that are shifted beyond the last position are re-introduced at the first position. + /// + /// ```rust + /// # use mcandle_core::{Tensor, Device}; + /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?; + /// let tensor = tensor.roll(1, 0)?; + /// assert_eq!(tensor.to_vec2::()?, &[[4., 5.], [0., 1.], [2., 3.]]); + /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?; + /// let tensor = tensor.roll(-1, 0)?; + /// assert_eq!(tensor.to_vec2::()?, &[[2., 3.], [4., 5.], [0., 1.]]); + /// # Ok::<(), mcandle_core::Error>(()) + /// ``` + pub fn roll(&self, shift: i32, dim: D) -> Result + where + D: Dim + Clone, + { + let dim = dim.to_index(self.shape(), "roll")?; + let dim_size = self.dim(dim)?; + let shift = shift.rem_euclid(dim_size as i32) as usize; + if shift == 0 { + Ok(self.clone()) + } else { + let a = self.narrow(dim, 0, dim_size - shift)?; + let b = self.narrow(dim, dim_size - shift, shift)?; + Tensor::cat(&[&b, &a], dim) + } + } + + /// Returns the sum of all elements in the input tensor. The sum is performed over all the + /// input dimensions. + /// + /// The resulting tensor has a shape that is similar to the shape of the input tensor, except + /// that the number of elements for each dimension index in `sum_dims` is 1. + /// + /// ```rust + /// use mcandle_core::{Tensor, Device}; + /// let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?; + /// let s = a.sum_keepdim(0)?; + /// assert_eq!(s.to_vec2::()?, &[[2., 4.]]); + /// let s = a.sum_keepdim(1)?; + /// assert_eq!(s.to_vec2::()?, &[[1.], [5.]]); + /// let s = a.sum_keepdim((0, 1))?; + /// assert_eq!(s.to_vec2::()?, &[[6.]]); + /// # Ok::<(), mcandle_core::Error>(()) + /// ``` + pub fn sum_keepdim(&self, sum_dims: D) -> Result { + self.sum_impl(sum_dims, true) + } + + /// Returns the sum of all elements in the input tensor. The sum is performed over all the + /// input dimensions and compared to `sum_keepdim` these dimensions are squeezed rather than + /// kept. + pub fn sum(&self, sum_dims: D) -> Result { + self.sum_impl(sum_dims, false) + } + + /// Returns the mean of all elements in the input tensor. The mean is performed over all the + /// input dimensions. + /// + /// The resulting tensor has a shape that is similar to the shape of the input tensor, except + /// that the number of elements for each dimension index in `mean_dims` is 1. + /// + /// ```rust + /// use mcandle_core::{Tensor, Device}; + /// let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?; + /// let s = a.mean_keepdim(0)?; + /// assert_eq!(s.to_vec2::()?, &[[1., 2.]]); + /// let s = a.mean_keepdim(1)?; + /// assert_eq!(s.to_vec2::()?, &[[0.5], [2.5]]); + /// let s = a.mean_keepdim((0, 1))?; + /// assert_eq!(s.to_vec2::()?, &[[1.5]]); + /// # Ok::<(), mcandle_core::Error>(()) + /// ``` + pub fn mean_keepdim(&self, mean_dims: D) -> Result { + let mean_dims = mean_dims.to_indexes(self.shape(), "mean-keepdim")?; + let reduced_dim: usize = mean_dims.iter().map(|i| self.dims()[*i]).product(); + let scale = 1f64 / (reduced_dim as f64); + self.sum_impl(mean_dims, true)? * scale + } + + /// Returns the mean of all elements in the input tensor. The mean is performed over all the + /// input dimensions and compared to `mean_keepdim` these dimensions are squeezed rather than + /// kept. + pub fn mean(&self, mean_dims: D) -> Result { + let mean_dims = mean_dims.to_indexes(self.shape(), "mean")?; + let reduced_dim: usize = mean_dims.iter().map(|i| self.dims()[*i]).product(); + let scale = 1f64 / (reduced_dim as f64); + self.sum_impl(mean_dims, false)? * scale + } + + /// Returns the unbiased variance over the selected dimension. + pub fn var_keepdim(&self, dim: D) -> Result { + let dim = dim.to_index(self.shape(), "var")?; + let mean = self.mean_keepdim(dim)?; + let squares = self.broadcast_sub(&mean)?.sqr()?; + squares.sum_impl(dim, true)? / (self.dim(dim)? - 1) as f64 + } + + /// Returns the unbiased variance over the selected dimension. + pub fn var(&self, dim: D) -> Result { + let dim = dim.to_index(self.shape(), "var")?; + self.var_keepdim(dim)?.squeeze(dim) + } + + /// Gathers the maximum value across the selected dimension. The resulting shape has the same + /// number of dimensions as the original tensor and the select dimension has a single element. + pub fn max_keepdim(&self, dim: D) -> Result { + self.reduce_impl(dim, true, ReduceOp::Max) + } + + /// Similar to `max_keepdim` but the target dimension is squeezed. + pub fn max(&self, dim: D) -> Result { + self.reduce_impl(dim, false, ReduceOp::Max) + } + + /// Gathers the minimum value across the selected dimension. The resulting shape has the same + /// number of dimensions as the original tensor and the select dimension has a single element. + pub fn min_keepdim(&self, dim: D) -> Result { + self.reduce_impl(dim, true, ReduceOp::Min) + } + + /// Similar to `min_keepdim` but the target dimension is squeezed. + pub fn min(&self, dim: D) -> Result { + self.reduce_impl(dim, false, ReduceOp::Min) + } + + pub fn argmax_keepdim(&self, dim: D) -> Result { + self.reduce_impl(dim, true, ReduceOp::ArgMax) + } + + /// Similar to `argmax_keepdim` but the target dimension is squeezed. + pub fn argmax(&self, dim: D) -> Result { + self.reduce_impl(dim, false, ReduceOp::ArgMax) + } + + pub fn argmin_keepdim(&self, dim: D) -> Result { + self.reduce_impl(dim, true, ReduceOp::ArgMin) + } + + /// Similar to `argmin_keepdim` but the target dimension is squeezed. + pub fn argmin(&self, dim: D) -> Result { + self.reduce_impl(dim, false, ReduceOp::ArgMin) + } + + /// Element-wise comparison between two tensors, e.g. equality, greater than, ... The actual + /// comparison operation is specified by the `op` argument. + /// + /// The returned tensor has the same shape as the original tensors and uses `u8` elements. + pub fn cmp(&self, rhs: T, op: CmpOp) -> Result { + let rhs = match rhs.to_tensor_scalar()? { + crate::scalar::TensorScalar::Tensor(rhs) => rhs, + crate::scalar::TensorScalar::Scalar(rhs) => rhs + .to_dtype(self.dtype())? + .to_device(self.device())? + .broadcast_as(self.shape())?, + }; + let shape = self.same_shape_binary_op(&rhs, "cmp")?; + let storage = self + .storage() + .cmp(op, &rhs.storage(), self.layout(), rhs.layout())?; + let op = BackpropOp::new1(self, |a| Op::Cmp(a, op)); + Ok(from_storage(storage, shape.dims(), op, false)) + } + + /// Element-wise equality. + pub fn eq(&self, rhs: T) -> Result { + self.cmp(rhs, CmpOp::Eq) + } + + /// Element-wise non-equality. + pub fn ne(&self, rhs: T) -> Result { + self.cmp(rhs, CmpOp::Ne) + } + + /// Element-wise comparison with lower-than, the returned tensor uses value 1 where `self < + /// rhs` and 0 otherwise. + pub fn lt(&self, rhs: T) -> Result { + self.cmp(rhs, CmpOp::Lt) + } + + /// Element-wise comparison with greater-than, the returned tensor uses value 1 where `self > + /// rhs` and 0 otherwise. + pub fn gt(&self, rhs: T) -> Result { + self.cmp(rhs, CmpOp::Gt) + } + + /// Element-wise comparison with greater-equal, the returned tensor uses value 1 where `self >= + /// rhs` and 0 otherwise. + pub fn ge(&self, rhs: T) -> Result { + self.cmp(rhs, CmpOp::Ge) + } + + /// Element-wise comparison with lower-equal, the returned tensor uses value 1 where `self <= + /// rhs` and 0 otherwise. + pub fn le(&self, rhs: T) -> Result { + self.cmp(rhs, CmpOp::Le) + } + + /// Clamp the tensor values to be between `min` and `max`. + pub fn clamp(&self, min: T1, max: T2) -> Result { + self.maximum(min)?.minimum(max) + } + + /// Interpolate the input tensor to the `target_size` size, taking the value of the nearest element. + /// + /// The input tensor should have three dimensions, `(batch, channels, l)`, the returned + /// tensor also has three dimensions, `(batch, channels, target_size)`. + pub fn interpolate1d(&self, target_size: usize) -> Result { + let (n, c, _l) = self.dims3()?; + let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest1D { arg, target_size }); + let storage = self + .storage() + .upsample_nearest1d(self.layout(), target_size)?; + Ok(from_storage(storage, (n, c, target_size), op, false)) + } + + /// Alias for `interpolate1d`. + pub fn upsample_nearest1d(&self, target_size: usize) -> Result { + self.interpolate1d(target_size) + } + + /// Interpolate the input tensor to the `(target_h, target_w)` size, taking the value of the + /// nearest element. + /// + /// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned + /// tensor also has four dimensions, `(batch, channels, target_h, target_w)`. + pub fn interpolate2d(&self, target_h: usize, target_w: usize) -> Result { + let (n, c, _h, _w) = self.dims4()?; + let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest2D { + arg, + target_h, + target_w, + }); + let storage = self + .storage() + .upsample_nearest2d(self.layout(), target_h, target_w)?; + Ok(from_storage(storage, (n, c, target_h, target_w), op, false)) + } + + /// Alias for `interpolate2d`. + pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result { + self.interpolate2d(target_h, target_w) + } + + /// 2D average pooling over an input tensor with multiple channels. + /// + /// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned + /// tensor also has four dimensions, `(batch, channels, h', w')`. The pooling is performed on + /// the two last dimensions using a kernel of size `sz`. The returned element is the average + /// value over the kernel window. + pub fn avg_pool2d(&self, sz: T) -> Result { + let sz = sz.to_usize2(); + self.avg_pool2d_with_stride(sz, sz) + } + + /// Same as `avg_pool2d` but with a `stride` that can be set to a value different from the + /// kernel size. + pub fn avg_pool2d_with_stride( + &self, + kernel_size: T, + stride: T, + ) -> Result { + let kernel_size = kernel_size.to_usize2(); + let stride = stride.to_usize2(); + let (n, c, h, w) = self.dims4()?; + if h < kernel_size.0 || w < kernel_size.1 { + bail!("kernel-size {kernel_size:?} is larger than the input size {h},{w}") + } + // https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d + let h_out = (h - kernel_size.0) / stride.0 + 1; + let w_out = (w - kernel_size.1) / stride.1 + 1; + let op = BackpropOp::new1(self, |arg| Op::AvgPool2D { + arg, + kernel_size, + stride, + }); + let storage = self + .storage() + .avg_pool2d(self.layout(), kernel_size, stride)?; + Ok(from_storage(storage, (n, c, h_out, w_out), op, false)) + } + + /// 2D max pooling over an input tensor with multiple channels. + /// + /// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned + /// tensor also has four dimensions, `(batch, channels, h', w')`. The pooling is performed on + /// the two last dimensions using a kernel of size `sz`, the returned element is the maximum + /// value over the kernel window. + pub fn max_pool2d(&self, sz: T) -> Result { + let sz = sz.to_usize2(); + self.max_pool2d_with_stride(sz, sz) + } + + /// Same as `max_pool2d` but with a `stride` that can be set to a value different from the + /// kernel size. + pub fn max_pool2d_with_stride( + &self, + kernel_size: T, + stride: T, + ) -> Result { + let kernel_size = kernel_size.to_usize2(); + let stride = stride.to_usize2(); + let (n, c, h, w) = self.dims4()?; + if h < kernel_size.0 || w < kernel_size.1 { + bail!("kernel-size {kernel_size:?} is larger than the input size {h},{w}") + } + // https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d + let h_out = (h - kernel_size.0) / stride.0 + 1; + let w_out = (w - kernel_size.1) / stride.1 + 1; + let op = BackpropOp::new1(self, |arg| Op::MaxPool2D { + arg, + kernel_size, + stride, + }); + let storage = self + .storage() + .max_pool2d(self.layout(), kernel_size, stride)?; + Ok(from_storage(storage, (n, c, h_out, w_out), op, false)) + } + + /// Returns the matrix-multiplication of the input tensor with the other provided tensor. + /// + /// # Arguments + /// + /// * `self` - A tensor with dimensions `b1, b2, ..., bi, m, k`. + /// * `rhs` - A tensor with dimensions `b1, b2, ..., bi, k, n`. + /// + /// The resulting tensor has dimensions `b1, b2, ..., bi, m, n`. + pub fn matmul(&self, rhs: &Self) -> Result { + let a_dims = self.shape().dims(); + let b_dims = rhs.shape().dims(); + + let dim = a_dims.len(); + + if dim < 2 || b_dims.len() != dim { + Err(Error::ShapeMismatchBinaryOp { + lhs: self.shape().clone(), + rhs: rhs.shape().clone(), + op: "matmul", + } + .bt())? + } + + let m = a_dims[dim - 2]; + let k = a_dims[dim - 1]; + let k2 = b_dims[dim - 2]; + let n = b_dims[dim - 1]; + + let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]); + if c_shape.elem_count() == 0 || k == 0 { + return Tensor::zeros(c_shape, self.dtype(), self.device()); + } + let batching: usize = a_dims[..dim - 2].iter().product(); + let batching_b: usize = b_dims[..dim - 2].iter().product(); + if k != k2 || batching != batching_b { + Err(Error::ShapeMismatchBinaryOp { + lhs: self.shape().clone(), + rhs: rhs.shape().clone(), + op: "matmul", + } + .bt())? + } + + let storage = self.storage().matmul_with_alpha( + &rhs.storage(), + None, + (batching, m, n, k), + self.layout(), + rhs.layout(), + )?; + let op = BackpropOp::new2(self, rhs, Op::Matmul); + Ok(from_storage(storage, c_shape, op, false)) + } + + /// Matrix-multiplication with broadcasting support. + /// + /// Compared to `matmul` the two matrixes are allowed to have different dimensions as long as + /// they are compatible for broadcast. E.g. if `self` has shape `(j, 1, n, k)` and `rhs` has + /// shape `(l, k, m)`, the output will have shape `(j, l, n, m)`. + pub fn broadcast_matmul(&self, rhs: &Self) -> Result { + let lhs = self; + let (l_shape, r_shape) = lhs.shape().broadcast_shape_matmul(rhs.shape())?; + let l_broadcast = l_shape != *lhs.shape(); + let r_broadcast = r_shape != *rhs.shape(); + // TODO: Avoid concretising the broadcasted matrixes via contiguous. + match (l_broadcast, r_broadcast) { + (true, true) => lhs + .broadcast_as(&l_shape)? + .contiguous()? + .matmul(&rhs.broadcast_as(&r_shape)?.contiguous()?), + (false, true) => lhs.matmul(&rhs.broadcast_as(&r_shape)?.contiguous()?), + (true, false) => lhs.broadcast_as(&l_shape)?.contiguous()?.matmul(rhs), + (false, false) => lhs.matmul(rhs), + } + } + + /// Returns the matrix-multiplication of the input tensor with the other provided tensor. The result is scaled + /// and then added to the output tensor, the bias tensor `c`. + /// + /// If `scale` is None, then the output is as follows: + /// `c := c + axb` + /// + /// Else: + /// `c := c + scale * (axb)` + /// + /// This function is faster than a matmul followed by some scaling multiply because the scaling is fused in the GEMM kernel. + /// This is incompatible with gradient tracking. No gradients will be tracked on this operation. However, this also means + /// there is an allocation saved as the output is in `c`. + /// + /// # Arguments + /// + /// * `self` - A tensor with dimensions `b1, b2, ..., bi, m, k`. + /// * `rhs` - A tensor with dimensions `b1, b2, ..., bi, k, n`. + /// * `c` - A tensor with dimensions `b1, b2, ..., bi, m, n`, into which the result is accumulated and added to. + /// * `scale` - Factor to multiply `self` x `rhs` by + pub fn matmul_with_alpha_beta( + &self, + rhs: &Self, + c: &mut Self, + scale: Option, + ) -> Result<()> { + let a_dims = self.shape().dims(); + let b_dims = rhs.shape().dims(); + + let dim = a_dims.len(); + + if dim < 2 || b_dims.len() != dim { + Err(Error::ShapeMismatchBinaryOp { + lhs: self.shape().clone(), + rhs: rhs.shape().clone(), + op: "matmul", + } + .bt())? + } + + let m = a_dims[dim - 2]; + let k = a_dims[dim - 1]; + let k2 = b_dims[dim - 2]; + let n = b_dims[dim - 1]; + + let exp_c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]); + if exp_c_shape.elem_count() == 0 || k == 0 { + bail!("Expected `c` to have more than one element, got 0."); + } + if exp_c_shape != c.shape().clone() { + Err(Error::UnexpectedShape { + msg: "`c` has an unexpected shape.".to_string(), + expected: exp_c_shape, + got: c.shape().clone(), + })? + } + + let batching: usize = a_dims[..dim - 2].iter().product(); + let batching_b: usize = b_dims[..dim - 2].iter().product(); + if k != k2 || batching != batching_b { + Err(Error::ShapeMismatchBinaryOp { + lhs: self.shape().clone(), + rhs: rhs.shape().clone(), + op: "matmul_with_alpha_beta", + } + .bt())? + } + + self.storage().matmul_with_alpha_beta( + &rhs.storage(), + &mut c.storage_mut(), + scale, + (batching, m, n, k), + self.layout(), + rhs.layout(), + c.layout(), + ) + } + + /// Returns the matrix-multiplication of the input tensor with the other provided tensor. The result is scaled. + /// + /// This function is faster than a matmul followed by some scaling multiply because the scaling is fused in the GEMM kernel. + /// + /// The output is as follows: + /// `scale * (axb)` + /// + /// + /// This is incompatible with gradient tracking. No gradients will be tracked on this operation. + /// + /// # Arguments + /// + /// * `self` - A tensor with dimensions `b1, b2, ..., bi, m, k`. + /// * `rhs` - A tensor with dimensions `b1, b2, ..., bi, k, n`. + /// * `scale` - Factor to multiply `self` x `rhs` by. + pub fn matmul_with_alpha(&self, rhs: &Self, scale: Option) -> Result { + let a_dims = self.shape().dims(); + let b_dims = rhs.shape().dims(); + + let dim = a_dims.len(); + + if dim < 2 || b_dims.len() != dim { + Err(Error::ShapeMismatchBinaryOp { + lhs: self.shape().clone(), + rhs: rhs.shape().clone(), + op: "matmul", + } + .bt())? + } + + let m = a_dims[dim - 2]; + let k = a_dims[dim - 1]; + let k2 = b_dims[dim - 2]; + let n = b_dims[dim - 1]; + + let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]); + if c_shape.elem_count() == 0 || k == 0 { + return Tensor::zeros(c_shape, self.dtype(), self.device()); + } + let batching: usize = a_dims[..dim - 2].iter().product(); + let batching_b: usize = b_dims[..dim - 2].iter().product(); + if k != k2 || batching != batching_b { + Err(Error::ShapeMismatchBinaryOp { + lhs: self.shape().clone(), + rhs: rhs.shape().clone(), + op: "matmul_with_alpha", + } + .bt())? + } + + let storage = self.storage().matmul_with_alpha( + &rhs.storage(), + scale, + (batching, m, n, k), + self.layout(), + rhs.layout(), + )?; + let op = BackpropOp::new2(self, rhs, Op::Matmul); + Ok(from_storage(storage, c_shape, op, false)) + } + + /// Matrix-multiplication with broadcasting support and fused scaling. + /// + /// Compared to `matmul` the two matrixes are allowed to have different dimensions as long as + /// they are compatible for broadcast. E.g. if `self` has shape `(j, 1, n, k)` and `rhs` has + /// shape `(l, k, m)`, the output will have shape `(j, l, n, m)`. + pub fn broadcast_matmul_with_alpha(&self, rhs: &Self, scale: Option) -> Result { + let lhs = self; + let (l_shape, r_shape) = lhs.shape().broadcast_shape_matmul(rhs.shape())?; + let l_broadcast = l_shape != *lhs.shape(); + let r_broadcast = r_shape != *rhs.shape(); + // TODO: Avoid concretising the broadcasted matrixes via contiguous. + match (l_broadcast, r_broadcast) { + (true, true) => lhs + .broadcast_as(&l_shape)? + .contiguous()? + .matmul_with_alpha(&rhs.broadcast_as(&r_shape)?.contiguous()?, scale), + (false, true) => { + lhs.matmul_with_alpha(&rhs.broadcast_as(&r_shape)?.contiguous()?, scale) + } + (true, false) => lhs + .broadcast_as(&l_shape)? + .contiguous()? + .matmul_with_alpha(rhs, scale), + (false, false) => lhs.matmul_with_alpha(rhs, scale), + } + } + + /// Returns a tensor with the same shape as the input tensor, the values are taken from + /// `on_true` if the input tensor value is not zero, and `on_false` at the positions where the + /// input tensor is equal to zero. + pub fn where_cond(&self, on_true: &Self, on_false: &Self) -> Result { + let _shap = self.same_shape_binary_op(on_true, "where_cond")?; + let shape = self.same_shape_binary_op(on_false, "where_cond")?; + let storage = self.storage().where_cond( + self.layout(), + &on_true.storage(), + on_true.layout(), + &on_false.storage(), + on_false.layout(), + )?; + let op = BackpropOp::new3(self, on_true, on_false, Op::WhereCond); + Ok(from_storage(storage, shape, op, false)) + } + + /// Returns a tensor with the values from the `self` tensor at the index corresponding to the + /// values hold in the `ids` tensor. + /// + /// # Arguments + /// + /// * `self` - A tensor with dimensions `v, h`. + /// * `ids` - A tensor with dimensions `s` and with integer values between 0 and v (exclusive). + /// + /// The resulting tensor has dimensions `s, h`. `s` is called the sequence length, `v` the + /// vocabulary size, and `h` the hidden size. + /// + /// ```rust + /// use mcandle_core::{Tensor, Device}; + /// let values = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?; + /// let ids = Tensor::new(&[2u32, 1u32, 2u32], &Device::Cpu)?; + /// let emb = values.embedding(&ids)?; + /// assert_eq!(emb.to_vec2::()?, &[[4., 5.], [2., 3.], [4., 5.]]); + /// # Ok::<(), mcandle_core::Error>(()) + /// ``` + pub fn embedding(&self, ids: &Self) -> Result { + if self.rank() != 2 || ids.rank() != 1 { + Err(Error::ShapeMismatchBinaryOp { + lhs: self.shape().clone(), + rhs: ids.shape().clone(), + op: "embedding", + } + .bt())? + } + self.index_select(ids, 0) + } + + /// Returns an iterator over position of the elements in the storage when ranging over the + /// index tuples in lexicographic order. + pub fn strided_index(&self) -> crate::StridedIndex { + self.layout.strided_index() + } + + /// Similar to `strided_index` but returns the position of the start of each contiguous block + /// as well as the length of the contiguous blocks. For a contiguous tensor, the index iterator + /// will only return the start offset and the size would be the number of elements in the + /// tensor. + pub fn strided_blocks(&self) -> crate::StridedBlocks { + self.layout.strided_blocks() + } + + /// Returns the data contained in a 1D tensor as a vector of scalar values. + pub fn to_vec1(&self) -> Result> { + if self.rank() != 1 { + Err(Error::UnexpectedNumberOfDims { + expected: 1, + got: self.rank(), + shape: self.shape().clone(), + } + .bt())? + } + let from_cpu_storage = |cpu_storage: &crate::CpuStorage| { + let data = S::cpu_storage_as_slice(cpu_storage)?; + let data = match self.layout.contiguous_offsets() { + Some((o1, o2)) => data[o1..o2].to_vec(), + None => self.strided_index().map(|i| data[i]).collect(), + }; + Ok::, Error>(data) + }; + match &*self.storage() { + Storage::Cpu(storage) => from_cpu_storage(storage), + Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?), + Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?), + } + } + + /// Returns the data contained in a 2D tensor as a vector of vector of scalar values. + pub fn to_vec2(&self) -> Result>> { + let (dim1, dim2) = self.dims2()?; + let from_cpu_storage = |cpu_storage: &crate::CpuStorage| { + let data = S::cpu_storage_as_slice(cpu_storage)?; + let mut rows = vec![]; + match self.layout.contiguous_offsets() { + Some((o1, o2)) => { + let data = &data[o1..o2]; + for idx_row in 0..dim1 { + rows.push(data[idx_row * dim2..(idx_row + 1) * dim2].to_vec()) + } + } + None => { + let mut src_index = self.strided_index(); + for _idx_row in 0..dim1 { + let row = (0..dim2).map(|_| data[src_index.next().unwrap()]).collect(); + rows.push(row) + } + assert!(src_index.next().is_none()); + } + } + Ok(rows) + }; + match &*self.storage() { + Storage::Cpu(storage) => from_cpu_storage(storage), + Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?), + Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?), + } + } + + /// Returns the data contained in a 3D tensor. + pub fn to_vec3(&self) -> Result>>> { + let (dim1, dim2, dim3) = self.dims3()?; + let from_cpu_storage = |cpu_storage: &crate::CpuStorage| { + let data = S::cpu_storage_as_slice(cpu_storage)?; + let mut top_rows = vec![]; + match self.layout.contiguous_offsets() { + Some((o1, o2)) => { + let data = &data[o1..o2]; + let dim23 = dim2 * dim3; + for idx1 in 0..dim1 { + let data = &data[idx1 * dim23..(idx1 + 1) * dim23]; + let mut rows = vec![]; + for idx2 in 0..dim2 { + rows.push(data[idx2 * dim3..(idx2 + 1) * dim3].to_vec()) + } + top_rows.push(rows); + } + } + None => { + let mut src_index = self.strided_index(); + for _idx in 0..dim1 { + let mut rows = vec![]; + for _jdx in 0..dim2 { + let row = (0..dim3).map(|_| data[src_index.next().unwrap()]).collect(); + rows.push(row) + } + top_rows.push(rows); + } + assert!(src_index.next().is_none()); + } + } + Ok(top_rows) + }; + match &*self.storage() { + Storage::Cpu(storage) => from_cpu_storage(storage), + Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?), + Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?), + } + } + + /// The dtype for the elements stored in the input tensor. + pub fn dtype(&self) -> DType { + self.dtype + } + + /// The device on which the input tensor is located. + pub fn device(&self) -> &Device { + &self.device + } + + /// The tensor shape, i.e. dimension sizes on each axis. + pub fn shape(&self) -> &Shape { + self.layout().shape() + } + + /// The dimension size for this tensor on each axis. + pub fn dims(&self) -> &[usize] { + self.shape().dims() + } + + /// The dimension size for a specified dimension index. + pub fn dim(&self, dim: D) -> Result { + let dim = dim.to_index(self.shape(), "dim")?; + Ok(self.dims()[dim]) + } + + /// The layout of the input tensor, this stores both the shape of the tensor as well as the + /// strides and the start offset to apply to the underlying storage. + pub fn layout(&self) -> &Layout { + &self.layout + } + + pub fn stride(&self) -> &[usize] { + self.layout.stride() + } + + /// The number of dimensions for this tensor, 0 for a scalar tensor, 1 for a 1D tensor, etc. + pub fn rank(&self) -> usize { + self.shape().rank() + } + + /// The number of elements stored in this tensor. + pub fn elem_count(&self) -> usize { + self.shape().elem_count() + } + + /// The unique identifier for this tensor. + pub fn id(&self) -> TensorId { + self.id + } + + /// Whether this tensor is a variable or not. A variable is a tensor for which gradient is + /// tracked and on which backpropagation can be performed. + pub fn is_variable(&self) -> bool { + self.is_variable + } + + pub(crate) fn op(&self) -> &Option { + &self.op + } + + /// Computes the sum of all the elements in this tensor and returns a tensor holding this + /// scalar with zero dimensions. + /// + /// ```rust + /// use mcandle_core::{Tensor, Device}; + /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?; + /// let tensor = tensor.sum_all()?; + /// assert_eq!(tensor.to_scalar::()?, 15.); + /// # Ok::<(), mcandle_core::Error>(()) + /// ``` + pub fn sum_all(&self) -> Result { + let dims: Vec<_> = (0..self.rank()).collect(); + self.sum(dims) + } + + pub fn mean_all(&self) -> Result { + self.sum_all()? / self.elem_count() as f64 + } + + fn flatten_( + &self, + start_dim: Option, + end_dim: Option, + ) -> Result { + if self.rank() == 0 { + self.reshape(1) + } else { + let start_dim = match start_dim { + None => 0, + Some(dim) => dim.to_index(self.shape(), "flatten")?, + }; + let end_dim = match end_dim { + None => self.rank() - 1, + Some(dim) => dim.to_index(self.shape(), "flatten")?, + }; + if start_dim < end_dim { + let dims = self.dims(); + let mut dst_dims = dims[..start_dim].to_vec(); + dst_dims.push(dims[start_dim..end_dim + 1].iter().product::()); + if end_dim + 1 < dims.len() { + dst_dims.extend(&dims[end_dim + 1..]); + } + self.reshape(dst_dims) + } else { + Ok(self.clone()) + } + } + } + + /// Flattens the input tensor on the dimension indexes from `start_dim` to `end_dim` (both + /// inclusive). + pub fn flatten(&self, start_dim: D1, end_dim: D2) -> Result { + self.flatten_(Some(start_dim), Some(end_dim)) + } + + /// Flattens the input tensor on the dimension indexes from `0` to `end_dim` (inclusive). + pub fn flatten_to(&self, end_dim: D) -> Result { + self.flatten_(None::, Some(end_dim)) + } + + /// Flattens the input tensor on the dimension indexes from `start_dim` (inclusive) to the last + /// dimension. + pub fn flatten_from(&self, start_dim: D) -> Result { + self.flatten_(Some(start_dim), None::) + } + + /// Flattens the input tensor by reshaping it into a one dimension tensor. + /// + /// ```rust + /// use mcandle_core::{Tensor, Device}; + /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?; + /// let tensor = tensor.flatten_all()?; + /// assert_eq!(tensor.to_vec1::()?, &[0., 1., 2., 3., 4., 5.]); + /// # Ok::<(), mcandle_core::Error>(()) + /// ``` + pub fn flatten_all(&self) -> Result { + self.flatten_(None::, None::) + } + + /// Returns the sub-tensor fixing the index at `i` on the first dimension. + /// + /// ```rust + /// use mcandle_core::{Tensor, Device}; + /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?; + /// let t = tensor.get(0)?; + /// assert_eq!(t.to_vec1::()?, &[0., 1.]); + /// let t = tensor.get(1)?; + /// assert_eq!(t.to_vec1::()?, &[2., 3.]); + /// # Ok::<(), mcandle_core::Error>(()) + /// ``` + pub fn get(&self, i: usize) -> Result { + let dims = self.dims(); + if dims.is_empty() { + Ok(self.clone()) + } else { + self.narrow(0, i, 1)?.reshape(&dims[1..]) + } + } + + /// Returns the sub-tensor fixing the index at `index` on the dimension `dim`. + /// + /// ```rust + /// use mcandle_core::{Tensor, Device}; + /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?; + /// let t = tensor.get_on_dim(1, 0)?; + /// assert_eq!(t.to_vec1::()?, &[0., 2., 4.]); + /// let t = tensor.get_on_dim(1, 1)?; + /// assert_eq!(t.to_vec1::()?, &[1., 3., 5.]); + /// let t = tensor.get_on_dim(0, 1)?; + /// assert_eq!(t.to_vec1::()?, &[2., 3.]); + /// # Ok::<(), mcandle_core::Error>(()) + /// ``` + pub fn get_on_dim(&self, dim: D, index: usize) -> Result { + let dim = dim.to_index(self.shape(), "get_on_dim")?; + self.narrow(dim, index, 1)?.squeeze(dim) + } + + /// Returns a tensor that is a transposed version of the input, the two last dimensions of the + /// input are swapped. + /// + /// ```rust + /// use mcandle_core::{Tensor, Device}; + /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?; + /// let tensor = tensor.t()?; + /// assert_eq!(tensor.to_vec2::()?, &[[0.0, 2.0, 4.0], [1.0, 3.0, 5.0]]); + /// # Ok::<(), mcandle_core::Error>(()) + /// ``` + pub fn t(&self) -> Result { + let rank = self.rank(); + if rank < 2 { + Err(Error::UnexpectedNumberOfDims { + expected: 2, + got: rank, + shape: self.shape().clone(), + } + .bt())? + } + self.transpose(rank - 2, rank - 1) + } + + /// Returns a tensor that is a transposed version of the input, the given dimensions are + /// swapped. + pub fn transpose(&self, dim1: D1, dim2: D2) -> Result { + let dim1 = dim1.to_index(self.shape(), "transpose")?; + let dim2 = dim2.to_index(self.shape(), "transpose")?; + if dim1 == dim2 { + return Ok(self.clone()); + } + let op = BackpropOp::new1(self, |t| Op::Transpose(t, dim1, dim2)); + let tensor_ = Tensor_ { + id: TensorId::new(), + storage: self.storage.clone(), + layout: self.layout.transpose(dim1, dim2)?, + op, + is_variable: false, + dtype: self.dtype, + device: self.device.clone(), + }; + Ok(Tensor(Arc::new(tensor_))) + } + + /// Returns a tensor with the same data as the input where the dimensions have been permuted. + /// dims must be a permutation, i.e. include each dimension index exactly once. + /// + /// ```rust + /// use mcandle_core::{Tensor, Device}; + /// let tensor = Tensor::arange(0u32, 120u32, &Device::Cpu)?.reshape((2, 3, 4, 5))?; + /// assert_eq!(tensor.dims(), &[2, 3, 4, 5]); + /// let tensor = tensor.permute((2, 3, 1, 0))?; + /// assert_eq!(tensor.dims(), &[4, 5, 3, 2]); + /// # Ok::<(), mcandle_core::Error>(()) + /// ``` + pub fn permute(&self, dims: D) -> Result { + let dims = dims.to_indexes(self.shape(), "permute")?; + // O(n^2) permutation check but these arrays are small. + let is_permutation = + dims.len() == self.rank() && (0..dims.len()).all(|i| dims.contains(&i)); + if !is_permutation { + bail!( + "dimension mismatch in permute, tensor {:?}, dims: {:?}", + self.dims(), + dims + ) + } + let op = BackpropOp::new1(self, |t| Op::Permute(t, dims.clone())); + let tensor_ = Tensor_ { + id: TensorId::new(), + storage: self.storage.clone(), + layout: self.layout.permute(&dims)?, + op, + is_variable: false, + dtype: self.dtype, + device: self.device.clone(), + }; + Ok(Tensor(Arc::new(tensor_))) + } + + /// Returns true if the data is stored in a C contiguous (aka row major) way. + pub fn is_contiguous(&self) -> bool { + self.layout.is_contiguous() + } + + /// Returns true if the data is stored in a Fortran contiguous (aka column major) way. + pub fn is_fortran_contiguous(&self) -> bool { + self.layout.is_fortran_contiguous() + } + + /// Compared to clone, this copies the actual storage but may fail because of running out of + /// memory. + pub fn copy(&self) -> Result { + let op = BackpropOp::new1(self, Op::Copy); + let tensor_ = Tensor_ { + id: TensorId::new(), + storage: Arc::new(RwLock::new(self.storage().try_clone(self.layout())?)), + layout: self.layout.clone(), + op, + is_variable: false, + dtype: self.dtype, + device: self.device.clone(), + }; + Ok(Tensor(Arc::new(tensor_))) + } + + /// Returns a new tensor detached from the current graph, gradient are not propagated through + /// this new node. The storage of this tensor is shared with the initial tensor. + /// + /// If the tensor is already detached from the computation graph, the same tensor is returned. + pub fn detach(&self) -> Tensor { + if self.op.is_none() && !self.is_variable { + self.clone() + } else { + let tensor_ = Tensor_ { + id: TensorId::new(), + storage: self.storage.clone(), + layout: self.layout.clone(), + op: BackpropOp::none(), + is_variable: false, + dtype: self.dtype, + device: self.device.clone(), + }; + Tensor(Arc::new(tensor_)) + } + } + + /// If the target device is the same as the tensor device, only a shallow copy is performed. + pub fn to_device(&self, device: &Device) -> Result { + if self.device().same_device(device) { + Ok(self.clone()) + } else { + let storage = match (&*self.storage(), device) { + (Storage::Cpu(storage), Device::Cuda(cuda)) => { + Storage::Cuda(cuda.storage_from_cpu_storage(storage)?) + } + (Storage::Cpu(storage), Device::Metal(metal)) => { + Storage::Metal(metal.storage_from_cpu_storage(storage)?) + } + (Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?), + (Storage::Metal(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?), + (Storage::Cuda(storage), Device::Cuda(cuda)) => { + // TODO: Avoid passing through the cpu storage here, especially if the gpu ids + // are the same. + let cpu_storage = storage.to_cpu_storage()?; + Storage::Cuda(cuda.storage_from_cpu_storage(&cpu_storage)?) + } + (Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()), + _ => { + bail!( + "not implemented yet, self.device: {:?}, device: {:?}", + self.device(), + device + ) + } + }; + let op = BackpropOp::new1(self, Op::ToDevice); + let tensor_ = Tensor_ { + id: TensorId::new(), + storage: Arc::new(RwLock::new(storage)), + layout: self.layout.clone(), + op, + is_variable: false, + dtype: self.dtype, + device: device.clone(), + }; + Ok(Tensor(Arc::new(tensor_))) + } + } + + /// Returns a new tensor duplicating data from the original tensor. New dimensions are inserted + /// on the left. + pub fn broadcast_left>(&self, left_shape: S) -> Result { + let left_shape = left_shape.into(); + let mut dims = left_shape.into_dims(); + dims.extend(self.dims()); + self.broadcast_as(dims) + } + + /// Broadcast the input tensor to the target shape. This returns an error if the input shape is + /// not compatible with the target shape. + /// + /// If the input shape is `i_1, i_2, ... i_k`, the target shape has to have `k` dimensions or + /// more and shape `j_1, ..., j_l, t_1, t_2, ..., t_k`. The dimensions `j_1` to `j_l` can have + /// any value, the dimension `t_a` must be equal to `i_a` if `i_a` is different from 1. If + /// `i_a` is equal to 1, any value can be used. + pub fn broadcast_as>(&self, shape: S) -> Result { + let tensor_ = Tensor_ { + id: TensorId::new(), + storage: self.storage.clone(), + layout: self.layout.broadcast_as(shape)?, + op: BackpropOp::new1(self, Op::Broadcast), + is_variable: false, + dtype: self.dtype, + device: self.device.clone(), + }; + Ok(Tensor(Arc::new(tensor_))) + } + + /// An alias for broadcast_as. + pub fn expand>(&self, shape: S) -> Result { + self.broadcast_as(shape) + } + + /// Casts the input tensor to the target `dtype`. + /// + /// ```rust + /// use mcandle_core::{Tensor, Device}; + /// let tensor = Tensor::new(3.14159265358979f64, &Device::Cpu)?; + /// assert_eq!(tensor.to_scalar::()?, 3.14159265358979); + /// let tensor = tensor.to_dtype(mcandle_core::DType::F32)?; + /// assert_eq!(tensor.to_scalar::()?, 3.1415927); + /// # Ok::<(), mcandle_core::Error>(()) + /// ``` + pub fn to_dtype(&self, dtype: DType) -> Result { + if self.dtype() == dtype { + Ok(self.clone()) + } else { + let shape = self.shape(); + let storage = self.storage().to_dtype(self.layout(), dtype)?; + let op = BackpropOp::new1(self, Op::ToDType); + Ok(from_storage(storage, shape.clone(), op, false)) + } + } + + /// Returns a tensor that is in row major order. This is the same as the original tensor if it + /// was already contiguous, otherwise a copy is triggered. + pub fn contiguous(&self) -> Result { + if self.is_contiguous() { + Ok(self.clone()) + } else { + let shape = self.shape(); + let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? }; + self.storage() + .copy_strided_src(&mut storage, 0, self.layout())?; + let op = BackpropOp::new1(self, Op::Copy); + Ok(from_storage(storage, shape.clone(), op, false)) + } + } + + /// Returns a tensor that is in row major order. This always makes a copy. + pub fn force_contiguous(&self) -> Result { + let shape = self.shape(); + let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? }; + self.storage() + .copy_strided_src(&mut storage, 0, self.layout())?; + let op = BackpropOp::new1(self, Op::Copy); + Ok(from_storage(storage, shape.clone(), op, false)) + } + + /// Create a variable based on the values currently stored in a tensor. The storage is always + /// copied. + pub(crate) fn make_var(&self) -> Result { + let shape = self.shape().clone(); + let mut storage = unsafe { self.device().alloc_uninit(&shape, self.dtype())? }; + self.storage() + .copy_strided_src(&mut storage, 0, self.layout())?; + Ok(from_storage(storage, shape, BackpropOp::none(), true)) + } + + /// Reshape returns a tensor with the target shape provided that the number of elements of the + /// original tensor is the same. + /// If the input tensor is contiguous, this is a view on the original data. Otherwise this uses + /// a new storage and copies the data over, the returned tensor is always contiguous. + /// + /// The shape can be specified using a tuple of `usize` and at most one `()` in which case + /// the behavior is the same as when using `-1` in PyTorch: this dimension size is adjusted so + /// as to match the number of elements in the tensor. + /// + /// ```rust + /// # use mcandle_core::{Tensor, DType, Device, D}; + /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// + /// let c = a.reshape((1, 6))?; + /// assert_eq!(c.shape().dims(), &[1, 6]); + /// + /// let c = a.reshape((3, 2))?; + /// assert_eq!(c.shape().dims(), &[3, 2]); + /// + /// let c = a.reshape((2, (), 1))?; + /// assert_eq!(c.shape().dims(), &[2, 3, 1]); + /// + /// # Ok::<(), mcandle_core::Error>(()) + /// ``` + pub fn reshape(&self, s: S) -> Result { + let shape = s.into_shape(self.elem_count())?; + if shape.elem_count() != self.elem_count() { + return Err(Error::ShapeMismatchBinaryOp { + lhs: self.shape().clone(), + rhs: shape, + op: "reshape", + } + .bt()); + } + let op = BackpropOp::new1(self, Op::Reshape); + if self.is_contiguous() { + let tensor_ = Tensor_ { + id: TensorId::new(), + storage: self.storage.clone(), + layout: Layout::contiguous_with_offset(shape, self.layout.start_offset()), + op, + is_variable: false, + dtype: self.dtype, + device: self.device.clone(), + }; + Ok(Tensor(Arc::new(tensor_))) + } else { + let mut storage = unsafe { self.device().alloc_uninit(&shape, self.dtype())? }; + self.storage() + .copy_strided_src(&mut storage, 0, self.layout())?; + Ok(from_storage(storage, shape, op, false)) + } + } + + /// Creates a new tensor with the specified dimension removed if its size was one. + /// + /// ```rust + /// # use mcandle_core::{Tensor, DType, Device, D}; + /// let a = Tensor::zeros((2, 3, 1), DType::F32, &Device::Cpu)?; + /// + /// let c = a.squeeze(2)?; + /// assert_eq!(c.shape().dims(), &[2, 3]); + /// + /// let c = a.squeeze(D::Minus1)?; + /// assert_eq!(c.shape().dims(), &[2, 3]); + /// # Ok::<(), mcandle_core::Error>(()) + /// ``` + pub fn squeeze(&self, dim: D) -> Result { + // The PyTorch semantics are to return the same tensor if the target dimension + // does not have a size of 1. + let dims = self.dims(); + let dim = dim.to_index(self.shape(), "squeeze")?; + if dims[dim] == 1 { + let mut dims = dims.to_vec(); + let mut strides = self.stride().to_vec(); + dims.remove(dim); + strides.remove(dim); + let tensor_ = Tensor_ { + id: TensorId::new(), + storage: self.storage.clone(), + layout: Layout::new(dims.into(), strides, self.layout.start_offset()), + op: BackpropOp::new1(self, Op::Reshape), + is_variable: false, + dtype: self.dtype, + device: self.device.clone(), + }; + Ok(Tensor(Arc::new(tensor_))) + } else { + Ok(self.clone()) + } + } + + /// Creates a new tensor with a dimension of size one inserted at the specified position. + /// + /// ```rust + /// # use mcandle_core::{Tensor, DType, Device, D}; + /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// + /// let c = a.unsqueeze(0)?; + /// assert_eq!(c.shape().dims(), &[1, 2, 3]); + /// + /// let c = a.unsqueeze(D::Minus1)?; + /// assert_eq!(c.shape().dims(), &[2, 3, 1]); + /// # Ok::<(), mcandle_core::Error>(()) + /// ``` + pub fn unsqueeze(&self, dim: D) -> Result { + let mut dims = self.dims().to_vec(); + let mut strides = self.stride().to_vec(); + let dim = dim.to_index_plus_one(self.shape(), "unsqueeze")?; + // Cannot panic because to_index_plus_one already checks dimensions + dims.insert(dim, 1); + // Any stride would work here, but we pick one so as to maximize the probability to remain + // C contiguous. + let stride = if dim < strides.len() { strides[dim] } else { 1 }; + strides.insert(dim, stride); + let tensor_ = Tensor_ { + id: TensorId::new(), + storage: self.storage.clone(), + layout: Layout::new(dims.into(), strides, self.layout.start_offset()), + op: BackpropOp::new1(self, Op::Reshape), + is_variable: false, + dtype: self.dtype, + device: self.device.clone(), + }; + Ok(Tensor(Arc::new(tensor_))) + } + + /// Stacks two or more tensors along a particular dimension. + /// + /// All tensors must have the same rank, and the output has one additional rank + /// + /// ```rust + /// # use mcandle_core::{Tensor, DType, Device}; + /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// + /// let c = Tensor::stack(&[&a, &b], 0)?; + /// assert_eq!(c.shape().dims(), &[2, 2, 3]); + /// + /// let c = Tensor::stack(&[&a, &b], 2)?; + /// assert_eq!(c.shape().dims(), &[2, 3, 2]); + /// # Ok::<(), mcandle_core::Error>(()) + /// ``` + pub fn stack, D: Dim>(args: &[A], dim: D) -> Result { + if args.is_empty() { + Err(Error::OpRequiresAtLeastOneTensor { op: "stack" }.bt())? + } + let dim = dim.to_index_plus_one(args[0].as_ref().shape(), "stack")?; + let args = args + .iter() + .map(|t| t.as_ref().unsqueeze(dim)) + .collect::>>()?; + Self::cat(&args, dim) + } + + /// Pad the input tensor using 0s along dimension `dim`. This adds `left` elements before the + /// input tensor values and `right` elements after. + pub fn pad_with_zeros(&self, dim: D, left: usize, right: usize) -> Result { + if left == 0 && right == 0 { + Ok(self.clone()) + } else if left == 0 { + let dim = dim.to_index(self.shape(), "pad_with_zeros")?; + let mut dims = self.dims().to_vec(); + dims[dim] = right; + let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?; + Tensor::cat(&[self, &right], dim) + } else if right == 0 { + let dim = dim.to_index(self.shape(), "pad_with_zeros")?; + let mut dims = self.dims().to_vec(); + dims[dim] = left; + let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?; + Tensor::cat(&[&left, self], dim) + } else { + let dim = dim.to_index(self.shape(), "pad_with_zeros")?; + let mut dims = self.dims().to_vec(); + dims[dim] = left; + let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?; + dims[dim] = right; + let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?; + Tensor::cat(&[&left, self, &right], dim) + } + } + + /// Pad the input tensor using same values along dimension `dim`. This adds `left` elements before the + /// input tensor values and `right` elements after. + pub fn pad_with_same(&self, dim: D, left: usize, right: usize) -> Result { + if left == 0 && right == 0 { + Ok(self.clone()) + } else if self.elem_count() == 0 { + bail!("cannot use pad_with_same on an empty tensor") + } else if left == 0 { + let dim = dim.to_index(self.shape(), "pad_with_same")?; + let r = self.narrow(dim, self.dim(dim)? - 1, 1)?; + let mut v = vec![self]; + for _ in 0..right { + v.push(&r) + } + Tensor::cat(&v, dim) + } else if right == 0 { + let dim = dim.to_index(self.shape(), "pad_with_same")?; + let l = self.narrow(dim, 0, 1)?; + let mut v = vec![]; + for _ in 0..left { + v.push(&l) + } + v.push(self); + Tensor::cat(&v, dim) + } else { + let dim = dim.to_index(self.shape(), "pad_with_same")?; + let l = self.narrow(dim, 0, 1)?; + let r = self.narrow(dim, self.dim(dim)? - 1, 1)?; + let mut v = vec![]; + for _ in 0..left { + v.push(&l) + } + v.push(self); + for _ in 0..right { + v.push(&r) + } + Tensor::cat(&v, dim) + } + } + + /// Run the `forward` method of `m` on `self`. + pub fn apply(&self, m: &M) -> Result { + m.forward(self) + } + + /// Run the `forward` method of `m` on `self`. + pub fn apply_t(&self, m: &M, train: bool) -> Result { + m.forward_t(self, train) + } + + pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> { + self.storage.read().unwrap() + } + + pub(crate) fn storage_mut(&self) -> std::sync::RwLockWriteGuard<'_, Storage> { + self.storage.write().unwrap() + } + + // If we extend the visibility of this function to be usable outside of this crate, we should + // make it unsafe. + pub(crate) fn storage_mut_and_layout( + &self, + ) -> (std::sync::RwLockWriteGuard<'_, Storage>, &Layout) { + let storage = self.storage.write().unwrap(); + (storage, &self.layout) + } + + /// The storage used by this tensor, together with the layout to use to access it safely. + pub fn storage_and_layout(&self) -> (std::sync::RwLockReadGuard<'_, Storage>, &Layout) { + let storage = self.storage.read().unwrap(); + (storage, &self.layout) + } + + pub(crate) fn same_storage(&self, rhs: &Self) -> bool { + let lhs: &RwLock = self.storage.as_ref(); + let rhs: &RwLock = rhs.storage.as_ref(); + std::ptr::eq(lhs, rhs) + } + + /// Normalize a 'relative' axis value: positive values are kept, negative + /// values means counting the dimensions from the back. + pub fn normalize_axis(&self, axis: i64) -> Result { + let rank = self.rank() as i64; + if rank <= axis { + bail!("axis {axis} is too large, tensor rank {rank}") + } else if 0 <= axis { + Ok(axis as usize) + } else { + let naxis = rank + axis; + if naxis < 0 { + bail!("axis {axis} is too small, tensor rank {rank}") + } + Ok(naxis as usize) + } + } + + /// Returns a lower triangular matrix of ones of size n by n. + pub fn tril2(n: usize, dtype: DType, device: &Device) -> Result { + let t = Tensor::arange(0u32, n as u32, device)?; + let t1 = t.reshape((1, n))?.broadcast_as((n, n))?; + let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?; + t1.le(&t2)?.to_dtype(dtype) + } + + /// Returns an upper triangular matrix of ones of size n by n. + pub fn triu2(n: usize, dtype: DType, device: &Device) -> Result { + let t = Tensor::arange(0u32, n as u32, device)?; + let t1 = t.reshape((1, n))?.broadcast_as((n, n))?; + let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?; + t1.ge(&t2)?.to_dtype(dtype) + } + + /// Returns a matrix with a diagonal of ones of size n by n. + pub fn eye(n: usize, dtype: DType, device: &Device) -> Result { + let t = Tensor::arange(0u32, n as u32, device)?; + let t1 = t.reshape((1, n))?.broadcast_as((n, n))?; + let t2 = t.reshape((n, 1))?.broadcast_as((n, n))?; + t1.eq(&t2)?.to_dtype(dtype) + } + + /// Returns the cumulative sum of elements of the input tensor summed over the specified + /// dimension. + /// + /// This operation is most efficient when dim is the last dimension of the tensor. + pub fn cumsum(&self, dim: D) -> Result { + let dim = dim.to_index(self.shape(), "cumsum")?; + let rank = self.rank(); + if rank == 0 { + return Ok(self.clone()); + } + let n_axis = self.dim(dim)?; + let triu = Tensor::triu2(n_axis, self.dtype(), self.device())?; + if rank == 1 { + self.unsqueeze(0)?.matmul(&triu)?.squeeze(0) + } else { + let last = rank - 1; + let t = self.transpose(dim, last)?; + let t = t.broadcast_matmul(&triu)?; + t.transpose(dim, last) + } + } + + /// Returns log(sum(exp(tensor), dim)). + pub fn log_sum_exp(&self, sum_dims: D) -> Result { + let sum_dims = sum_dims.to_indexes(self.shape(), "log-sum-exp")?; + if sum_dims.is_empty() { + return Ok(self.clone()); + } + let max = sum_dims[1..] + .iter() + .try_fold(self.max_keepdim(sum_dims[0])?, |max, &dim| { + max.max_keepdim(dim) + })?; + let exp = self.broadcast_sub(&max)?.exp()?; + let sum = exp.sum(sum_dims.clone())?; + + sum.log()? + max.squeeze_dims(&sum_dims) + } + + /// Pointwise pow operation. + pub fn pow(&self, rhs: &Tensor) -> Result { + rhs.mul(&self.log()?)?.exp() + } + + /// Broadcasting version of `pow`. + pub fn broadcast_pow(&self, rhs: &Tensor) -> Result { + rhs.broadcast_mul(&self.log()?)?.exp() + } + + /// Returns a view of which contains all slices of size `size` from self tensor in the dimension + /// `dim` and stepped by `step`. + pub fn unfold(&self, dim: D, size: usize, step: usize) -> Result { + // https://github.com/pytorch/pytorch/blob/75b0720a97ac5d82e8a7a1a6ae7c5f7a87d7183d/aten/src/ATen/native/TensorShape.cpp#L3785-L3804 + let mut sizes = self.dims().to_vec(); + let mut strides = self.stride().to_vec(); + + let dim = dim.to_index(self.shape(), "unfold")?; + + let max_len = if self.dims().is_empty() { + 1 + } else { + sizes[dim] + }; + if size > max_len { + bail!( + "unsqueeze: maximum size for tensor at dimension {dim} is {max_len} but size is {size}" + ) + } + sizes.push(size); + strides.push(if self.dims().is_empty() { + 1 + } else { + strides[dim] + }); + + if !self.dims().is_empty() { + sizes[dim] = ((sizes[dim] as f32 - size as f32) / step as f32 + 1.) as usize; + strides[dim] *= step; + } + + let tensor_ = Tensor_ { + id: TensorId::new(), + storage: self.storage.clone(), + layout: Layout::new(sizes.into(), strides, self.layout.start_offset()), + op: BackpropOp::new1(self, Op::Reshape), + is_variable: false, + dtype: self.dtype, + device: self.device.clone(), + }; + Ok(Tensor(Arc::new(tensor_))) + } +} + +macro_rules! bin_trait { + ($trait:ident, $fn1:ident, $mul:expr, $add:expr) => { + impl> std::ops::$trait for Tensor { + type Output = Result; + + fn $fn1(self, rhs: B) -> Self::Output { + Tensor::$fn1(&self, rhs.borrow()) + } + } + + impl> std::ops::$trait for &Tensor { + type Output = Result; + + fn $fn1(self, rhs: B) -> Self::Output { + Tensor::$fn1(&self, rhs.borrow()) + } + } + + impl> std::ops::$trait for Result { + type Output = Result; + + fn $fn1(self, rhs: Tensor) -> Self::Output { + Tensor::$fn1(self?.borrow(), &rhs) + } + } + + impl> std::ops::$trait<&Tensor> for Result { + type Output = Result; + + fn $fn1(self, rhs: &Tensor) -> Self::Output { + Tensor::$fn1(self?.borrow(), rhs) + } + } + + impl> std::ops::$trait> for Tensor { + type Output = Result; + + fn $fn1(self, rhs: Result) -> Self::Output { + Tensor::$fn1(&self, rhs?.borrow()) + } + } + + impl> std::ops::$trait> for &Tensor { + type Output = Result; + + fn $fn1(self, rhs: Result) -> Self::Output { + Tensor::$fn1(&self, rhs?.borrow()) + } + } + + impl std::ops::$trait for Tensor { + type Output = Result; + + fn $fn1(self, rhs: f64) -> Self::Output { + self.affine($mul(rhs), $add(rhs)) + } + } + + impl std::ops::$trait for &Tensor { + type Output = Result; + + fn $fn1(self, rhs: f64) -> Self::Output { + self.affine($mul(rhs), $add(rhs)) + } + } + }; +} + +bin_trait!(Add, add, |_| 1., |v| v); +bin_trait!(Sub, sub, |_| 1., |v: f64| -v); +bin_trait!(Mul, mul, |v| v, |_| 0.); +bin_trait!(Div, div, |v| 1. / v, |_| 0.); + +impl std::ops::Add for f64 { + type Output = Result; + + fn add(self, rhs: Tensor) -> Self::Output { + rhs + self + } +} + +impl std::ops::Add<&Tensor> for f64 { + type Output = Result; + + fn add(self, rhs: &Tensor) -> Self::Output { + rhs + self + } +} + +impl std::ops::Mul for f64 { + type Output = Result; + + fn mul(self, rhs: Tensor) -> Self::Output { + rhs * self + } +} + +impl std::ops::Mul<&Tensor> for f64 { + type Output = Result; + + fn mul(self, rhs: &Tensor) -> Self::Output { + rhs * self + } +} + +impl std::ops::Sub for f64 { + type Output = Result; + + fn sub(self, rhs: Tensor) -> Self::Output { + rhs.affine(-1., self) + } +} + +impl std::ops::Sub<&Tensor> for f64 { + type Output = Result; + + fn sub(self, rhs: &Tensor) -> Self::Output { + rhs.affine(-1., self) + } +} + +impl std::ops::Div for f64 { + type Output = Result; + + #[allow(clippy::suspicious_arithmetic_impl)] + fn div(self, rhs: Tensor) -> Self::Output { + rhs.recip()? * self + } +} + +impl std::ops::Div<&Tensor> for f64 { + type Output = Result; + + #[allow(clippy::suspicious_arithmetic_impl)] + fn div(self, rhs: &Tensor) -> Self::Output { + rhs.recip()? * self + } +} diff --git a/mcandle/mcandle-core/src/tensor_cat.rs b/mcandle/mcandle-core/src/tensor_cat.rs new file mode 100644 index 000000000..e7f6141c3 --- /dev/null +++ b/mcandle/mcandle-core/src/tensor_cat.rs @@ -0,0 +1,300 @@ +use crate::{shape::Dim, Error, Result, Shape, Tensor}; + +impl Tensor { + /// Concatenates two or more tensors along a particular dimension. + /// + /// All tensors must of the same rank, and the output will have + /// the same rank + /// + /// ```rust + /// # use mcandle_core::{Tensor, DType, Device}; + /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// + /// let c = Tensor::cat(&[&a, &b], 0)?; + /// assert_eq!(c.shape().dims(), &[4, 3]); + /// + /// let c = Tensor::cat(&[&a, &b], 1)?; + /// assert_eq!(c.shape().dims(), &[2, 6]); + /// # Ok::<(), mcandle_core::Error>(()) + /// ``` + pub fn cat, D: Dim>(args: &[A], dim: D) -> Result { + if args.is_empty() { + Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())? + } + let arg0 = args[0].as_ref(); + if args.len() == 1 { + return Ok(arg0.clone()); + } + let dim = dim.to_index(arg0.shape(), "cat")?; + for arg in args { + arg.as_ref().check_dim(dim, "cat")?; + } + for (arg_idx, arg) in args.iter().enumerate() { + let arg = arg.as_ref(); + if arg0.rank() != arg.rank() { + Err(Error::UnexpectedNumberOfDims { + expected: arg0.rank(), + got: arg.rank(), + shape: arg.shape().clone(), + } + .bt())? + } + for (dim_idx, (v1, v2)) in arg0 + .shape() + .dims() + .iter() + .zip(arg.shape().dims().iter()) + .enumerate() + { + if dim_idx != dim && v1 != v2 { + Err(Error::ShapeMismatchCat { + dim: dim_idx, + first_shape: arg0.shape().clone(), + n: arg_idx + 1, + nth_shape: arg.shape().clone(), + } + .bt())? + } + } + } + let all_contiguous = args.iter().all(|v| v.as_ref().is_contiguous()); + if all_contiguous { + Self::cat_contiguous(args, dim) + } else if dim == 0 { + Self::cat0(args) + } else { + let args: Vec = args + .iter() + .map(|a| a.as_ref().transpose(0, dim)) + .collect::>>()?; + let cat = Self::cat0(&args)?; + cat.transpose(0, dim) + } + } + + fn cat0>(args: &[A]) -> Result { + if args.is_empty() { + Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())? + } + let arg0 = args[0].as_ref(); + if args.len() == 1 { + return Ok(arg0.clone()); + } + let rank = arg0.rank(); + let device = arg0.device(); + let dtype = arg0.dtype(); + let first_dims = arg0.shape().dims(); + let mut cat_dims = first_dims.to_vec(); + cat_dims[0] = 0; + let mut offsets = vec![0usize]; + for (arg_idx, arg) in args.iter().enumerate() { + let arg = arg.as_ref(); + if arg.dtype() != dtype { + Err(Error::DTypeMismatchBinaryOp { + lhs: dtype, + rhs: arg.dtype(), + op: "cat", + } + .bt())? + } + if arg.device().location() != device.location() { + Err(Error::DeviceMismatchBinaryOp { + lhs: device.location(), + rhs: arg.device().location(), + op: "cat", + } + .bt())? + } + if rank != arg.rank() { + Err(Error::UnexpectedNumberOfDims { + expected: rank, + got: arg.rank(), + shape: arg.shape().clone(), + } + .bt())? + } + for (dim_idx, (v1, v2)) in arg0 + .shape() + .dims() + .iter() + .zip(arg.shape().dims().iter()) + .enumerate() + { + if dim_idx == 0 { + cat_dims[0] += v2; + } + if dim_idx != 0 && v1 != v2 { + Err(Error::ShapeMismatchCat { + dim: dim_idx, + first_shape: arg0.shape().clone(), + n: arg_idx + 1, + nth_shape: arg.shape().clone(), + } + .bt())? + } + } + let next_offset = offsets.last().unwrap() + arg.elem_count(); + offsets.push(next_offset); + } + let shape = Shape::from(cat_dims); + let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, 0)); + let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? }; + for (arg, &offset) in args.iter().zip(offsets.iter()) { + let arg = arg.as_ref(); + arg.storage() + .copy_strided_src(&mut storage, offset, arg.layout())?; + } + Ok(crate::tensor::from_storage(storage, shape, op, false)) + } + + fn cat_contiguous>(args: &[A], dim: usize) -> Result { + if args.is_empty() { + Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())? + } + let arg0 = args[0].as_ref(); + if args.len() == 1 { + return Ok(arg0.clone()); + } + let rank = arg0.rank(); + let device = arg0.device(); + let dtype = arg0.dtype(); + let first_dims = arg0.shape().dims(); + let mut cat_dims = first_dims.to_vec(); + cat_dims[dim] = 0; + for (arg_idx, arg) in args.iter().enumerate() { + let arg = arg.as_ref(); + if arg.dtype() != dtype { + Err(Error::DTypeMismatchBinaryOp { + lhs: dtype, + rhs: arg.dtype(), + op: "cat", + } + .bt())? + } + if arg.device().location() != device.location() { + Err(Error::DeviceMismatchBinaryOp { + lhs: device.location(), + rhs: arg.device().location(), + op: "cat", + } + .bt())? + } + if rank != arg.rank() { + Err(Error::UnexpectedNumberOfDims { + expected: rank, + got: arg.rank(), + shape: arg.shape().clone(), + } + .bt())? + } + for (dim_idx, (v1, v2)) in arg0 + .shape() + .dims() + .iter() + .zip(arg.shape().dims().iter()) + .enumerate() + { + if dim_idx == dim { + cat_dims[dim] += v2; + } + if dim_idx != dim && v1 != v2 { + Err(Error::ShapeMismatchCat { + dim: dim_idx, + first_shape: arg0.shape().clone(), + n: arg_idx + 1, + nth_shape: arg.shape().clone(), + } + .bt())? + } + } + } + let cat_target_dim_len = cat_dims[dim]; + let block_size: usize = cat_dims.iter().skip(1 + dim).product(); + let shape = Shape::from(cat_dims); + let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, dim)); + let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? }; + let mut dst_o = 0; + for arg in args.iter() { + let arg = arg.as_ref(); + let arg_dims = arg.shape().dims(); + let d1: usize = arg_dims.iter().take(dim).product(); + let d2 = block_size * arg_dims[dim]; + let dst_s = block_size * cat_target_dim_len; + let src_o = arg.layout().start_offset(); + arg.storage().copy2d( + &mut storage, + d1, + d2, + /* src_s */ d2, + dst_s, + src_o, + dst_o, + )?; + dst_o += d2; + } + Ok(crate::tensor::from_storage(storage, shape, op, false)) + } + + /// Set the values on `self` using values from `src`. The copy starts at the specified + /// `offset` for the target dimension `dim` on `self`. + /// `self` and `src` must have the same shape except on dimension `dim` where the `self` size + /// has to be greater than or equal to `offset` plus the `src` size. + /// + /// Note that this modifies `self` in place and as such is not compatibel with + /// back-propagation. + pub fn slice_set(&self, src: &Self, dim: D, offset: usize) -> Result<()> { + let dim = dim.to_index(self.shape(), "slice-set")?; + if !self.is_contiguous() || !src.is_contiguous() { + Err(Error::RequiresContiguous { op: "slice-set" }.bt())? + } + if self.dtype() != src.dtype() { + Err(Error::DTypeMismatchBinaryOp { + lhs: self.dtype(), + rhs: src.dtype(), + op: "slice-set", + } + .bt())? + } + if self.device().location() != src.device().location() { + Err(Error::DeviceMismatchBinaryOp { + lhs: self.device().location(), + rhs: src.device().location(), + op: "slice-set", + } + .bt())? + } + if self.rank() != src.rank() { + Err(Error::UnexpectedNumberOfDims { + expected: self.rank(), + got: src.rank(), + shape: self.shape().clone(), + } + .bt())? + } + for (dim_idx, (v1, v2)) in self.dims().iter().zip(src.dims().iter()).enumerate() { + if dim_idx == dim && *v2 + offset > *v1 { + crate::bail!("shape mismatch on target dim, dst: {v1}, src: {v2} + {offset}") + } + if dim_idx != dim && v1 != v2 { + crate::bail!("shape mismatch on dim {dim_idx}, {v1} <> {v2}") + } + } + let block_size: usize = src.dims().iter().skip(1 + dim).product(); + let d1: usize = src.dims().iter().take(dim).product(); + let d2 = block_size * src.dims()[dim]; + let dst_o = self.layout().start_offset() + offset * block_size; + let src_o = src.layout().start_offset(); + src.storage().copy2d( + &mut self.storage_mut(), + d1, + d2, + /* src_s */ d2, + /* dst_s */ block_size * self.dims()[dim], + src_o, + dst_o, + )?; + + Ok(()) + } +} diff --git a/mcandle/mcandle-core/src/tensor_indexing.rs b/mcandle/mcandle-core/src/tensor_indexing.rs new file mode 100644 index 000000000..c5068d855 --- /dev/null +++ b/mcandle/mcandle-core/src/tensor_indexing.rs @@ -0,0 +1,379 @@ +use std::ops::{Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive}; + +use crate::{ + bail, + op::{BackpropOp, Op}, + shape::Dim, + tensor::from_storage, + DType, Error, Result, Tensor, +}; + +/// Specialization of `std::ops::RangeBounds` for `usize` to allow trait objects. +pub trait RangeBound { + fn start_bound(&self) -> std::ops::Bound; + fn end_bound(&self) -> std::ops::Bound; +} + +macro_rules! range_bound { + ($name:ident) => { + impl RangeBound for $name { + fn end_bound(&self) -> std::ops::Bound { + >::end_bound(&self).cloned() + } + fn start_bound(&self) -> std::ops::Bound { + >::start_bound(&self).cloned() + } + } + }; + // Use the marker to designate no generics + ($name:ident, $marker:expr) => { + impl RangeBound for $name { + fn end_bound(&self) -> std::ops::Bound { + >::end_bound(&self).cloned() + } + fn start_bound(&self) -> std::ops::Bound { + >::start_bound(&self).cloned() + } + } + }; + // Use the marker to designate no generics + ($name:ty) => { + impl RangeBound for $name { + fn end_bound(&self) -> std::ops::Bound { + >::end_bound(&self).cloned() + } + fn start_bound(&self) -> std::ops::Bound { + >::start_bound(&self).cloned() + } + } + }; +} + +range_bound!(Range); +range_bound!(RangeFrom); +range_bound!(RangeFull, ()); +range_bound!(RangeInclusive); +range_bound!(RangeTo); +range_bound!(RangeToInclusive); +range_bound!((std::ops::Bound, std::ops::Bound)); + +impl RangeBound for usize { + fn end_bound(&self) -> std::ops::Bound { + std::ops::Bound::Excluded(self + 1) + } + fn start_bound(&self) -> std::ops::Bound { + std::ops::Bound::Included(*self) + } +} + +impl Tensor { + /// Returns a copy of `self` where the values within `ranges` have been replaced with the + /// content of `src`. This is analogous to slice asignment in `torch`. + /// + /// # Example + /// ```rust + /// use mcandle_core::{Device, Tensor}; + /// + /// let dev = Device::Cpu; + /// let tensor = Tensor::arange(0u32, 4 * 5, &dev)?.reshape((4, 5))?; + /// let src = Tensor::arange(100u32, (2 * 3) + 100, &dev)?.reshape((3, 2))?; + /// let out = tensor.slice_assign(&[&(..3), &(3..5)], &src)?; + /// assert_eq!( + /// out.to_vec2::()?, + /// &[ + /// [0, 1, 2, 100, 101], + /// [5, 6, 7, 102, 103], + /// [10, 11, 12, 104, 105], + /// [15, 16, 17, 18, 19] + /// ] + /// ); + /// # Ok::<(), mcandle_core::Error>(()) + /// ``` + pub fn slice_assign(&self, ranges: &[&dyn RangeBound], src: &Tensor) -> Result { + let src_dims = src.dims(); + let self_dims = self.dims(); + if self_dims.len() != src_dims.len() { + bail!( + "slice-assign requires input with the same rank {} <> {}", + self_dims.len(), + src_dims.len() + ) + } + if self_dims.len() != ranges.len() { + bail!( + "slice-assign requires input with the same rank as there are ranges {} <> {}", + self_dims.len(), + ranges.len() + ) + } + let mut src = src.clone(); + let mut mask = Self::ones(src.shape(), DType::U8, src.device())?; + for (i, range) in ranges.iter().enumerate() { + let start_included = match range.start_bound() { + std::ops::Bound::Unbounded => 0, + std::ops::Bound::Included(v) => v, + std::ops::Bound::Excluded(v) => v + 1, + }; + let end_excluded = match range.end_bound() { + std::ops::Bound::Unbounded => self_dims[i], + std::ops::Bound::Included(v) => v + 1, + std::ops::Bound::Excluded(v) => v, + }; + if end_excluded <= start_included { + bail!("slice-assign: empty range for dim {i}, {start_included} {end_excluded}") + } + if self_dims[i] < end_excluded { + bail!( + "slice-assign: upper bound is out of range for dim {i}, {end_excluded} {}", + self_dims[i] + ) + } + if end_excluded - start_included != src_dims[i] { + bail!( + "slice-assign: the range for dim {i} ({start_included}..{end_excluded}) does not match the size of src {}", src_dims[i] + ) + } + src = src.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?; + mask = mask.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)? + } + mask.where_cond(/* on_true= */ &src, /* on_false= */ self) + } + + pub fn scatter_add(&self, indexes: &Self, source: &Self, dim: D) -> Result { + let dim = dim.to_index(self.shape(), "scatter-add")?; + let source_dims = source.dims(); + let self_dims = self.dims(); + let mismatch = if source_dims.len() != self_dims.len() { + true + } else { + let mut mismatch = false; + for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() { + if i != dim && d1 != d2 { + mismatch = true; + break; + } + } + mismatch + }; + if mismatch { + Err(Error::ShapeMismatchBinaryOp { + op: "scatter-add (self, src)", + lhs: self.shape().clone(), + rhs: source.shape().clone(), + } + .bt())? + } + if indexes.dims() != source.dims() { + Err(Error::ShapeMismatchBinaryOp { + op: "scatter-add (indexes, src)", + lhs: indexes.shape().clone(), + rhs: source.shape().clone(), + } + .bt())? + } + let storage = self.storage().scatter_add( + self.layout(), + &indexes.storage(), + indexes.layout(), + &source.storage(), + source.layout(), + dim, + )?; + let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| { + Op::ScatterAdd(t1, t2, t3, dim) + }); + Ok(from_storage(storage, self.shape(), op, false)) + } + + /// Embeds the values of the `src` tensor into the `self` tensor on the specified dimension. + pub fn slice_scatter(&self, src: &Self, dim: D, start: usize) -> Result { + let dim = dim.to_index(self.shape(), "slice-scatter")?; + if dim == 0 { + self.slice_scatter0(src, start) + } else { + // TODO: Maybe we want to add a more efficient implementation at some point. + self.transpose(0, dim)? + .slice_scatter0(&src.transpose(0, dim)?, start)? + .transpose(0, dim) + } + } + + /// Embeds the values of the `src` tensor into the `self` tensor on the first dimension. + pub fn slice_scatter0(&self, src: &Self, start: usize) -> Result { + if self.dtype() != src.dtype() { + Err(Error::DTypeMismatchBinaryOp { + lhs: self.dtype(), + rhs: src.dtype(), + op: "slice-scatter", + } + .bt())? + } + if self.device().location() != src.device().location() { + Err(Error::DeviceMismatchBinaryOp { + lhs: self.device().location(), + rhs: src.device().location(), + op: "slice-scatter", + } + .bt())? + } + if self.rank() != src.rank() { + Err(Error::UnexpectedNumberOfDims { + expected: self.rank(), + got: src.rank(), + shape: src.shape().clone(), + } + .bt())? + } + let shape_ok = + self.dims() + .iter() + .zip(src.dims().iter()) + .enumerate() + .all(|(dim_idx, (&d1, &d2))| { + if 0 == dim_idx { + d2 + start <= d1 + } else { + d1 == d2 + } + }); + if !shape_ok { + Err(Error::ShapeMismatchBinaryOp { + op: "slice-scatter (self, src)", + lhs: self.shape().clone(), + rhs: src.shape().clone(), + } + .bt())? + } + let mut storage = unsafe { self.device().alloc_uninit(self.shape(), self.dtype())? }; + self.storage() + .copy_strided_src(&mut storage, 0, self.layout())?; + let offset = start * src.dims()[1..].iter().product::(); + src.storage() + .copy_strided_src(&mut storage, offset, src.layout())?; + let op = BackpropOp::new2(self, src, |t1, t2| Op::SliceScatter0(t1, t2, start)); + Ok(from_storage(storage, self.shape(), op, false)) + } + + /// Accumulate element from `source` at indexes `indexes` and add them to `self`. + pub fn index_add(&self, indexes: &Self, source: &Self, dim: D) -> Result { + let dim = dim.to_index(self.shape(), "index-add")?; + let source_dims = source.dims(); + let self_dims = self.dims(); + let mismatch = if source_dims.len() != self_dims.len() { + true + } else { + let mut mismatch = false; + for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() { + if i != dim && d1 != d2 { + mismatch = true; + break; + } + } + mismatch + }; + if mismatch { + Err(Error::ShapeMismatchBinaryOp { + op: "index-add (self, source)", + lhs: self.shape().clone(), + rhs: source.shape().clone(), + } + .bt())? + } + // The number of element in indexes must match the dimension on which the add is + // performed on the source tensor (and the index values from `indexes` are taken from + // the target tensor self) + let indexes_len = indexes.dims1()?; + if source_dims[dim] != indexes_len { + Err(Error::ShapeMismatchBinaryOp { + op: "index-add (ids, source))", + lhs: indexes.shape().clone(), + rhs: source.shape().clone(), + } + .bt())? + } + let storage = self.storage().index_add( + self.layout(), + &indexes.storage(), + indexes.layout(), + &source.storage(), + source.layout(), + dim, + )?; + let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| { + Op::IndexAdd(t1, t2, t3, dim) + }); + Ok(from_storage(storage, self.shape(), op, false)) + } + + /// Gather values across the target dimension. + /// + /// # Arguments + /// + /// * `self` - The input tensor. + /// * `indexes` - The indices of elements to gather, this should have the same shape as `self` + /// but can have a different number of elements on the target dimension. + /// * `dim` - the target dimension. + /// + /// The resulting tensor has the same shape as `indexes` and use values from `self` indexed on + /// dimension `dim` by the values in `indexes`. + pub fn gather(&self, indexes: &Self, dim: D) -> Result { + let dim = dim.to_index(self.shape(), "gather")?; + let self_dims = self.dims(); + let indexes_dims = indexes.dims(); + let mismatch = if indexes_dims.len() != self_dims.len() { + true + } else { + let mut mismatch = false; + for (i, (&d1, &d2)) in self_dims.iter().zip(indexes_dims.iter()).enumerate() { + if i != dim && d1 != d2 { + mismatch = true; + break; + } + } + mismatch + }; + if mismatch { + Err(Error::ShapeMismatchBinaryOp { + op: "gather", + lhs: self.shape().clone(), + rhs: indexes.shape().clone(), + } + .bt())? + } + let storage = + self.storage() + .gather(self.layout(), &indexes.storage(), indexes.layout(), dim)?; + let op = BackpropOp::new2(self, indexes, |t1, t2| Op::Gather(t1, t2, dim)); + Ok(from_storage(storage, indexes.shape(), op, false)) + } + + /// Select values for the input tensor at the target indexes across the specified dimension. + /// + /// The `indexes` is argument is an int tensor with a single dimension. + /// The output has the same number of dimension as the `self` input. The target dimension of + /// the output has length the length of `indexes` and the values are taken from `self` using + /// the index from `indexes`. Other dimensions have the same number of elements as the input + /// tensor. + pub fn index_select(&self, indexes: &Self, dim: D) -> Result { + let dim = dim.to_index(self.shape(), "index-select")?; + let indexes_len = match indexes.dims() { + [l] => *l, + _ => Err(Error::ShapeMismatchBinaryOp { + lhs: self.shape().clone(), + rhs: indexes.shape().clone(), + op: "index-select", + } + .bt())?, + }; + let storage = self.storage().index_select( + &indexes.storage(), + self.layout(), + indexes.layout(), + dim, + )?; + let mut dims = self.dims().to_vec(); + dims[dim] = indexes_len; + let op = BackpropOp::new2(self, indexes, |t1, t2| Op::IndexSelect(t1, t2, dim)); + Ok(from_storage(storage, dims, op, false)) + } +} diff --git a/mcandle/mcandle-core/src/test_utils.rs b/mcandle/mcandle-core/src/test_utils.rs new file mode 100644 index 000000000..3b8fb904c --- /dev/null +++ b/mcandle/mcandle-core/src/test_utils.rs @@ -0,0 +1,62 @@ +use crate::{Result, Tensor}; + +#[macro_export] +macro_rules! test_device { + // TODO: Switch to generating the two last arguments automatically once concat_idents is + // stable. https://github.com/rust-lang/rust/issues/29599 + ($fn_name: ident, $test_cpu: ident, $test_cuda: ident, $test_metal: ident) => { + #[test] + fn $test_cpu() -> Result<()> { + $fn_name(&Device::Cpu) + } + + #[cfg(feature = "cuda")] + #[test] + fn $test_cuda() -> Result<()> { + $fn_name(&Device::new_cuda(0)?) + } + + #[cfg(feature = "metal")] + #[test] + fn $test_metal() -> Result<()> { + $fn_name(&Device::new_metal(0)?) + } + }; +} + +pub fn to_vec0_round(t: &Tensor, digits: i32) -> Result { + let b = 10f32.powi(digits); + let t = t.to_vec0::()?; + Ok(f32::round(t * b) / b) +} + +pub fn to_vec1_round(t: &Tensor, digits: i32) -> Result> { + let b = 10f32.powi(digits); + let t = t.to_vec1::()?; + let t = t.iter().map(|t| f32::round(t * b) / b).collect(); + Ok(t) +} + +pub fn to_vec2_round(t: &Tensor, digits: i32) -> Result>> { + let b = 10f32.powi(digits); + let t = t.to_vec2::()?; + let t = t + .iter() + .map(|t| t.iter().map(|t| f32::round(t * b) / b).collect()) + .collect(); + Ok(t) +} + +pub fn to_vec3_round(t: &Tensor, digits: i32) -> Result>>> { + let b = 10f32.powi(digits); + let t = t.to_vec3::()?; + let t = t + .iter() + .map(|t| { + t.iter() + .map(|t| t.iter().map(|t| f32::round(t * b) / b).collect()) + .collect() + }) + .collect(); + Ok(t) +} diff --git a/mcandle/mcandle-core/src/utils.rs b/mcandle/mcandle-core/src/utils.rs new file mode 100644 index 000000000..78c45a9a9 --- /dev/null +++ b/mcandle/mcandle-core/src/utils.rs @@ -0,0 +1,44 @@ +use std::str::FromStr; + +pub fn get_num_threads() -> usize { + // Respond to the same environment variable as rayon. + match std::env::var("RAYON_NUM_THREADS") + .ok() + .and_then(|s| usize::from_str(&s).ok()) + { + Some(x) if x > 0 => x, + Some(_) | None => num_cpus::get(), + } +} + +pub fn has_accelerate() -> bool { + cfg!(feature = "accelerate") +} + +pub fn has_mkl() -> bool { + cfg!(feature = "mkl") +} + +pub fn cuda_is_available() -> bool { + cfg!(feature = "cuda") +} + +pub fn metal_is_available() -> bool { + cfg!(feature = "metal") +} + +pub fn with_avx() -> bool { + cfg!(target_feature = "avx") +} + +pub fn with_neon() -> bool { + cfg!(target_feature = "neon") +} + +pub fn with_simd128() -> bool { + cfg!(target_feature = "simd128") +} + +pub fn with_f16c() -> bool { + cfg!(target_feature = "f16c") +} diff --git a/mcandle/mcandle-core/src/variable.rs b/mcandle/mcandle-core/src/variable.rs new file mode 100644 index 000000000..1e4880e59 --- /dev/null +++ b/mcandle/mcandle-core/src/variable.rs @@ -0,0 +1,152 @@ +// Variables are wrappers around tensors that can be modified, they are typically used for holding +// weights and being modified by gradient descent. +// We do not expose a public way to create variables as this would break the invariant that the +// tensor within a variable is actually with `is_variable` set to `true`. +use crate::{DType, Device, Error, Result, Shape, Tensor}; + +/// A variable is a wrapper around a tensor, however variables can have their content modified +/// whereas tensors are immutable. +#[derive(Clone, Debug)] +pub struct Var(Tensor); + +impl std::fmt::Display for Var { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + std::fmt::Display::fmt(&self.0, f) + } +} + +impl std::ops::Deref for Var { + type Target = Tensor; + + fn deref(&self) -> &Self::Target { + self.0.as_ref() + } +} + +impl Var { + pub fn zeros>(shape: S, dtype: DType, device: &Device) -> Result { + let inner = Tensor::zeros_impl(shape, dtype, device, true)?; + Ok(Self(inner)) + } + + pub fn ones>(shape: S, dtype: DType, device: &Device) -> Result { + let inner = Tensor::ones_impl(shape, dtype, device, true)?; + Ok(Self(inner)) + } + + // Convert a tensor to a variable, if the tensor is already a variable then it is returned as is. + pub fn from_tensor(t: &Tensor) -> Result { + if t.is_variable() { + Ok(Self(t.clone())) + } else { + let inner = t.make_var()?; + Ok(Self(inner)) + } + } + + pub fn rand_f64>( + lo: f64, + up: f64, + s: S, + dtype: DType, + device: &Device, + ) -> Result { + let inner = Tensor::rand_f64_impl(lo, up, s, dtype, device, true)?; + Ok(Self(inner)) + } + + pub fn randn_f64>( + mean: f64, + std: f64, + s: S, + dtype: DType, + device: &Device, + ) -> Result { + let inner = Tensor::randn_f64_impl(mean, std, s, dtype, device, true)?; + Ok(Self(inner)) + } + + pub fn rand, T: crate::FloatDType>( + lo: T, + up: T, + s: S, + device: &Device, + ) -> Result { + let inner = Tensor::rand_impl(lo, up, s, device, true)?; + Ok(Self(inner)) + } + + pub fn randn, T: crate::FloatDType>( + mean: T, + std: T, + s: S, + device: &Device, + ) -> Result { + let inner = Tensor::randn_impl(mean, std, s, device, true)?; + Ok(Self(inner)) + } + + /// Creates a new tensor on the specified device using the content and shape of the input. + /// This is similar to `new` but the resulting tensor is a variable. + pub fn new(array: A, device: &Device) -> Result { + let shape = array.shape()?; + let inner = Tensor::new_impl(array, shape, device, true)?; + Ok(Self(inner)) + } + + pub fn from_vec, D: crate::WithDType>( + data: Vec, + shape: S, + device: &Device, + ) -> Result { + let inner = Tensor::from_vec_impl(data, shape, device, true)?; + Ok(Self(inner)) + } + + pub fn from_slice, D: crate::WithDType>( + array: &[D], + shape: S, + device: &Device, + ) -> Result { + let inner = Tensor::new_impl(array, shape.into(), device, true)?; + Ok(Self(inner)) + } + + pub fn as_detached_tensor(&self) -> Tensor { + self.0.detach() + } + + pub fn as_tensor(&self) -> &Tensor { + &self.0 + } + + /// Consumes this `Var` and return the underlying tensor. + pub fn into_inner(self) -> Tensor { + self.0 + } + + /// Sets the content of the inner tensor, this does not require a mutable reference as inner + /// mutability is used. + pub fn set(&self, src: &Tensor) -> Result<()> { + if self.same_storage(src) { + let msg = "cannot set a variable to a tensor that is derived from its value"; + Err(Error::CannotSetVar { msg }.bt())? + } + let (mut dst, layout) = self.storage_mut_and_layout(); + if !layout.is_contiguous() { + let msg = "cannot set a non-contiguous variable"; + Err(Error::CannotSetVar { msg }.bt())? + } + let (src, src_l) = src.storage_and_layout(); + if layout.shape() != src_l.shape() { + Err(Error::ShapeMismatchBinaryOp { + lhs: layout.shape().clone(), + rhs: src_l.shape().clone(), + op: "set", + } + .bt())? + } + src.copy_strided_src(&mut dst, layout.start_offset(), src_l)?; + Ok(()) + } +} diff --git a/mcandle/mcandle-core/tests/conv_tests.rs b/mcandle/mcandle-core/tests/conv_tests.rs new file mode 100644 index 000000000..53d992f7d --- /dev/null +++ b/mcandle/mcandle-core/tests/conv_tests.rs @@ -0,0 +1,864 @@ +use anyhow::Result; +use mcandle_core::{test_device, test_utils, Device, IndexOp, Tensor}; + +/* This test is based on the following script. +import torch +torch.manual_seed(4242) + +t = torch.randn((1, 4, 5)) +w = torch.randn((2, 4, 3)) +print(t.flatten()) +print(w.flatten()) +res = torch.nn.functional.conv1d(t, w) +print(res.flatten()) +res = torch.nn.functional.conv1d(t, w, padding=1) +print(res.flatten()) + +w_t = w.transpose(0, 1) +res = torch.nn.functional.conv_transpose1d(t, w_t) +print(res.shape) +print(res) +res = torch.nn.functional.conv_transpose1d(t, w_t, groups=2) +print(res.shape) +print(res) +*/ +fn conv1d(dev: &Device) -> Result<()> { + let t = Tensor::new( + &[ + 0.4056f32, -0.8689, -0.0773, -1.5630, 1.2279, -0.9287, -1.7030, 0.1370, 0.1866, 0.4145, + 1.8025, -0.1536, 2.2013, -0.6836, 0.2477, 1.3127, -0.6957, 0.3278, -1.0124, 0.5599, + ], + dev, + )? + .reshape((1, 4, 5))?; + let w = Tensor::new( + &[ + -0.8404f32, -0.3490, 0.0130, 1.3123, 0.1763, -1.9249, 1.4270, 0.9421, 0.8670, -0.7181, + -1.1111, 0.8869, -1.2429, 1.8357, 1.6052, -1.3844, 0.3951, -1.2036, 0.6686, 1.6261, + -0.6451, -0.0840, -1.4247, 0.5512, + ], + dev, + )? + .reshape((2, 4, 3))?; + let res = t.conv1d(&w, 0, 1, 1, 1)?; + assert_eq!(res.dims(), [1, 2, 3]); + assert_eq!( + test_utils::to_vec1_round(&res.flatten_all()?, 4)?, + [2.6357, -1.3336, 4.1393, -1.1784, 3.5675, 0.5069] + ); + let res = t.conv1d(&w, /*padding*/ 1, 1, 1, 1)?; + assert_eq!(res.dims(), [1, 2, 5]); + // Same as pytorch default padding: use zeros. + assert_eq!( + test_utils::to_vec1_round(&res.flatten_all()?, 4)?, + [2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352] + ); + + let w = w.transpose(0, 1)?; + // The CPU kernels applied in the contiguous and non contiguous cases are different. + for w in [w.clone(), w.contiguous()?] { + let res = t.conv_transpose1d(&w, 0, 0, 1, 1, 1)?; + assert_eq!(res.dims(), [1, 2, 7]); + assert_eq!( + test_utils::to_vec1_round(&res.flatten_all()?, 4)?, + [ + 0.0699, -1.2899, 8.3018, 5.5873, 2.4572, -2.6143, -0.0706, 1.8765, 4.8318, 1.1538, + 4.7076, -5.9745, -0.8276, 1.621 + ], + ); + let res = t.conv_transpose1d(&w, 0, 0, 1, 1, 2)?; + assert_eq!(res.dims(), [1, 4, 7]); + assert_eq!( + test_utils::to_vec2_round(&res.squeeze(0)?, 4)?, + [ + [-1.5596, -1.8099, 2.0407, 4.8764, -0.1743, -0.735, -0.7819], + [0.7816, 3.8152, -0.5926, 2.2515, -5.1844, -0.3157, 1.4721], + [1.6295, 0.52, 6.2611, 0.7109, 2.6315, -1.8793, 0.7113], + [1.0949, 1.0166, 1.7464, 2.4561, -0.79, -0.5119, 0.1488] + ] + ); + } + Ok(()) +} + +fn conv1d_small(dev: &Device) -> Result<()> { + let t = Tensor::new(&[0.4056f32, -0.8689, -0.0773, -1.5630], dev)?.reshape((1, 1, 4))?; + let w = Tensor::new(&[1f32, 0., 0.], dev)?.reshape((1, 1, 3))?; + let res = t.conv1d(&w, 0, 1, 1, 1)?; + assert_eq!(res.dims(), [1, 1, 2]); + assert_eq!( + test_utils::to_vec1_round(&res.flatten_all()?, 4)?, + [0.4056, -0.8689] + ); + let res = t.conv1d(&w, /*padding*/ 1, 1, 1, 1)?; + assert_eq!(res.dims(), [1, 1, 4]); + assert_eq!( + test_utils::to_vec1_round(&res.flatten_all()?, 4)?, + [0.0, 0.4056, -0.8689, -0.0773], + ); + Ok(()) +} + +/* This test is based on the following script. +import torch +torch.manual_seed(4242) + +t = torch.randn((1, 4, 5, 5)) +w = torch.randn((2, 4, 3, 3)) +print(t.flatten()) +print(w.flatten()) +res = torch.nn.functional.conv2d(t, w) +print(res.flatten()) + +w_t = w.transpose(0, 1) +res = torch.nn.functional.conv_transpose2d(t, w_t) +print(res.shape) +print(res) + +res = torch.nn.functional.conv2d(t, w, dilation=2) +print(res.shape) +print(res[0]) + +res = torch.nn.functional.conv_transpose2d(t, w_t, dilation=2) +print(res.shape) +print(res) +*/ +fn conv2d(dev: &Device) -> Result<()> { + let t = Tensor::new( + &[ + 0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616, + 1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699, 0.0823, 0.3526, 0.6843, 0.2395, + 1.2279, -0.9287, -1.7030, 0.1370, 0.6047, 0.3770, -0.6266, 0.3529, 2.2013, -0.6836, + 0.2477, 1.3127, -0.2260, 0.2622, -1.2974, -0.8140, -0.8404, -0.3490, 0.0130, 1.3123, + 1.7569, -0.3956, -1.8255, 0.1727, -0.3538, 2.6941, 1.0529, 0.4219, -0.2071, 1.1586, + 0.4717, 0.3865, -0.5690, -0.5010, -0.1310, 0.7796, 0.6630, -0.2021, 2.6090, 0.2049, + 0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712, + 0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790, + -0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006, + -0.8, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085, + ], + dev, + )?; + let w = Tensor::new( + &[ + -0.9325f32, 0.6451, -0.8537, 0.2378, 0.8764, -0.1832, 0.2987, -0.6488, -0.2273, + -2.4184, -0.1192, -0.4821, -0.5079, -0.5766, -2.4729, 1.6734, 0.4558, 0.2851, 1.1514, + -0.9013, 1.0662, -0.1817, -0.0259, 0.1709, 0.5367, 0.7513, 0.8086, -2.2586, -0.5027, + 0.9141, -1.3086, -1.3343, -1.5669, -0.1657, 0.7958, 0.1432, 0.3896, -0.4501, 0.1667, + 0.0714, -0.0952, 1.2970, -0.1674, -0.3178, 1.0677, 0.3060, 0.7080, 0.1914, 1.1679, + -0.3602, 1.9265, -1.8626, -0.5112, -0.0982, 0.2621, 0.6565, 0.5908, 1.0089, -0.1646, + 1.8032, -0.6286, 0.2016, -0.3370, 1.2555, 0.8009, -0.6488, -0.4652, -1.5685, 1.5860, + 0.5583, 0.4623, 0.6026, + ], + dev, + )?; + let t = t.reshape((1, 4, 5, 5))?; + let w = w.reshape((2, 4, 3, 3))?; + let res = t.conv2d(&w, 0, 1, 1, 1)?; + assert_eq!(res.dims(), [1, 2, 3, 3]); + assert_eq!( + test_utils::to_vec1_round(&res.flatten_all()?, 4)?, + [ + -4.2812, 2.0923, 5.2187, 7.5184, 0.752, -14.9426, 10.0087, 4.391, 0.2918, 1.6715, + 10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075 + ] + ); + + let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?; + + assert_eq!(res.dims(), [1, 2, 7, 7]); + assert_eq!( + test_utils::to_vec3_round(&res.i(0)?, 4)?, + [ + [ + [-1.9918, 2.6797, -0.4599, -1.6037, 1.4131, -2.4012, 2.9277], + [1.8016, -3.5361, 1.0757, 3.5395, -8.2168, -3.2023, 0.5375], + [0.8243, 1.8675, 7.8929, -4.0746, -6.4415, 5.1139, 1.6889], + [0.2722, 8.9679, 3.3477, 1.8514, -4.2896, -3.8228, -7.5632], + [-8.5412, -5.8142, -7.1587, -1.6095, 0.4651, 0.2748, -2.0985], + [2.0833, -0.6482, -12.1692, -4.1284, -2.9765, -0.0656, -4.5114], + [5.307, 2.6957, 2.3087, 1.0478, 0.7808, -1.1519, -0.9579] + ], + [ + [1.089, 0.1872, -0.6408, -0.9897, 0.8503, 1.1019, -0.9211], + [-0.1741, -0.2915, 4.2472, 1.9417, 1.65, 0.6303, -4.7131], + [1.6555, 2.4026, -2.9293, 2.9953, 0.5328, 3.5873, -0.9621], + [-1.4289, -3.2787, 4.1747, -6.0341, -4.6341, -5.7945, 4.142], + [7.5973, 6.4431, 5.9872, 2.1639, -8.6566, 3.3143, -3.4059], + [-0.8775, -3.048, 11.6543, 0.6442, 2.3218, -0.4765, 1.1516], + [-5.5423, -2.5188, 1.0754, -0.0563, -2.9386, -1.1504, 1.0171] + ] + ] + ); + + // Dilations. + let res = t.conv2d(&w, 0, 1, 2, 1)?; + assert_eq!(res.dims(), [1, 2, 1, 1]); + assert_eq!( + test_utils::to_vec1_round(&res.flatten_all()?, 4)?, + [2.45, -2.3504], + ); + + // Transpose and dilations. + let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 2)?; + assert_eq!(res.dims(), [1, 2, 9, 9]); + assert_eq!( + test_utils::to_vec3_round(&res.i(0)?, 4)?, + [ + [ + [-1.9918, 3.1652, -0.6778, -4.3442, 4.4351, 0.6652, -3.0124, -0.6031, 2.9277], + [2.7036, -1.7156, -0.3969, 1.0516, 1.6381, -2.8886, -0.205, 2.4682, -1.0499], + [-0.9459, 3.1631, 3.707, -4.8369, -8.5166, -1.4496, -2.7559, -3.2698, 1.4376], + [-0.2157, 3.7786, -2.0252, -4.2633, 3.6731, -1.5142, 5.9391, -0.2622, -0.141], + [-6.8121, -3.1744, 1.5945, 3.0637, -9.6088, 1.4446, 2.9489, -3.0082, -7.3822], + [0.2371, 3.3303, 0.3861, 2.2646, -4.6784, 4.1235, -0.0109, 0.3176, -0.03], + [-2.5339, -2.9564, -3.4518, -4.4594, -9.1873, -1.9709, -0.4676, 0.51, -3.5024], + [4.007, 0.3067, -2.2954, 1.1105, -0.1992, 1.6372, -2.9268, 0.2807, -1.2787], + [5.307, 1.1317, 1.3518, 0.9049, 3.8116, -0.4075, -0.8874, -0.2241, -0.9579] + ], + [ + [1.089, -0.6483, 0.0726, -0.4752, -1.3283, 1.7103, 1.0703, 0.1076, -0.9211], + [-0.8629, 0.1376, 0.3202, 2.0955, 0.9696, 2.8988, -1.0012, 1.5049, -0.1278], + [1.9286, -1.5255, -2.9563, 2.4589, 3.3611, -0.6951, 0.3525, -1.7724, -5.9861], + [1.1226, 2.1561, 3.6417, 4.7546, -0.692, 4.4126, -5.1902, 6.0805, 2.3185], + [1.0111, 0.3604, 0.6432, -3.6605, 7.9517, -9.2955, -5.2988, -3.7803, -2.0642], + [3.3172, -1.7967, -3.6576, -2.0942, 1.3158, 0.112, -1.7405, 2.9167, 0.7957], + [5.1001, 1.8995, -1.8639, 1.1262, 9.9629, 2.683, -3.6319, -1.1607, 0.5856], + [-4.8445, -0.5642, 4.2317, 0.0856, 1.2267, -0.5712, 1.736, 1.0997, 0.6908], + [-5.5423, -1.1831, -1.2176, 0.0843, 0.0446, -0.7545, -2.4798, -0.0827, 1.0171] + ] + ] + ); + + Ok(()) +} + +/* This test is based on the following script. +import torch +torch.manual_seed(4242) + +t = torch.randn((1, 2, 3, 3)) +w = torch.randn((1, 2, 1, 1)) +print(t.flatten()) +print(w.flatten()) +res = torch.nn.functional.conv2d(t, w) +print(res.flatten()) + +w_t = w.transpose(0, 1) +res = torch.nn.functional.conv_transpose2d(t, w_t) +print(res.shape) +print(res.flatten()) + +t_t = w.transpose(0, 1) +res = torch.nn.functional.conv_transpose2d(t_t, w) +print(res.shape) +print(res.flatten()) +*/ +fn conv2d_small(dev: &Device) -> Result<()> { + let t = Tensor::new( + &[ + 0.4056f32, -0.8689, 0.6843, 0.2395, 1.2279, -0.9287, -1.7030, 0.1370, 0.1866, 0.4145, + -0.6266, 0.3529, 2.2013, -0.6836, 0.2477, 1.3127, -0.6957, 0.3278, + ], + dev, + )?; + let w = Tensor::new(&[-0.9259f32, 1.3017], dev)?; + let t = t.reshape((1, 2, 3, 3))?; + let w = w.reshape((1, 2, 1, 1))?; + let res = t.conv2d(&w, 0, 1, 1, 1)?; + assert_eq!(res.dims(), [1, 1, 3, 3]); + assert_eq!( + test_utils::to_vec1_round(&res.flatten_all()?, 4)?, + [0.164, -0.0111, -0.1742, 2.6437, -2.0268, 1.1823, 3.2855, -1.0324, 0.2539] + ); + let res = t.conv2d(&w, 2, 1, 1, 1)?; + assert_eq!(res.dims(), [1, 1, 7, 7]); + assert_eq!( + test_utils::to_vec1_round(&res.flatten_all()?, 4)?, + [ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1640, + -0.0111, -0.1742, 0.0, 0.0, 0.0, 0.0, 2.6437, -2.0268, 1.1823, 0.0, 0.0, 0.0, 0.0, + 3.2855, -1.0324, 0.2539, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0 + ] + ); + + let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?; + assert_eq!(res.dims(), [1, 1, 3, 3]); + assert_eq!( + test_utils::to_vec1_round(&res.flatten_all()?, 4)?, + [0.164, -0.0111, -0.1742, 2.6437, -2.0268, 1.1823, 3.2855, -1.0324, 0.2539], + ); + let res = t.transpose(0, 1)?.conv_transpose2d(&w, 0, 0, 1, 1)?; + assert_eq!(res.dims(), [2, 2, 3, 3]); + assert_eq!( + test_utils::to_vec1_round(&res.flatten_all()?, 4)?, + [ + -0.3755, 0.8045, -0.6336, -0.2218, -1.1369, 0.8599, 1.5768, -0.1268, -0.1728, 0.528, + -1.131, 0.8908, 0.3118, 1.5984, -1.2089, -2.2168, 0.1783, 0.2429, -0.3838, 0.5802, + -0.3268, -2.0382, 0.6329, -0.2293, -1.2154, 0.6441, -0.3035, 0.5396, -0.8156, 0.4594, + 2.8654, -0.8898, 0.3224, 1.7087, -0.9056, 0.4267 + ] + ); + Ok(()) +} + +fn conv2d_smaller(dev: &Device) -> Result<()> { + let t = Tensor::new( + &[ + 0.4056f32, -0.8689, 0.6843, 0.2395, 1.2279, -0.9287, -1.7030, 0.1370, 0.1866, + ], + dev, + )?; + let w = Tensor::new(&[1f32, 1., 1., 1., 1., 1., 1., 1., 1.], dev)?; + let t = t.reshape((1, 1, 3, 3))?; + let w = w.reshape((1, 1, 3, 3))?; + let res = t.conv2d(&w, 0, 1, 1, 1)?; + assert_eq!(res.dims(), [1, 1, 1, 1]); + assert_eq!( + test_utils::to_vec1_round(&res.flatten_all()?, 4)?, + [-0.6197] + ); + Ok(()) +} + +/* This test is based on the following script. +import torch +torch.manual_seed(4242) + +t = torch.randn((1, 2, 4, 2)) +w = torch.randn((1, 2, 1, 1)) +print(t.flatten()) +print(w.flatten()) +res = torch.nn.functional.conv2d(t, w) +print(res.flatten()) +*/ +fn conv2d_non_square(dev: &Device) -> Result<()> { + let t = Tensor::new( + &[ + 0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616, + 1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699, + ], + dev, + )?; + let w = Tensor::new(&[-1.1351f32, 1.3841], dev)?; + let t = t.reshape((1, 2, 4, 2))?; + let w = w.reshape((1, 2, 1, 1))?; + let res = t.conv2d(&w, 0, 1, 1, 1)?; + assert_eq!(res.dims(), [1, 1, 4, 2]); + assert_eq!( + test_utils::to_vec1_round(&res.flatten_all()?, 4)?, + [0.2312, 5.2238, 2.3772, 1.9076, 2.0256, -0.5776, -1.6028, -1.467] + ); + Ok(()) +} + +/* +import torch +torch.manual_seed(4242) + +t = torch.randn((1, 4, 5, 5), requires_grad=True) +w = torch.randn((2, 4, 3, 3), requires_grad=True) +print(t.flatten()) +print(w.flatten()) +res = torch.nn.functional.conv2d(t, w) +print(res.flatten()) +loss = (res ** 2).sum() +print(loss) +loss.backward() +print(t.grad.shape) +print(t.grad.flatten()) +print(w.grad.shape) +print(w.grad.flatten()) + +t.grad.zero_() +w.grad.zero_() +res = torch.nn.functional.conv2d(t, w, stride=2) +print(res.flatten()) +loss = (res ** 2).sum() +print(loss) +loss.backward() +print(t.grad.shape) +print(t.grad[0]) +print(w.grad.shape) +print(w.grad[0]) +*/ +fn conv2d_grad(dev: &Device) -> Result<()> { + // conv-transposes are not implemented for metal + use mcandle_core::Var; + let t = Var::from_slice( + &[ + 0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616, + 1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699, 0.0823, 0.3526, 0.6843, 0.2395, + 1.2279, -0.9287, -1.7030, 0.1370, 0.6047, 0.3770, -0.6266, 0.3529, 2.2013, -0.6836, + 0.2477, 1.3127, -0.2260, 0.2622, -1.2974, -0.8140, -0.8404, -0.3490, 0.0130, 1.3123, + 1.7569, -0.3956, -1.8255, 0.1727, -0.3538, 2.6941, 1.0529, 0.4219, -0.2071, 1.1586, + 0.4717, 0.3865, -0.5690, -0.5010, -0.1310, 0.7796, 0.6630, -0.2021, 2.6090, 0.2049, + 0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712, + 0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790, + -0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006, + -0.8, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085, + ], + (1, 4, 5, 5), + dev, + )?; + let w = Var::from_slice( + &[ + -0.9325f32, 0.6451, -0.8537, 0.2378, 0.8764, -0.1832, 0.2987, -0.6488, -0.2273, + -2.4184, -0.1192, -0.4821, -0.5079, -0.5766, -2.4729, 1.6734, 0.4558, 0.2851, 1.1514, + -0.9013, 1.0662, -0.1817, -0.0259, 0.1709, 0.5367, 0.7513, 0.8086, -2.2586, -0.5027, + 0.9141, -1.3086, -1.3343, -1.5669, -0.1657, 0.7958, 0.1432, 0.3896, -0.4501, 0.1667, + 0.0714, -0.0952, 1.2970, -0.1674, -0.3178, 1.0677, 0.3060, 0.7080, 0.1914, 1.1679, + -0.3602, 1.9265, -1.8626, -0.5112, -0.0982, 0.2621, 0.6565, 0.5908, 1.0089, -0.1646, + 1.8032, -0.6286, 0.2016, -0.3370, 1.2555, 0.8009, -0.6488, -0.4652, -1.5685, 1.5860, + 0.5583, 0.4623, 0.6026, + ], + (2, 4, 3, 3), + dev, + )?; + let res = t.conv2d(&w, 0, 1, 1, 1)?; + let loss = res.sqr()?.sum_all()?; + assert_eq!(test_utils::to_vec0_round(&loss, 2)?, 741.12f32); + let grads = loss.backward()?; + let grad_t = grads.get(&t).unwrap(); + let grad_w = grads.get(&w).unwrap(); + assert_eq!(grad_t.dims(), [1, 4, 5, 5]); + assert_eq!(grad_w.dims(), [2, 4, 3, 3]); + assert_eq!( + test_utils::to_vec1_round(&grad_t.flatten_all()?, 2)?, + [ + 9.29, -2.84, -5.71, 3.38, -7.71, -19.15, 7.02, 29.1, 9.34, 34.73, -22.87, 24.35, + -39.88, -14.01, 21.08, 9.94, 13.63, -34.68, 11.21, -6.26, 7.72, -6.32, -16.64, -1.08, + -20.22, 21.73, -0.37, -4.06, 5.82, -3.65, -30.73, 14.55, 87.7, 31.6, 4.53, -89.78, + -75.37, -57.43, -7.56, 92.96, 18.79, -4.63, -159.75, -42.47, -47.26, 52.88, 37.32, + 49.0, 12.82, 2.01, -8.98, 20.18, 16.62, 12.06, 15.38, 20.0, 2.57, -15.22, 72.62, + -10.75, 2.25, -31.2, 3.75, -0.2, 9.76, -0.68, 5.21, -40.44, -22.59, -61.61, 17.28, + 20.41, 37.55, 5.23, 6.81, 23.54, 23.62, -9.99, -9.13, 4.87, -35.06, -26.1, 63.48, + 25.81, -39.21, -70.68, -46.96, 2.33, 41.81, 82.42, -28.63, -11.78, -35.33, -10.28, + -28.57, -9.13, 7.21, -9.05, -9.62, -11.25 + ] + ); + assert_eq!( + test_utils::to_vec1_round(&grad_w.flatten_all()?, 2)?, + [ + -28.92, -22.88, -141.23, 73.35, 61.07, 47.81, -20.0, -73.71, -41.82, -13.59, 21.5, + 28.72, 28.57, -46.85, -90.19, 143.61, 16.68, 7.43, 18.88, -90.81, -20.29, 54.79, 82.63, + 22.94, 77.81, -16.39, -13.2, 9.34, -40.39, -26.62, 5.33, -60.91, 9.09, -59.37, 7.08, + 58.64, 5.55, 20.52, 2.5, -17.25, -6.8, 22.21, 30.15, -7.52, -37.46, 5.67, 22.58, 9.03, + 47.05, 17.61, 37.31, -98.13, -14.61, -4.8, -6.36, 44.69, 23.34, 8.37, -13.52, 80.05, + -34.24, -16.36, -12.31, 1.92, -33.62, -14.1, -49.23, -7.39, 11.5, -9.98, 9.66, 29.6 + ] + ); + + // Same as before but with stride. + let res = t.conv2d(&w, 0, 2, 1, 1)?; + let loss = res.sqr()?.sum_all()?; + assert_eq!(test_utils::to_vec0_round(&loss, 2)?, 277.16f32); + let grads = loss.backward()?; + let grad_t = grads.get(&t).unwrap(); + let grad_w = grads.get(&w).unwrap(); + assert_eq!(grad_t.dims(), [1, 4, 5, 5]); + assert_eq!(grad_w.dims(), [2, 4, 3, 3]); + assert_eq!( + test_utils::to_vec3_round(&grad_t.i(0)?, 2)?, + [ + [ + [9.29, -7.03, 0.94, 3.49, -7.71], + [-1.8, -7.82, 8.9, 8.46, 7.43], + [-25.84, 22.09, -19.27, -0.22, 1.69], + [4.02, 18.53, -18.37, 2.3, -24.51], + [7.72, -9.68, -12.34, 5.6, -20.22] + ], + [ + [21.73, 3.39, -18.27, 3.86, -3.65], + [8.25, 3.73, 30.73, -8.61, -11.93], + [-72.15, -15.36, -17.53, -12.32, -1.61], + [-22.32, -7.79, -91.82, 6.44, -37.69], + [52.88, 14.44, 42.75, 9.88, 2.01] + ], + [ + [-8.98, 9.91, 6.75, -4.68, 15.38], + [4.93, -0.33, 9.94, -1.46, 14.78], + [13.62, -30.63, 3.96, -3.58, -4.48], + [-14.13, 1.19, -34.43, 3.08, -33.83], + [17.28, 12.94, 31.83, -3.35, 6.81] + ], + [ + [23.54, 6.98, -24.52, 0.52, 4.87], + [9.65, 6.18, 1.71, -25.23, -4.93], + [-54.99, -23.66, 3.19, -3.73, 18.58], + [-21.35, -10.39, -39.88, 28.73, -30.76], + [-9.13, 11.12, -14.0, -8.23, -11.25] + ] + ] + ); + assert_eq!( + test_utils::to_vec3_round(&grad_w.i(0)?, 2)?, + [ + [ + [28.34, -7.91, -45.75], + [21.03, 3.86, 29.86], + [0.72, -36.58, -35.28] + ], + [ + [-16.04, 11.53, -16.38], + [29.62, -16.32, -48.35], + [57.5, 28.29, 25.81] + ], + [ + [2.93, -19.6, 1.57], + [27.15, 53.88, -24.64], + [12.74, -22.6, -26.2] + ], + [ + [-0.18, -14.86, -6.82], + [-19.55, -2.72, 45.9], + [-2.54, 36.97, 27.11] + ] + ] + ); + + // Replicate the issue from https://github.com/huggingface/mcandle/issues/1212 + let res = t.i((.., .., 0..4, 0..4))?.conv2d(&w, 0, 2, 1, 1)?; + let loss = res.sqr()?.sum_all()?; + assert_eq!(test_utils::to_vec0_round(&loss, 2)?, 21.12f32); + let grads = loss.backward()?; + let grad_t = grads.get(&t).unwrap(); + let grad_w = grads.get(&w).unwrap(); + assert_eq!(grad_t.dims(), [1, 4, 5, 5]); + assert_eq!(grad_w.dims(), [2, 4, 3, 3]); + assert_eq!( + test_utils::to_vec3_round(&grad_t.i(0)?, 2)?, + [ + [ + [9.29, -7.03, 7.87, 0.0, 0.0], + [-1.8, -7.82, 5.9, 0.0, 0.0], + [-3.12, 4.49, 5.52, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0] + ], + [ + [21.73, 3.39, 4.77, 0.0, 0.0], + [8.25, 3.73, 27.61, 0.0, 0.0], + [-20.55, -5.61, -2.77, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0] + ], + [ + [-8.98, 9.91, -7.15, 0.0, 0.0], + [4.93, -0.33, 4.56, 0.0, 0.0], + [-6.7, -5.76, -8.05, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0] + ], + [ + [23.54, 6.98, -10.0, 0.0, 0.0], + [9.65, 6.18, 18.72, 0.0, 0.0], + [3.29, -5.27, 0.79, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0] + ] + ] + ); + assert_eq!( + test_utils::to_vec3_round(&grad_w.i(0)?, 2)?, + [ + [ + [-3.47, 7.44, 0.66], + [12.89, -3.4, -9.29], + [-14.16, -0.83, 7.14] + ], + [ + [-3.23, 5.37, -3.02], + [-2.12, -11.24, 1.94], + [6.97, 7.2, 2.99] + ], + [ + [-4.04, -3.31, 4.87], + [-6.68, -5.68, 1.73], + [-5.54, 4.32, 0.52] + ], + [[-4.72, 1.5, 4.72], [3.79, 4.04, 6.76], [-4.6, 5.8, 6.93]] + ] + ); + + // Conv Transpose 2d Test + //tested against following python + + // import torch + // torch.manual_seed(4242) + // padding = 4 + // outpadding = 2 + // dilation = 3 + // stride = 3 + // input = torch.randn((1, 4, 7, 5), requires_grad=True) + // kernel = torch.randn((4, 2, 3, 5), requires_grad=True) + // print("input", input.flatten()) + // print("kernel", kernel.flatten()) + // res = torch.nn.functional.conv_transpose2d( + // input, + // kernel, + // stride=stride, + // padding=padding, + // dilation=dilation, + // output_padding=outpadding, + // ) + // res.retain_grad() + // print(res.shape) + // loss = (res**2).sum() + // print(loss) + // loss.backward() + // print(input.grad.shape) + // print("input grad", torch.round(input.grad, decimals=1)) + // print(kernel.grad.shape) + // print("kernel grad", torch.round(kernel.grad.flatten(), decimals=1)) + + let padding = 4; + let outpadding = 2; + let dilation = 3; + let stride = 3; + + let t = Var::from_slice( + &[ + 0.4056_f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, + 3.0616, 1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699, 0.0823, 0.3526, 0.6843, + 0.2395, 1.2279, -0.9287, -1.7030, 0.1370, 0.6047, 0.3770, -0.6266, 0.3529, 2.2013, + -0.6836, 0.2477, 1.3127, -0.2260, 0.2622, -1.2974, -0.8140, -0.8404, -0.3490, 0.0130, + 1.3123, 1.7569, -0.3956, -1.8255, 0.1727, -0.3538, 2.6941, 1.0529, 0.4219, -0.2071, + 1.1586, 0.4717, 0.3865, -0.5690, -0.5010, -0.1310, 0.7796, 0.6630, -0.2021, 2.6090, + 0.2049, 0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, + -1.3712, 0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, + 0.3790, -0.4431, -0.4720, -0.7890, 0.2620, 0.5411, -1.1715, -2.4997, 2.3249, -0.8912, + -0.4733, -0.5701, -2.8888, -1.4112, -0.5471, -0.9234, -1.1660, 0.4189, -0.7465, + -0.6473, 0.1402, 0.7875, 0.5377, -0.6779, -0.8088, -0.4864, -0.2312, 0.9279, 0.1264, + 1.5480, 0.8265, -0.1025, 0.5138, -0.2512, 0.1576, 1.2705, 0.3641, -0.9325, 0.6451, + -0.8537, 0.2378, 0.1794, 0.2752, -0.3687, -1.1149, -0.1410, -0.5829, -0.0892, 1.4258, + -2.2789, 0.5270, 0.1825, 1.7007, -0.5263, -0.2954, 0.4440, 0.5537, 0.3492, 0.6186, + 1.6475, 0.2219, + ], + (1, 4, 7, 5), + dev, + )?; + + #[rustfmt::skip] + let w = Var::from_slice( + &[ + -1.1744_f32, 0.3266, 2.5893, 1.0142, 0.1763, 0.7752, 0.6604, 0.2029, -0.2145, 0.7234, + -0.3441, -1.5400, -0.6333, 0.6613, 0.2083, 0.6230, -1.7002, 0.3393, 0.4049, 1.0762, + 0.2723, 1.4181, 0.0029, -0.2122, 1.7668, 1.4168, 0.3320, -0.2719, 0.7932, -0.7204, + 0.4447, 0.1211, 0.5908, 1.0089, -0.1646, 1.8033, -0.6286, 0.2016, -0.3370, 1.2555, + 0.8009, -0.6488, -0.4652, -1.5685, 1.5860, 0.5583, 0.4623, 0.6026, 0.8828, 2.4990, + 0.6811, -0.3369, 1.3320, 1.7669, -1.1067, 1.2958, -0.9415, -0.9655, -0.4462, 0.7181, + 0.5181, -1.1658, -1.8467, -0.7763, 1.2769, 0.8651, 0.9890, 1.5092, 0.7207, -0.8481, + 0.7417, 0.3375, -1.2685, 1.4572, 1.0915, 0.1093, -0.8550, -0.5831, -0.6309, -0.2509, + 0.5220, -0.0914, 0.7900, 0.1096, 0.3258, 0.2723, -1.0942, -0.3393, -0.1653, 0.5732, + -0.8014, 1.8194, -1.9023, 0.2127, 1.8636, -0.8979, 0.1927, -0.2778, 0.3105, 0.0071, + -1.1823, 0.2476, -0.7178, -1.3821, 1.0769, -0.4376, -0.9967, -0.1227, 1.6197, -1.0604, + 0.1372, 0.8141, -0.6163, 0.7304, -0.8285, 2.0636, -0.7176, 0.2495, -0.2581, -0.4478, + ], + (4, 2, 3, 5), + dev, + )?; + let res = t.conv_transpose2d(&w, padding, outpadding, stride, dilation)?; + let loss = res.sqr()?.sum_all()?; + assert_eq!(test_utils::to_vec0_round(&loss, 0)?, 2904.0); + let grads = loss.backward()?; + + let grad_t = grads.get(&t).unwrap(); + let grad_w = grads.get(&w).unwrap(); + assert_eq!(grad_t.dims(), [1, 4, 7, 5]); + assert_eq!(grad_w.dims(), [4, 2, 3, 5]); + + assert_eq!( + test_utils::to_vec1_round(&grad_w.flatten_all()?, 1)?, + [ + // torch gets 89.1 + -89.0, -135.3, 136.7, 102.0, -53.4, 117.9, 118.6, -43.9, -218.0, -58.5, -114.3, -150.0, + -15.6, 172.1, 66.3, -64.3, -27.9, -19.8, 31.7, 62.1, 5.5, 92.6, 28.2, -29.6, 55.9, + 52.7, -72.7, -119.8, 53.8, -25.5, 128.8, 19.3, 68.0, 190.9, -64.1, -86.2, -111.2, + 106.6, -67.7, 37.8, 115.9, 50.4, -77.7, -54.9, 22.3, -4.6, 89.8, 61.7, 122.4, 192.6, + -27.8, -104.6, 57.0, 166.4, 27.1, 6.1, 18.7, -93.2, 31.5, 168.2, -3.7, -99.5, -55.5, + -10.8, 17.5, 20.8, 16.9, 43.8, 42.0, -89.2, 18.8, -9.6, -84.1, 212.6, 19.7, -50.0, + -52.0, -40.0, -166.6, -73.2, -10.8, -73.3, 31.5, -23.4, -79.3, -27.0, -84.4, -42.9, + -20.3, 51.8, -16.7, 76.3, -120.5, -65.8, 96.5, -10.7, -45.9, -88.1, 65.4, -7.0, -1.5, + 92.8, -25.1, -114.2, -5.8, -14.8, -51.2, -20.7, 54.2, -79.8, 47.7, -29.2, -8.8, 53.5, + -28.4, 85.0, -18.3, 107.0, 28.3, -71.8 + ] + ); + + assert_eq!( + test_utils::to_vec3_round(&grad_t.i(0)?, 1)?, + [ + [ + [32.3, -41.6, -24.0, 14.1, 17.6], + [-11.8, 72.5, 87.6, 46.4, 61.5], + [115.0, 108.5, -48.6, -63.4, -50.0], + [51.3, 5.4, 31.3, 91.1, -30.9], + [52.7, 92.8, -68.0, -47.0, 83.0], + // pytorch gets -107.1 + [-10.2, -107.0, -5.4, 213.1, -31.4], + [-2.4, 65.1, 9.2, -146.2, -24.2] + ], + [ + [-72.6, -63.9, -61.9, 45.3, 33.0], + [79.3, -0.5, -26.2, 78.2, 42.7], + [90.9, 141.6, 40.1, -62.7, 37.0], + [32.8, 198.2, -0.8, -31.1, 27.3], + // torch gets 48.0 + [34.5, 34.9, -47.9, 127.6, -12.3], + [-61.4, -3.2, -2.9, -10.9, -16.6], + [74.6, 60.1, -68.9, 34.5, -50.4] + ], + [ + [37.5, -56.9, -43.6, -13.5, -9.9], + [40.0, 97.3, 28.6, 14.2, -30.1], + [-22.3, -126.3, -68.8, -8.2, 26.1], + [-32.9, 37.3, 108.5, -54.8, 29.6], + [34.9, -176.9, -125.0, -28.3, -13.9], + [-54.9, 142.6, 62.1, -80.4, -65.6], + [7.4, -91.1, -67.6, 35.0, 39.7] + ], + [ + [-57.2, -40.9, -10.1, 32.6, 29.4], + [18.7, -18.0, 29.5, -1.2, 59.2], + [-14.0, -74.4, 19.8, -117.0, 58.2], + [-21.8, 163.5, -71.1, -99.0, 80.9], + [-58.9, -10.9, 93.8, -139.6, 98.0], + // torch gets 54.5 + [-54.4, 135.3, 6.0, -79.1, 134.6], + [27.5, -76.0, 43.4, -2.8, -7.8] + ] + ] + ); + + // Test the same, but then with the following properties, t & w are unmodified. + let padding = 1; + let outpadding = 1; + let dilation = 1; + let stride = 2; + + let res = t.conv_transpose2d(&w, padding, outpadding, stride, dilation)?; + let loss = res.sqr()?.sum_all()?; + assert_eq!(test_utils::to_vec0_round(&loss, 0)?, 3627.0); // torch gives 3626.8560 + + let grads = loss.backward()?; + + let grad_t = grads.get(&t).unwrap(); + let grad_w = grads.get(&w).unwrap(); + assert_eq!(grad_t.dims(), [1, 4, 7, 5]); + assert_eq!(grad_w.dims(), [4, 2, 3, 5]); + + #[rustfmt::skip] + assert_eq!( + test_utils::to_vec3_round(&grad_t.i(0)?, 1)?, + [ + [ + [ 13.2, -40.7, -9.7, -47.3, -82.7], + [ -98.2, 9.7, 57.7, -6.2, 180.7], + [ 100.2, 24.1, 3.7, -100.5, -48.1], + [ -0.3, 13.5, -2.9, 80.0, -49.8], + [ 47.2, -25.6, -74.4, 61.2, -18.4], + [ 4.6, -69.5, 27.9, 66.5, -88.1], + // 4th column on next row; torch is 4.2 + [ -12.0, 79.2, -40.0, 4.1, -97.1], + ], + [ + [ -42.2, -36.5, -51.1, 7.5, 32.3], + [ 74.1, -44.6, -68.8, 19.5, 7.7], + [ 137.1, 54.2, 153.8, -58.0, 45.5], + [ 24.4, -56.8, 9.7, -41.0, -14.5], + [ -3.7, 72.6, 8.3, 134.8, 40.5], + [ 43.2, -56.9, -47.5, -89.4, -95.4], + [ 68.2, 108.1, -80.0, 57.0, -121.1] + ], + [ + [ 31.1, -11.4, -34.8, 33.1, -44.2], + [ 29.4, -31.6, -40.2, 13.7, 13.1], + [ -0.8, -83.8, -7.8, -17.3, 78.2], + [ 12.0, -118.7, 137.5, -76.7, 50.8], + [ -28.7, -114.2, -3.7, -96.3, -13.8], + [ -31.8, 28.5, -14.3, 4.6, 13.4], + [ 28.0, -0.2, -38.9, -29.7, -59.0] + ], + [ + [ -16.8, 38.5, 15.5, 26.6, 48.9], + [ 14.5, 49.6, -24.8, 65.6, 61.7], + [ 22.1, -64.7, -4.3, -51.0, 36.3], + [ 31.0, -88.9, 47.1, -123.5, -3.8], + [ -14.8, -39.8, 128.2, -110.3, 42.6], + // 1st column on next row; torch is -7.2 + [ -7.1, 95.3, -21.3, -58.7, -13.9], + [ 26.9, 21.3, 16.1, 70.3, 32.1] + ] + ] + ); + + #[rustfmt::skip] + assert_eq!( + test_utils::to_vec1_round(&grad_w.flatten_all()?, 1)?, + [ + // 2nd value; torch gets -3.2, 3rd value; torch gets 221.8 + -2.460e+01, -3.100e+00, 2.219e+02, 7.400e+00, 5.620e+01, + 7.420e+01, 7.830e+01, 8.900e+00, 1.050e+01, 2.810e+01, + 5.100e+00, -1.046e+02, -1.572e+02, 8.710e+01, -9.840e+01, + -4.230e+01, -1.898e+02, 1.860e+01, -3.570e+01, 9.810e+01, + 4.680e+01, 1.182e+02, 4.020e+01, -1.900e+00, 1.508e+02, + 1.094e+02, 1.018e+02, -4.620e+01, 1.591e+02, -2.320e+01, + // 5th value; torch gets 7.1 + -8.450e+01, -4.600e+00, 6.330e+01, 1.123e+02, -7.000e+00, + 1.101e+02, -6.620e+01, 2.090e+01, -5.120e+01, 8.990e+01, + 9.050e+01, -6.990e+01, 6.800e+01, -9.250e+01, 1.380e+02, + 4.720e+01, 4.710e+01, 6.210e+01, 8.870e+01, 2.098e+02, + 3.870e+01, -1.390e+01, 6.270e+01, 1.484e+02, -9.920e+01, + -4.200e+01, -1.505e+02, -1.480e+01, -2.620e+01, 8.220e+01, + -3.350e+01, -2.260e+01, -1.198e+02, -5.080e+01, 1.259e+02, + 5.600e+01, 9.270e+01, 1.209e+02, 6.590e+01, -8.330e+01, + 7.000e+00, -2.600e+01, -1.133e+02, 3.870e+01, 4.020e+01, + -6.300e+00, -8.710e+01, -5.150e+01, -8.510e+01, 2.000e-01, + 3.640e+01, -6.100e+00, 6.590e+01, -2.700e+00, 6.550e+01, + // 4th value; torch gets 3.8 + 5.300e+00, -6.760e+01, -4.270e+01, -3.900e+00, 2.880e+01, + 5.260e+01, 6.170e+01, -1.203e+02, -1.610e+01, 7.740e+01, + -1.008e+02, -1.070e+01, -9.900e+00, 3.300e+00, -2.620e+01, + -4.440e+01, 2.580e+01, -6.920e+01, -4.220e+01, 1.108e+02, + 1.240e+01, -3.440e+01, -2.800e+00, 7.880e+01, -6.690e+01, + 1.480e+01, 2.310e+01, -4.260e+01, -1.500e+00, -4.760e+01, + 5.350e+01, -2.260e+01, 8.000e-01, -3.840e+01, -2.500e+00 + ] + ); + + Ok(()) +} + +test_device!(conv1d, conv1d_cpu, conv1d_gpu, conv1d_metal); +test_device!( + conv1d_small, + conv1d_small_cpu, + conv1d_small_gpu, + conv1d_small_metal +); +test_device!(conv2d, conv2d_cpu, conv2d_gpu, conv2d_metal); +test_device!( + conv2d_non_square, + conv2d_non_square_cpu, + conv2d_non_square_gpu, + conv2d_non_square_metal +); +test_device!( + conv2d_small, + conv2d_small_cpu, + conv2d_small_gpu, + conv2d_small_metal +); +test_device!( + conv2d_smaller, + conv2d_smaller_cpu, + conv2d_smaller_gpu, + conv2d_smaller_metal +); +test_device!( + conv2d_grad, + conv2d_grad_cpu, + conv2d_grad_gpu, + conv2_grad_metal +); diff --git a/mcandle/mcandle-core/tests/custom_op_tests.rs b/mcandle/mcandle-core/tests/custom_op_tests.rs new file mode 100644 index 000000000..bab9623a7 --- /dev/null +++ b/mcandle/mcandle-core/tests/custom_op_tests.rs @@ -0,0 +1,181 @@ +use mcandle_core::backend::BackendStorage; +use mcandle_core::cpu_backend; +use mcandle_core::test_utils::to_vec1_round; +use mcandle_core::{CpuStorage, CustomOp1, DType, Device, Error, Layout, Result, Shape, Tensor}; + +fn fwd(v: T, alpha: f64) -> T { + if v.is_sign_positive() { + v + } else { + let alpha = T::from(alpha).unwrap_or(T::nan()); + (v.exp() - T::one()) * alpha + } +} + +struct Elu { + alpha: f64, +} + +impl CustomOp1 for Elu { + fn name(&self) -> &'static str { + "elu" + } + + fn cpu_fwd(&self, s: &CpuStorage, l: &Layout) -> Result<(CpuStorage, Shape)> { + let storage = mcandle_core::map_dtype!( + "elu", + s, + |s| cpu_backend::unary_map(s, l, |v| fwd(v, self.alpha)), + (BF16, F16, F32, F64) + ); + Ok((storage, l.shape().clone())) + } +} + +#[test] +fn custom_op1_no_backward() -> Result<()> { + let cpu = &Device::Cpu; + let t = Tensor::arange(0u32, 12u32, cpu)?.to_dtype(DType::F32)?; + let t = (t - 5.)?; + let elu_t = t.apply_op1_no_bwd(&Elu { alpha: 1. })?; + assert_eq!( + to_vec1_round(&elu_t, 4)?, + &[-0.9933, -0.9817, -0.9502, -0.8647, -0.6321, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0] + ); + Ok(()) +} + +// Define a similar struct as Elu but with backward support. +fn bwd(v: T, alpha: f64) -> T { + if v.is_sign_positive() { + T::one() + } else { + let alpha = T::from(alpha).unwrap_or(T::nan()); + v.exp() * alpha + } +} + +struct EluBackward { + alpha: f64, +} + +impl CustomOp1 for EluBackward { + fn name(&self) -> &'static str { + "elu-bwd" + } + + fn cpu_fwd(&self, s: &CpuStorage, l: &Layout) -> Result<(CpuStorage, Shape)> { + let storage = mcandle_core::map_dtype!( + "elu-bwd", + s, + |s| cpu_backend::unary_map(s, l, |v| bwd(v, self.alpha)), + (BF16, F16, F32, F64) + ); + Ok((storage, l.shape().clone())) + } +} + +struct EluWithBackward(Elu); + +impl EluWithBackward { + fn new(alpha: f64) -> Self { + Self(Elu { alpha }) + } +} + +impl CustomOp1 for EluWithBackward { + fn name(&self) -> &'static str { + "elu" + } + + fn cpu_fwd(&self, s: &CpuStorage, l: &Layout) -> Result<(CpuStorage, Shape)> { + self.0.cpu_fwd(s, l) + } + + fn bwd(&self, arg: &Tensor, _res: &Tensor, grad_res: &Tensor) -> Result> { + let alpha = self.0.alpha; + let bwd = arg.apply_op1(EluBackward { alpha })?; + Ok(Some(grad_res.mul(&bwd)?)) + } +} + +#[test] +fn custom_op1_with_backward() -> Result<()> { + let cpu = &Device::Cpu; + let t = mcandle_core::Var::new(&[-2f32, 0f32, 2f32], cpu)?; + let elu_t = t.apply_op1(EluWithBackward::new(2.))?; + assert_eq!(to_vec1_round(&elu_t, 4)?, &[-1.7293, 0.0, 2.0]); + + let grads = elu_t.backward()?; + let grad_x = grads.get(&t).unwrap(); + assert_eq!(to_vec1_round(grad_x, 4)?, [0.2707, 1.0, 1.0]); + + Ok(()) +} + +impl mcandle_core::InplaceOp1 for Elu { + fn name(&self) -> &'static str { + "elu" + } + + fn cpu_fwd(&self, s: &mut CpuStorage, _l: &Layout) -> Result<()> { + let alpha = self.alpha; + match s { + CpuStorage::BF16(s) => s.iter_mut().for_each(|v| *v = fwd(*v, alpha)), + CpuStorage::F16(s) => s.iter_mut().for_each(|v| *v = fwd(*v, alpha)), + CpuStorage::F32(s) => s.iter_mut().for_each(|v| *v = fwd(*v, alpha)), + CpuStorage::F64(s) => s.iter_mut().for_each(|v| *v = fwd(*v, alpha)), + _ => mcandle_core::bail!("unsupported dtype for inplace elu"), + } + Ok(()) + } +} + +#[test] +fn inplace_op1() -> Result<()> { + let cpu = &Device::Cpu; + let t = Tensor::arange(0u32, 12u32, cpu)?.to_dtype(DType::F32)?; + let t = (t - 5.)?; + t.inplace_op1(&Elu { alpha: 1. })?; + assert_eq!( + to_vec1_round(&t, 4)?, + &[-0.9933, -0.9817, -0.9502, -0.8647, -0.6321, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0] + ); + Ok(()) +} + +#[cfg(any(feature = "cuda", feature = "metal"))] +#[allow(clippy::approx_constant)] +#[test] +fn ug_op() -> Result<()> { + let kernel = { + use ug::lang::op; + + let layout = ug::Layout::from_shape(&[12]); + let ptr = op::Arg::ptr(ug::DType::F32); + let src = op::load(ptr.id(), layout.clone(), ug::DType::F32)?; + let src = op::unary(op::UnaryOp::Exp, src)?; + let st = op::store(ptr.id(), layout, src)?; + let kernel = op::Kernel::new("exp".to_string(), vec![ptr], vec![st]); + let opts: ug::lower_op::Opts = Default::default(); + kernel.lower(&opts.with_global(0, 12))? + }; + let device = if mcandle_core::utils::cuda_is_available() { + Device::new_cuda(0)? + } else if mcandle_core::utils::metal_is_available() { + Device::new_metal(0)? + } else { + mcandle_core::bail!("metal/cuda is mandatory for this test") + }; + let op = mcandle_core::UgIOp1::new("test", kernel, &device)?; + let t = Tensor::arange(0u32, 12u32, &device)?.to_dtype(DType::F32)?; + t.inplace_op1(&op)?; + assert_eq!( + to_vec1_round(&t, 2)?, + &[ + 1.0, 2.72, 7.39, 20.09, 54.6, 148.41, 403.43, 1096.63, 2980.96, 8103.08, 22026.47, + 59874.13 + ] + ); + Ok(()) +} diff --git a/mcandle/mcandle-core/tests/display_tests.rs b/mcandle/mcandle-core/tests/display_tests.rs new file mode 100644 index 000000000..60a4f2602 --- /dev/null +++ b/mcandle/mcandle-core/tests/display_tests.rs @@ -0,0 +1,84 @@ +use anyhow::Result; +use mcandle_core::{DType, Device::Cpu, Tensor}; + +#[test] +fn display_scalar() -> Result<()> { + let t = Tensor::new(1234u32, &Cpu)?; + let s = format!("{t}"); + assert_eq!(&s, "[1234]\nTensor[[], u32]"); + let t = t.to_dtype(DType::F32)?.neg()?; + let s = format!("{}", (&t / 10.0)?); + assert_eq!(&s, "[-123.4000]\nTensor[[], f32]"); + let s = format!("{}", (&t / 1e8)?); + assert_eq!(&s, "[-1.2340e-5]\nTensor[[], f32]"); + let s = format!("{}", (&t * 1e8)?); + assert_eq!(&s, "[-1.2340e11]\nTensor[[], f32]"); + let s = format!("{}", (&t * 0.)?); + assert_eq!(&s, "[0.]\nTensor[[], f32]"); + Ok(()) +} + +#[test] +fn display_vector() -> Result<()> { + let t = Tensor::new::<&[u32; 0]>(&[], &Cpu)?; + let s = format!("{t}"); + assert_eq!(&s, "[]\nTensor[[0], u32]"); + let t = Tensor::new(&[0.1234567, 1.0, -1.2, 4.1, f64::NAN], &Cpu)?; + let s = format!("{t}"); + assert_eq!( + &s, + "[ 0.1235, 1.0000, -1.2000, 4.1000, NaN]\nTensor[[5], f64]" + ); + let t = (Tensor::ones(50, DType::F32, &Cpu)? * 42.)?; + let s = format!("\n{t}"); + let expected = r#" +[42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., + 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., + 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., 42., + 42., 42.] +Tensor[[50], f32]"#; + assert_eq!(&s, expected); + let t = (Tensor::ones(11000, DType::F32, &Cpu)? * 42.)?; + let s = format!("{t}"); + assert_eq!( + &s, + "[42., 42., 42., ..., 42., 42., 42.]\nTensor[[11000], f32]" + ); + Ok(()) +} + +#[test] +fn display_multi_dim() -> Result<()> { + let t = (Tensor::ones((200, 100), DType::F32, &Cpu)? * 42.)?; + let s = format!("\n{t}"); + let expected = r#" +[[42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.], + ... + [42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.]] +Tensor[[200, 100], f32]"#; + assert_eq!(&s, expected); + let t = t.reshape(&[2, 1, 1, 100, 100])?; + let t = format!("\n{t}"); + let expected = r#" +[[[[[42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.], + ... + [42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.]]]], + [[[[42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.], + ... + [42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.], + [42., 42., 42., ..., 42., 42., 42.]]]]] +Tensor[[2, 1, 1, 100, 100], f32]"#; + assert_eq!(&t, expected); + Ok(()) +} diff --git a/mcandle/mcandle-core/tests/fortran_tensor_3d.pth b/mcandle/mcandle-core/tests/fortran_tensor_3d.pth new file mode 100644 index 000000000..bd50b03d9 Binary files /dev/null and b/mcandle/mcandle-core/tests/fortran_tensor_3d.pth differ diff --git a/mcandle/mcandle-core/tests/grad_tests.rs b/mcandle/mcandle-core/tests/grad_tests.rs new file mode 100644 index 000000000..6d5ba4912 --- /dev/null +++ b/mcandle/mcandle-core/tests/grad_tests.rs @@ -0,0 +1,533 @@ +#![allow(clippy::approx_constant)] +use anyhow::{Context, Result}; +use mcandle_core::{test_device, test_utils, Device, Shape, Tensor, Var}; + +fn simple_grad(device: &Device) -> Result<()> { + let x = Var::new(&[3f32, 1., 4.], device)?; + let x = x.as_tensor(); + let y = (((x * x)? + x * 5f64)? + 4f64)?; + let grads = y.backward()?; + let grad_x = grads.get(x).context("no grad for x")?; + assert_eq!(x.to_vec1::()?, [3., 1., 4.]); + // y = x^2 + 5.x + 4 + assert_eq!(y.to_vec1::()?, [28., 10., 40.]); + // dy/dx = 2.x + 5 + assert_eq!(grad_x.to_vec1::()?, [11., 7., 13.]); + Ok(()) +} + +fn sum_grad(device: &Device) -> Result<()> { + let x = Var::new(&[3f32, 1., 4.], device)?; + let x = x.as_tensor(); + let y = (x.sqr()?.sum_keepdim(0)? * 2.)?; + let grads = y.backward()?; + let grad_x = grads.get(x).context("no grad for x")?; + assert_eq!(y.to_vec1::()?, [52.]); + // y = 2.x^2 so dy/dx = 4.x + assert_eq!(grad_x.to_vec1::()?, &[12., 4., 16.]); + + // Same test as before but squeezing on the last dimension. + let y = (x.sqr()?.sum_keepdim(0)? * 2.)?.squeeze(0)?; + let grads = y.backward()?; + let grad_x = grads.get(x).context("no grad for x")?; + assert_eq!(y.to_scalar::()?, 52.); + // y = 2.x^2 so dy/dx = 4.x + assert_eq!(grad_x.to_vec1::()?, &[12., 4., 16.]); + Ok(()) +} + +fn matmul_grad(device: &Device) -> Result<()> { + let data: Vec<_> = (0..12).map(|i| i as f32).collect(); + let x = Var::from_slice(&data, (2, 2, 3), device)?; + let data: Vec<_> = (0..12).map(|i| i as f32).collect(); + let y = Var::from_slice(&data, (2, 3, 2), device)?; + let c = x.matmul(&y)?; + let grads = c.backward()?; + let grad_x = grads.get(&x).context("no grad for x")?; + let grad_y = grads.get(&y).context("no grad for y")?; + assert_eq!(grad_x.shape(), &Shape::from((2, 2, 3))); + assert_eq!(grad_y.shape(), &Shape::from((2, 3, 2))); + assert_eq!( + &*grad_x.to_vec3::()?, + &[ + [[1., 5., 9.], [1., 5., 9.]], + [[13., 17., 21.], [13., 17., 21.]] + ] + ); + assert_eq!( + &*grad_y.to_vec3::()?, + &[ + [[3., 3.], [5., 5.], [7., 7.]], + [[15., 15.], [17., 17.], [19., 19.]] + ] + ); + Ok(()) +} + +// The simplest gradient descent, using scalar variable. +fn grad_descent(device: &Device) -> Result<()> { + let x = Var::new(0f32, device)?; + let learning_rate = 0.1; + for _step in 0..100 { + let xt = x.as_tensor(); + let c = ((xt - 4.2)? * (xt - 4.2)?)?; + let grads = c.backward()?; + let x_grad = grads.get(&x).context("no grad for x")?; + x.set(&(xt - x_grad * learning_rate)?)? + } + assert_eq!(x.to_scalar::()?, 4.199999); + Ok(()) +} + +fn unary_grad(device: &Device) -> Result<()> { + let x = Var::new(&[3f32, 1., 4., 0.15], device)?; + let x = x.as_tensor(); + let y = (x.log()? + 1.)?; + let grads = y.backward()?; + let grad_x = grads.get(x).context("no grad for x")?; + assert_eq!( + test_utils::to_vec1_round(&y, 4)?, + [2.0986, 1.0, 2.3863, -0.8971] + ); + assert_eq!( + test_utils::to_vec1_round(grad_x, 4)?, + [0.3333, 1.0, 0.25, 6.6667] + ); + let y = x.exp()?; + let grads = y.backward()?; + let grad_x = grads.get(x).context("no grad for x")?; + assert_eq!( + test_utils::to_vec1_round(&y, 4)?, + [20.0855, 2.7183, 54.5982, 1.1618] + ); + assert_eq!( + test_utils::to_vec1_round(grad_x, 4)?, + [20.0855, 2.7183, 54.5982, 1.1618] + ); + let y = x.exp()?.sqr()?; + let grads = y.backward()?; + let grad_x = grads.get(x).context("no grad for x")?; + assert_eq!( + test_utils::to_vec1_round(&y, 3)?, + [403.429, 7.389, 2980.958, 1.35] + ); + // exp(x)^2 = exp(2*x) + assert_eq!( + test_utils::to_vec1_round(grad_x, 2)?, + [806.86, 14.78, 5961.92, 2.7] + ); + let y = x.sin()?; + let grads = y.backward()?; + let grad_x = grads.get(x).context("no grad for x")?; + assert_eq!( + test_utils::to_vec1_round(&y, 4)?, + [0.1411, 0.8415, -0.7568, 0.1494], + ); + assert_eq!( + test_utils::to_vec1_round(grad_x, 4)?, + [-0.99, 0.5403, -0.6536, 0.9888], + ); + let y = x.cos()?; + let grads = y.backward()?; + let grad_x = grads.get(x).context("no grad for x")?; + assert_eq!( + test_utils::to_vec1_round(&y, 4)?, + [-0.99, 0.5403, -0.6536, 0.9888], + ); + assert_eq!( + test_utils::to_vec1_round(grad_x, 4)?, + [-0.1411, -0.8415, 0.7568, -0.1494], + ); + let y = x.sqr()?; + let grads = y.backward()?; + let grad_x = grads.get(x).context("no grad for x")?; + assert_eq!(y.to_vec1::()?, [9.0, 1.0, 16.0, 0.0225]); + assert_eq!(grad_x.to_vec1::()?, [6.0, 2.0, 8.0, 0.3]); + let y = x.sqr()?.sqrt()?; + let grads = y.backward()?; + let grad_x = grads.get(x).context("no grad for x")?; + assert_eq!(y.to_vec1::()?, [3.0, 1.0, 4.0, 0.15]); + assert_eq!(test_utils::to_vec1_round(grad_x, 4)?, [1.0, 1.0, 1.0, 1.0]); + let y = x.neg()?; + let grads = y.backward()?; + let grad_x = grads.get(x).context("no grad for x")?; + assert_eq!(y.to_vec1::()?, [-3.0, -1.0, -4.0, -0.15]); + assert_eq!(grad_x.to_vec1::()?, [-1.0, -1.0, -1.0, -1.0]); + let y = x.affine(0.2, 1.)?; + let grads = y.backward()?; + let grad_x = grads.get(x).context("no grad for x")?; + assert_eq!(y.to_vec1::()?, [1.6, 1.2, 1.8, 1.03]); + assert_eq!(grad_x.to_vec1::()?, [0.2, 0.2, 0.2, 0.2]); + let y = Tensor::new(1f32, device)?.broadcast_div(x)?; + let grads = y.backward()?; + let grad_x = grads.get(x).context("no grad for x")?; + assert_eq!( + test_utils::to_vec1_round(&y, 4)?, + [0.3333, 1.0, 0.25, 6.6667] + ); + assert_eq!( + grad_x.to_vec1::()?, + [-0.11111111, -1.0, -0.0625, -44.444443], + ); + let y = x.broadcast_div(&Tensor::new(0.5f32, device)?)?; + let grads = y.backward()?; + let grad_x = grads.get(x).context("no grad for x")?; + assert_eq!(y.to_vec1::()?, [6., 2., 8., 0.3]); + assert_eq!(grad_x.to_vec1::()?, [2., 2., 2., 2.]); + + let x = Var::new(&[3f32, 1., 4., 0.15], device)?; + let y = x.powf(2.5)?; + let grads = y.backward()?; + let grad_x = grads.get(&x).context("no grad for x")?; + assert_eq!(test_utils::to_vec1_round(&y, 2)?, [15.59, 1.0, 32.0, 0.01]); + assert_eq!( + test_utils::to_vec1_round(grad_x, 2)?, + [12.99, 2.5, 20.0, 0.15] + ); + + let y = x.tanh()?; + let grads = y.backward()?; + let grad_x = grads.get(&x).context("no grad for x")?; + assert_eq!(test_utils::to_vec1_round(&y, 2)?, [1.0, 0.76, 1.0, 0.15]); + assert_eq!( + test_utils::to_vec1_round(grad_x, 2)?, + [0.01, 0.42, 0.0, 0.98], + ); + + // testing compared to pytorch nn.GELU(approximate = 'tanh') + let y = x.gelu()?; + let grads = y.backward()?; + let grad_x = grads.get(&x).context("no grad for x")?; + assert_eq!( + test_utils::to_vec1_round(&y, 4)?, + [2.9964, 0.8412, 3.9999, 0.0839] + ); + assert_eq!( + test_utils::to_vec1_round(grad_x, 4)?, + [1.0116, 1.0830, 1.0003, 0.6188], + ); + + // Testing compared to pytorch torch.erf + // + // import torch + // x = torch.tensor([3.0, 1.0, 4.0, 0.15], requires_grad=True) + // y = x.erf() + // print(y) + // loss = y.sum() + // loss.backward() + // print(x.grad) + let y = x.erf()?; + let grads = y.backward()?; + let grad_x = grads.get(&x).context("no grad for x")?; + assert_eq!(test_utils::to_vec1_round(&y, 4)?, [1.0, 0.8427, 1.0, 0.168]); + assert_eq!( + test_utils::to_vec1_round(grad_x, 4)?, + [0.0001, 0.4151, 0.0, 1.1033], + ); + + // Testing compared to pytorch nn.GELU(approximate = 'none') + // + // import torch + // import torch.nn.functional as F + // x = torch.tensor([3.0, 1.0, 4.0, 0.15], requires_grad=True) + // y = F.gelu(x, approximate='none') + // print(y) + // loss = y.sum() + // loss.backward() + // print(x.grad) + let y = x.gelu_erf()?; + let grads = y.backward()?; + let grad_x = grads.get(&x).context("no grad for x")?; + assert_eq!( + test_utils::to_vec1_round(&y, 4)?, + [2.9960, 0.8413, 3.9999, 0.0839] + ); + assert_eq!( + test_utils::to_vec1_round(grad_x, 4)?, + [1.0119, 1.0833, 1.0005, 0.6188], + ); + + // Testing compared to pytorch elu + // + // import torch + // import torch.nn.functional as F + // x = torch.tensor([-1.0, 0.0, -2.0, 3.0], requires_grad=True) + // y = F.elu(x, alpha=2.0) + // print(y) + // loss = y.min + // loss = y.sum() + // loss.backward() + // print(x.grad) + let elu_x = Var::new(&[-1.0f32, 0., -2., 3.], device)?; + let y = elu_x.elu(2.)?; + let grads = y.backward()?; + let grad_x = grads.get(&elu_x).context("no grad for x")?; + + assert_eq!( + test_utils::to_vec1_round(&y, 4)?, + [-1.2642, 0.0000, -1.7293, 3.0000] + ); + assert_eq!( + test_utils::to_vec1_round(grad_x, 4)?, + [0.7358, 2.0000, 0.2707, 1.0000] + ); + + // testing compared to pytorch nn.Silu() + let y = x.silu()?; + let grads = y.backward()?; + let grad_x = grads.get(&x).context("no grad for x")?; + assert_eq!( + test_utils::to_vec1_round(&y, 4)?, + [2.8577, 0.7311, 3.9281, 0.0806] + ); + assert_eq!( + test_utils::to_vec1_round(grad_x, 4)?, + [1.0881, 0.9277, 1.0527, 0.5747], + ); + + if device.is_cpu() { + let x = Var::new(&[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]], device)?; + let y = x.interpolate1d(12)?.reshape(36)?; + + let z = Tensor::new( + &[ + 1_f32, 02., 03., 04., 05., 06., 07., 08., 09., 10., 11., 12., 13., 14., 15., 16., + 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., + 33., 34., 35., 36., + ], + device, + )?; + + let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?; + let grads = loss.backward()?; + let grad_x = grads.get(&x).context("no grad for x")?; + + assert_eq!( + test_utils::to_vec3_round(grad_x, 4)?, + [[[10_f32, 26., 42.], [58., 74., 90.], [106., 122., 138.]]] + ); + } + + // manually checked: see comments + let x = Var::new(&[[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]]], device)?; + let y = x.interpolate2d(6, 6)?.reshape(36)?; + + let z = Tensor::new( + &[ + 1_f32, 02., 03., 04., 05., 06., 07., 08., 09., 10., 11., 12., 13., 14., 15., 16., 17., + 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34., + 35., 36., + ], + device, + )?; + // gradient should be + // row 1 + // 1+2+7+8 = 18 + // 3+4+9+10 = 26 + // 5+6+11+12 = 34 + // row 2 + // 13+14+19+20 = 66 + // 15+16+21+22 = 74 + // 17+18+23+24 = 82 + // row 3 + // 25+26+31+32 = 114 + // 27+28+33+34 = 122 + // 29+30+35+36 = 130 + let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?; + + let grads = loss.backward()?; + + let grad_x = grads.get(&x).context("no grad for x")?; + assert_eq!( + test_utils::to_vec2_round(&grad_x.flatten(0, 2)?, 4)?, + [[18_f32, 26., 34.], [66., 74., 82.], [114., 122., 130.]] + ); + + // manually checked: see comments + let x = Var::new(&[[[[1f32, 2.], [4., 5.]]]], device)?; + let y = x.interpolate2d(6, 6)?.reshape(36)?; + + let z = Tensor::new( + &[ + 1_f32, 02., 03., 04., 05., 06., 07., 08., 09., 10., 11., 12., 13., 14., 15., 16., 17., + 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34., + 35., 36., + ], + device, + )?; + // gradient should be + // row 1 + // 1+2+3+7+8+9+13+14+15 = 72 + // 4+5+6+10+11+12+16+17+18 = 99 + // row 2 + // 19+20+21+25+26+27+31+32+33 = 234 + // 22+23+24+28+29+30+34+35+36 = 243 + let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?; + + let grads = loss.backward()?; + + let grad_x = grads.get(&x).context("no grad for x")?; + assert_eq!( + test_utils::to_vec2_round(&grad_x.flatten(0, 2)?, 4)?, + [[72_f32, 99.], [234., 261.]] + ); + + // manually checked: see comments + let x = Var::new(&[[[[1f32, 2.], [4., 5.]], [[6f32, 7.], [8., 9.]]]], device)?; + + let y = x.interpolate2d(4, 4)?.reshape(32)?; + + #[rustfmt::skip] + let z = Tensor::new( + &[ + 1_f32, 02., 03., 04., + 05., 06., 07., 08., + 09., 10., 11., 12., + 13., 14., 15., 16., + 17., 18., 19., 20., + 21., 22., 23., 24., + 25., 26., 27., 28., + 29., 30., 31., 32. + ], + device, + )?; + // gradient should be + // m1r1 + // 1+2+5+6=14 + // 3+4+7+8=22 + // m1r2 + // 9+10+13+14=46 + // 11+12+15+16=54 + // m2r1 + // 17+18+21+22=78 + // 19+20+23+24=86 + // m2r2 + // 25+26+29+30=110 + // 27+28+31+32=118 + let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?; + + let grads = loss.backward()?; + + let grad_x = grads.get(&x).context("no grad for x")?; + + assert_eq!( + test_utils::to_vec3_round(&grad_x.flatten(0, 1)?, 4)?, + [[[14_f32, 22.], [46., 54.]], [[78., 86.], [110., 118.]]] + ); + + // manually checked: see comments + let x = Var::new( + &[[[[1f32, 2.], [4., 5.]]], [[[6f32, 7.], [8., 9.]]]], + device, + )?; + + let y = x.interpolate2d(4, 4)?.reshape(32)?; + + #[rustfmt::skip] + let z = Tensor::new( + &[ + 1_f32, 02., 03., 04., + 05., 06., 07., 08., + 09., 10., 11., 12., + 13., 14., 15., 16., + 17., 18., 19., 20., + 21., 22., 23., 24., + 25., 26., 27., 28., + 29., 30., 31., 32. + ], + device, + )?; + // gradient should be + // m1r1 + // 1+2+5+6=14 + // 3+4+7+8=22 + // m1r2 + // 9+10+13+14=46 + // 11+12+15+16=54 + // m2r1 + // 17+18+21+22=78 + // 19+20+23+24=86 + // m2r2 + // 25+26+29+30=110 + // 27+28+31+32=118 + let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?; + + let grads = loss.backward()?; + + let grad_x = grads.get(&x).context("no grad for x")?; + + assert_eq!( + test_utils::to_vec3_round(&grad_x.flatten(0, 1)?, 4)?, + [[[14_f32, 22.], [46., 54.]], [[78., 86.], [110., 118.]]] + ); + Ok(()) +} + +fn binary_grad(device: &Device) -> Result<()> { + let x = Var::new(&[3f32, 1., -4., -1.], device)?; + let x = x.as_tensor(); + // leaky relu + let y = x.maximum(&(x * 0.1)?)?; + let grads = y.backward()?; + let grad_x = grads.get(x).context("no grad for x")?; + assert_eq!(x.to_vec1::()?, [3., 1., -4., -1.]); + assert_eq!(y.to_vec1::()?, [3., 1., -0.4, -0.1]); + assert_eq!(grad_x.to_vec1::()?, [1., 1., 0.1, 0.1]); + + let y = x.minimum(&(x * 0.1)?)?; + let grads = y.backward()?; + let grad_x = grads.get(x).context("no grad for x")?; + assert_eq!(y.to_vec1::()?, [0.3, 0.1, -4., -1.]); + assert_eq!(grad_x.to_vec1::()?, [0.1, 0.1, 1., 1.]); + + // This one is easy to mess up, we want the gradient to be one as it is the identity function. + let y = x.minimum(x)?; + let grads = y.backward()?; + let grad_x = grads.get(x).context("no grad for x")?; + assert_eq!(y.to_vec1::()?, [3., 1., -4., -1.]); + assert_eq!(grad_x.to_vec1::()?, [1., 1., 1., 1.]); + + let x_var = Var::new(&[3f32, 1., -4., -1., 5., 9.], device)?; + let x = x_var.as_tensor(); + let y_var = Var::new(&[2f32, 7., 1.], device)?; + let y = y_var.as_tensor(); + + let ss = x + .reshape((2, 3))? + .slice_scatter0(&y.reshape((1, 3))?, 1)? + .sqr()?; + let grads = ss.backward()?; + let grad_x = grads.get(x).context("no grad for x")?; + let grad_y = grads.get(y).context("no grad for y")?; + assert_eq!(ss.to_vec2::()?, [[9., 1., 16.], [4., 49., 1.]]); + assert_eq!(grad_x.to_vec1::()?, [6.0, 2.0, -8.0, 0.0, 0.0, 0.0]); + assert_eq!(grad_y.to_vec1::()?, [4.0, 14.0, 2.0]); + Ok(()) +} + +test_device!( + simple_grad, + simple_grad_cpu, + simple_grad_gpu, + simple_grad_metal +); +test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu, sum_grad_metal); +test_device!( + matmul_grad, + matmul_grad_cpu, + matmul_grad_gpu, + matmul_grad_metal +); +test_device!( + grad_descent, + grad_descent_cpu, + grad_descent_gpu, + grad_descent_metal +); +test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu, unary_grad_metal); +test_device!( + binary_grad, + binary_grad_cpu, + binary_grad_gpu, + binary_grad_metal +); diff --git a/mcandle/mcandle-core/tests/indexing_tests.rs b/mcandle/mcandle-core/tests/indexing_tests.rs new file mode 100644 index 000000000..0687ff4c4 --- /dev/null +++ b/mcandle/mcandle-core/tests/indexing_tests.rs @@ -0,0 +1,217 @@ +use anyhow::Result; +use mcandle_core::{Device, IndexOp, Tensor}; + +#[test] +fn integer_index() -> Result<()> { + let dev = Device::Cpu; + + let tensor = Tensor::arange(0u32, 2 * 3, &dev)?.reshape((2, 3))?; + let result = tensor.i(1)?; + assert_eq!(result.dims(), &[3]); + assert_eq!(result.to_vec1::()?, &[3, 4, 5]); + + let result = tensor.i((.., 2))?; + assert_eq!(result.dims(), &[2]); + assert_eq!(result.to_vec1::()?, &[2, 5]); + + Ok(()) +} + +#[test] +fn range_index() -> Result<()> { + let dev = Device::Cpu; + // RangeFull + let tensor = Tensor::arange(0u32, 2 * 3, &dev)?.reshape((2, 3))?; + let result = tensor.i(..)?; + assert_eq!(result.dims(), &[2, 3]); + assert_eq!(result.to_vec2::()?, &[[0, 1, 2], [3, 4, 5]]); + + // Range + let tensor = Tensor::arange(0u32, 4 * 3, &dev)?.reshape((4, 3))?; + let result = tensor.i(1..3)?; + assert_eq!(result.dims(), &[2, 3]); + assert_eq!(result.to_vec2::()?, &[[3, 4, 5], [6, 7, 8]]); + + // RangeFrom + let result = tensor.i(2..)?; + assert_eq!(result.dims(), &[2, 3]); + assert_eq!(result.to_vec2::()?, &[[6, 7, 8], [9, 10, 11]]); + + // RangeTo + let result = tensor.i(..2)?; + assert_eq!(result.dims(), &[2, 3]); + assert_eq!(result.to_vec2::()?, &[[0, 1, 2], [3, 4, 5]]); + + // RangeInclusive + let result = tensor.i(1..=2)?; + assert_eq!(result.dims(), &[2, 3]); + assert_eq!(result.to_vec2::()?, &[[3, 4, 5], [6, 7, 8]]); + + // RangeTo + let result = tensor.i(..1)?; + assert_eq!(result.dims(), &[1, 3]); + assert_eq!(result.to_vec2::()?, &[[0, 1, 2]]); + + // RangeToInclusive + let result = tensor.i(..=1)?; + assert_eq!(result.dims(), &[2, 3]); + assert_eq!(result.to_vec2::()?, &[[0, 1, 2], [3, 4, 5]]); + + // Empty range + let result = tensor.i(1..1)?; + assert_eq!(result.dims(), &[0, 3]); + let empty: [[u32; 3]; 0] = []; + assert_eq!(result.to_vec2::()?, &empty); + + // Similar to PyTorch, allow empty ranges when the computed length is negative. + #[allow(clippy::reversed_empty_ranges)] + let result = tensor.i(1..0)?; + assert_eq!(result.dims(), &[0, 3]); + let empty: [[u32; 3]; 0] = []; + assert_eq!(result.to_vec2::()?, &empty); + Ok(()) +} + +#[test] +fn index_3d() -> Result<()> { + let tensor = Tensor::from_iter(0..24u32, &Device::Cpu)?.reshape((2, 3, 4))?; + assert_eq!(tensor.i((0, 0, 0))?.to_scalar::()?, 0); + assert_eq!(tensor.i((1, 0, 0))?.to_scalar::()?, 12); + assert_eq!(tensor.i((0, 1, 0))?.to_scalar::()?, 4); + assert_eq!(tensor.i((0, 1, 3))?.to_scalar::()?, 7); + assert_eq!(tensor.i((0..2, 0, 0))?.to_vec1::()?, &[0, 12]); + assert_eq!( + tensor.i((0..2, .., 0))?.to_vec2::()?, + &[[0, 4, 8], [12, 16, 20]] + ); + assert_eq!( + tensor.i((..2, .., 3))?.to_vec2::()?, + &[[3, 7, 11], [15, 19, 23]] + ); + assert_eq!(tensor.i((1, .., 3))?.to_vec1::()?, &[15, 19, 23]); + Ok(()) +} + +#[test] +fn slice_assign_range() -> Result<()> { + let dev = Device::Cpu; + + let tensor = Tensor::arange(0u32, 4 * 5, &dev)?.reshape((4, 5))?; + let src = Tensor::arange(100u32, (2 * 3) + 100, &dev)?.reshape((3, 2))?; + let out = tensor.slice_assign(&[&(1..4), &(3..5)], &src)?; + assert_eq!( + out.to_vec2::()?, + &[ + [0, 1, 2, 3, 4], + [5, 6, 7, 100, 101], + [10, 11, 12, 102, 103], + [15, 16, 17, 104, 105] + ] + ); + let out = tensor.slice_assign(&[&(0..3), &(0..2)], &src)?; + assert_eq!( + out.to_vec2::()?, + &[ + [100, 101, 2, 3, 4], + [102, 103, 7, 8, 9], + [104, 105, 12, 13, 14], + [15, 16, 17, 18, 19] + ] + ); + Ok(()) +} + +#[test] +fn slice_assign_to() -> Result<()> { + let dev = Device::Cpu; + + let tensor = Tensor::arange(0u32, 4 * 5, &dev)?.reshape((4, 5))?; + let src = Tensor::arange(100u32, (2 * 3) + 100, &dev)?.reshape((3, 2))?; + let out = tensor.slice_assign(&[&(..3), &(3..5)], &src)?; + assert_eq!( + out.to_vec2::()?, + &[ + [0, 1, 2, 100, 101], + [5, 6, 7, 102, 103], + [10, 11, 12, 104, 105], + [15, 16, 17, 18, 19] + ] + ); + Ok(()) +} + +#[test] +fn slice_assign_from() -> Result<()> { + let dev = Device::Cpu; + + let tensor = Tensor::arange(0u32, 4 * 5, &dev)?.reshape((4, 5))?; + let src = Tensor::arange(100u32, (2 * 3) + 100, &dev)?.reshape((3, 2))?; + let out = tensor.slice_assign(&[&(1..), &(0..2)], &src)?; + assert_eq!( + out.to_vec2::()?, + &[ + [0, 1, 2, 3, 4], + [100, 101, 7, 8, 9], + [102, 103, 12, 13, 14], + [104, 105, 17, 18, 19] + ] + ); + Ok(()) +} + +#[test] +fn slice_assign_to_incl() -> Result<()> { + let dev = Device::Cpu; + + let tensor = Tensor::arange(0u32, 4 * 5, &dev)?.reshape((4, 5))?; + let src = Tensor::arange(100u32, (2 * 3) + 100, &dev)?.reshape((3, 2))?; + let out = tensor.slice_assign(&[&(..=2), &(1..3)], &src)?; + assert_eq!( + out.to_vec2::()?, + &[ + [0, 100, 101, 3, 4], + [5, 102, 103, 8, 9], + [10, 104, 105, 13, 14], + [15, 16, 17, 18, 19] + ] + ); + Ok(()) +} + +#[test] +fn slice_assign_full() -> Result<()> { + let dev = Device::Cpu; + + let tensor = Tensor::arange(0u32, 4 * 5, &dev)?.reshape((4, 5))?; + let src = Tensor::arange(100u32, (2 * 4) + 100, &dev)?.reshape((4, 2))?; + let out = tensor.slice_assign(&[&(..), &(3..5)], &src)?; + assert_eq!( + out.to_vec2::()?, + &[ + [0, 1, 2, 100, 101], + [5, 6, 7, 102, 103], + [10, 11, 12, 104, 105], + [15, 16, 17, 106, 107] + ] + ); + Ok(()) +} + +#[test] +fn slice_assign_exact() -> Result<()> { + let dev = Device::Cpu; + + let tensor = Tensor::arange(0u32, 4 * 5, &dev)?.reshape((4, 5))?; + let src = Tensor::arange(100u32, 2 + 100, &dev)?.reshape((1, 2))?; + let out = tensor.slice_assign(&[&0, &(3..5)], &src)?; + assert_eq!( + out.to_vec2::()?, + &[ + [0, 1, 2, 100, 101], + [5, 6, 7, 8, 9], + [10, 11, 12, 13, 14], + [15, 16, 17, 18, 19] + ] + ); + Ok(()) +} diff --git a/mcandle/mcandle-core/tests/layout_tests.rs b/mcandle/mcandle-core/tests/layout_tests.rs new file mode 100644 index 000000000..bb3e6231b --- /dev/null +++ b/mcandle/mcandle-core/tests/layout_tests.rs @@ -0,0 +1,151 @@ +use mcandle_core::{test_device, Device, IndexOp, Result, Tensor}; + +fn contiguous(device: &Device) -> Result<()> { + let tensor = Tensor::arange(0u32, 24u32, device)?.reshape((2, 3, 4))?; + assert_eq!( + tensor.to_vec3::()?, + &[ + [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], + [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]] + ] + ); + assert_eq!( + tensor.t()?.contiguous()?.to_vec3::()?, + &[ + [[0, 4, 8], [1, 5, 9], [2, 6, 10], [3, 7, 11]], + [[12, 16, 20], [13, 17, 21], [14, 18, 22], [15, 19, 23]] + ] + ); + assert_eq!( + tensor.transpose(0, 1)?.contiguous()?.to_vec3::()?, + &[ + [[0, 1, 2, 3], [12, 13, 14, 15]], + [[4, 5, 6, 7], [16, 17, 18, 19]], + [[8, 9, 10, 11], [20, 21, 22, 23]] + ] + ); + assert_eq!( + tensor.transpose(0, 1)?.flatten_all()?.to_vec1::()?, + &[0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 16, 17, 18, 19, 8, 9, 10, 11, 20, 21, 22, 23] + ); + assert_eq!( + tensor + .i(1..)? + .transpose(0, 1)? + .contiguous()? + .to_vec3::()?, + &[[[12, 13, 14, 15]], [[16, 17, 18, 19]], [[20, 21, 22, 23]]] + ); + assert_eq!( + tensor.transpose(0, 2)?.contiguous()?.to_vec3::()?, + &[ + [[0, 12], [4, 16], [8, 20]], + [[1, 13], [5, 17], [9, 21]], + [[2, 14], [6, 18], [10, 22]], + [[3, 15], [7, 19], [11, 23]] + ] + ); + Ok(()) +} + +test_device!(contiguous, contiguous_cpu, contiguous_gpu, contiguous_metal); + +#[test] +fn strided_blocks() -> Result<()> { + use mcandle_core::Device::Cpu; + let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?; + match tensor.strided_blocks() { + mcandle_core::StridedBlocks::SingleBlock { start_offset, len } => { + assert_eq!(start_offset, 0); + assert_eq!(len, 24); + } + mcandle_core::StridedBlocks::MultipleBlocks { .. } => { + panic!("unexpected block structure") + } + }; + let tensor = Tensor::arange(0u32, 26u32, &Cpu)? + .i(2..)? + .reshape((2, 3, 4))?; + match tensor.strided_blocks() { + mcandle_core::StridedBlocks::SingleBlock { start_offset, len } => { + assert_eq!(start_offset, 2); + assert_eq!(len, 24); + } + mcandle_core::StridedBlocks::MultipleBlocks { .. } => { + panic!("unexpected block structure") + } + }; + let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?; + let tensor = tensor.i(1)?; + match tensor.strided_blocks() { + mcandle_core::StridedBlocks::SingleBlock { start_offset, len } => { + assert_eq!(start_offset, 12); + assert_eq!(len, 12); + } + mcandle_core::StridedBlocks::MultipleBlocks { .. } => { + panic!("unexpected block structure") + } + }; + let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?; + let tensor = tensor.i((.., 1))?.contiguous()?; + match tensor.strided_blocks() { + mcandle_core::StridedBlocks::SingleBlock { start_offset, len } => { + assert_eq!(start_offset, 0); + assert_eq!(len, 8); + assert_eq!(tensor.to_vec2::()?, &[[4, 5, 6, 7], [16, 17, 18, 19]]); + } + mcandle_core::StridedBlocks::MultipleBlocks { .. } => { + panic!("unexpected block structure") + } + }; + let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?; + let tensor = tensor.i((.., 1))?; + match tensor.strided_blocks() { + mcandle_core::StridedBlocks::SingleBlock { .. } => { + panic!("unexpected block structure") + } + mcandle_core::StridedBlocks::MultipleBlocks { + block_len, + block_start_index, + } => { + assert_eq!(block_len, 4); + assert_eq!(block_start_index.collect::>(), &[4, 16]) + } + }; + let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?; + match tensor.t()?.strided_blocks() { + mcandle_core::StridedBlocks::SingleBlock { .. } => { + panic!("unexpected block structure") + } + mcandle_core::StridedBlocks::MultipleBlocks { + block_start_index, + block_len, + } => { + assert_eq!(block_len, 1); + assert_eq!( + block_start_index.collect::>(), + &[ + 0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11, 12, 16, 20, 13, 17, 21, 14, 18, 22, 15, + 19, 23 + ] + ) + } + }; + let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?; + match tensor.transpose(0, 1)?.strided_blocks() { + mcandle_core::StridedBlocks::SingleBlock { .. } => { + panic!("unexpected block structure") + } + mcandle_core::StridedBlocks::MultipleBlocks { + block_start_index, + block_len, + } => { + assert_eq!(block_len, 4); + assert_eq!( + block_start_index.collect::>(), + &[0, 12, 4, 16, 8, 20] + ) + } + }; + Ok(()) +} diff --git a/mcandle/mcandle-core/tests/matmul_tests.rs b/mcandle/mcandle-core/tests/matmul_tests.rs new file mode 100644 index 000000000..7f2d9b873 --- /dev/null +++ b/mcandle/mcandle-core/tests/matmul_tests.rs @@ -0,0 +1,172 @@ +use mcandle_core::{test_device, DType, Device, IndexOp, Result, Tensor}; + +fn matmul(device: &Device) -> Result<()> { + let data = vec![1.0f32, 2.0, 3.0, 4.0]; + let a = Tensor::from_slice(&data, (2, 2), device)?; + let data = vec![1.0f32, 2.0, 3.0, 4.0]; + let b = Tensor::from_slice(&data, (2, 2), device)?; + + let c = a.matmul(&b)?; + assert_eq!(c.to_vec2::()?, &[[7.0f32, 10.0], [15.0, 22.0]]); + + let data = vec![1.0f32, 2.0]; + let a = Tensor::from_slice(&data, (2, 1), device)?; + let data = vec![3.0f32, 4.0]; + let b = Tensor::from_slice(&data, (1, 2), device)?; + let c = a.matmul(&b)?; + assert_eq!(c.to_vec2::()?, &[&[3.0, 4.0], &[6.0, 8.0]]); + + let data: Vec<_> = (0..6).map(|i| i as f32).collect(); + let a = Tensor::from_slice(&data, (2, 3), device)?; + let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect(); + let b = Tensor::from_slice(&data, (3, 2), device)?; + let c = a.matmul(&b)?; + assert_eq!(c.to_vec2::()?, &[&[16., 19.], &[52., 64.]]); + + let data: Vec<_> = (0..12).map(|i| i as f32).collect(); + let a = Tensor::from_slice(&data, (2, 2, 3), device)?; + let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect(); + let b = Tensor::from_slice(&data, (2, 3, 2), device)?; + let expected = [[[16., 19.], [52., 64.]], [[214., 235.], [304., 334.]]]; + + let c = a.matmul(&b)?; + assert_eq!(c.to_vec3::()?, &expected); + + // Also perform the matmul on contiguous transposed versions. + let a_tt = a.t()?.contiguous()?.t()?; + assert!(!a_tt.is_contiguous()); + assert_eq!(a.dims(), a_tt.dims()); + assert_eq!(a_tt.stride(), &[6, 1, 2]); + + let b_tt = b.t()?.contiguous()?.t()?; + assert!(!b_tt.is_contiguous()); + assert_eq!(b.dims(), b_tt.dims()); + assert_eq!(b_tt.stride(), &[6, 1, 3]); + + assert_eq!(a_tt.matmul(&b)?.to_vec3::()?, &expected); + assert_eq!(a.matmul(&b_tt)?.to_vec3::()?, &expected); + assert_eq!(a_tt.matmul(&b_tt)?.to_vec3::()?, &expected); + Ok(()) +} + +fn matmul_bf16(device: &Device) -> Result<()> { + if !device.supports_bf16() { + return Ok(()); + } + let data = vec![1.0f32, 2.0, 3.0, 4.0]; + let a = Tensor::from_slice(&data, (2, 2), device)?.to_dtype(DType::BF16)?; + let data = vec![1.0f32, 2.0, 3.0, 4.0]; + let b = Tensor::from_slice(&data, (2, 2), device)?.to_dtype(DType::BF16)?; + + let c = a.matmul(&b)?.to_dtype(DType::F32)?; + assert_eq!(c.to_vec2::()?, &[[7.0f32, 10.0], [15.0, 22.0]]); + Ok(()) +} + +fn broadcast_matmul(device: &Device) -> Result<()> { + let lhs = Tensor::randn(0f32, 1f32, (3, 1, 4, 5), device)?; + let rhs = Tensor::randn(0f32, 1f32, (6, 5, 2), device)?; + let out = lhs.broadcast_matmul(&rhs)?; + assert_eq!(out.dims(), &[3, 6, 4, 2]); + for idx1 in 0..3 { + for idx2 in 0..6 { + let out = out.i((idx1, idx2))?; + let lhs = lhs.i((idx1, 0))?; + let rhs = rhs.i(idx2)?; + let out2 = lhs.matmul(&rhs); + let sum_diff2 = (out - out2)?.sqr()?.sum_all()?; + // With cuda, we see errors of up to ~1e-12. + assert!(sum_diff2.to_vec0::()? < 1e-6) + } + } + Ok(()) +} + +// https://github.com/huggingface/mcandle/issues/1948 +fn squeeze_mm(device: &Device) -> Result<()> { + let seq_len = 8_usize; + let a = Tensor::zeros((1, seq_len, 16), DType::F32, device)?; + let x = a.i((.., seq_len - 1, ..))?; + let w = Tensor::zeros((32, 16), DType::F32, device)?.t()?; + let x = x.matmul(&w)?; + assert_eq!(x.dims(), &[1, 32]); + Ok(()) +} + +// https://github.com/huggingface/mcandle/issues/1992 +fn mm_layout(device: &Device) -> Result<()> { + let a = Tensor::arange(0f32, 16f32, device)?.reshape((1, 1, 4, 4))?; + let b = Tensor::arange(0f32, 8f32, device)?.reshape((1, 1, 4, 2))?; + let mm1 = a.matmul(&b)?; + // Forces the layout to be: + // shape: [1, 1, 4, 2], stride: [8, 2, 2, 1], start_offset: 0 + // This is still a contiguous matrix but matmul checks are only the two last dimensions have + // non 1 sizes but matmul check may be reluctant to handle it. + let b = b.transpose(1, 2)?.force_contiguous()?.transpose(1, 2)?; + let mm2 = a.matmul(&b)?; + let diff = (mm1 - mm2)?.abs()?.sum_all()?.to_vec0::()?; + assert_eq!(diff, 0.); + Ok(()) +} + +fn matmul_alpha_beta(device: &Device) -> Result<()> { + let data = vec![1.0f32, 2.0, 3.0, 4.0]; + let a = Tensor::from_slice(&data, (2, 2), device)?; + let data = vec![1.0f32, 2.0, 3.0, 4.0]; + let b = Tensor::from_slice(&data, (2, 2), device)?; + let data = vec![1.0f32, 1.0, 1.0, 1.0]; + let mut c = Tensor::from_slice(&data, (2, 2), device)?; + + a.matmul_with_alpha_beta(&b, &mut c, None)?; + assert_eq!(c.to_vec2::()?, &[[8.0f32, 11.0], [16.0, 23.0]]); + + let data = vec![1.0f32, 2.0, 3.0, 4.0]; + let a = Tensor::from_slice(&data, (2, 2), device)?; + let data = vec![1.0f32, 2.0, 3.0, 4.0]; + let b = Tensor::from_slice(&data, (2, 2), device)?; + let data = vec![1.0f32, 1.0, 1.0, 1.0]; + let mut c = Tensor::from_slice(&data, (2, 2), device)?; + + a.matmul_with_alpha_beta(&b, &mut c, Some(2.))?; + assert_eq!(c.to_vec2::()?, &[[15.0f32, 21.0], [31.0, 45.0]]); + Ok(()) +} + +fn matmul_alpha(device: &Device) -> Result<()> { + let data = vec![1.0f32, 2.0, 3.0, 4.0]; + let a = Tensor::from_slice(&data, (2, 2), device)?; + let data = vec![1.0f32, 2.0, 3.0, 4.0]; + let b = Tensor::from_slice(&data, (2, 2), device)?; + + let c = a.matmul_with_alpha(&b, Some(2.))?; + assert_eq!(c.to_vec2::()?, &[[14.0f32, 20.0], [30.0, 44.0]]); + Ok(()) +} + +test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal); +test_device!( + matmul_alpha_beta, + matmul_alpha_beta_cpu, + matmul_alpha_beta_gpu, + matmul_alpha_beta_metal +); +test_device!( + matmul_alpha, + matmul_alpha_cpu, + matmul_alpha_gpu, + matmul_alpha_metal +); +test_device!( + matmul_bf16, + matmul_bf16_cpu, + matmul_bf16_gpu, + matmul_bf16_metal +); +test_device!( + broadcast_matmul, + broadcast_matmul_cpu, + broadcast_matmul_gpu, + broadcast_matmul_metal +); +test_device!(squeeze_mm, squeeze_mm_cpu, squeeze_mm_gpu, squeeze_mm_metal); +test_device!(mm_layout, mm_layout_cpu, mm_layout_gpu, mm_layout_metal); diff --git a/mcandle/mcandle-core/tests/npy.py b/mcandle/mcandle-core/tests/npy.py new file mode 100644 index 000000000..0fd2778a1 --- /dev/null +++ b/mcandle/mcandle-core/tests/npy.py @@ -0,0 +1,9 @@ +import numpy as np +x = np.arange(10) + +# Write a npy file. +np.save("test.npy", x) + +# Write multiple values to a npz file. +values = { "x": x, "x_plus_one": x + 1 } +np.savez("test.npz", **values) diff --git a/mcandle/mcandle-core/tests/pool_tests.rs b/mcandle/mcandle-core/tests/pool_tests.rs new file mode 100644 index 000000000..a3d6d6a3d --- /dev/null +++ b/mcandle/mcandle-core/tests/pool_tests.rs @@ -0,0 +1,117 @@ +use mcandle_core::{test_device, test_utils, Device, IndexOp, Result, Tensor}; + +// https://github.com/huggingface/mcandle/issues/364 +fn avg_pool2d(dev: &Device) -> Result<()> { + let data: Vec = vec![ + 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., + ]; + let t = Tensor::from_vec(data, (1, 1, 4, 4), dev)?; + let pool = t.avg_pool2d(2)?.squeeze(0)?.squeeze(0)?; + assert_eq!(pool.to_vec2::()?, [[0.5f32, 1.], [1., 1.]]); + + let data: Vec = vec![ + 1., 2., 1., 3., 0., 0., 1., 1., 1., 1., 1., 1., 5., 1., 1., 1., + ]; + let t = Tensor::from_vec(data, (1, 1, 2, 8), dev)?; + let pool = t.avg_pool2d(2)?.squeeze(0)?.squeeze(0)?; + assert_eq!(pool.to_vec2::()?, [[5. / 4., 6. / 4., 6. / 4., 1.]]); + Ok(()) +} + +fn max_pool2d(dev: &Device) -> Result<()> { + let data: Vec = vec![ + 1., 2., 1., 3., 0., 0., 1., 1., 1., 1., 1., 1., 5., 1., 1., 1., + ]; + let t = Tensor::from_vec(data, (1, 1, 4, 4), dev)?; + + let pool = t.max_pool2d(2)?.squeeze(0)?.squeeze(0)?; + assert_eq!(pool.to_vec2::()?, [[2f32, 3.], [5., 1.]]); + + let t = t.reshape((1, 1, 2, 8))?; + let pool = t.max_pool2d(2)?.squeeze(0)?.squeeze(0)?; + assert_eq!(pool.to_vec2::()?, [[2.0, 3.0, 5.0, 1.0]]); + Ok(()) +} + +/* This test corresponds to the following PyTorch script. +import torch +torch.manual_seed(4242) + +t = torch.randn((1, 2, 4, 4)) +print(t.flatten()) +res = torch.nn.functional.avg_pool2d(t, 2) +print(res) +*/ +fn avg_pool2d_pytorch(dev: &Device) -> Result<()> { + if dev.is_metal() { + return Ok(()); + } + let t = Tensor::new( + &[ + 0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616, + 1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699, 0.0823, 0.3526, 0.6843, 0.2395, + 1.2279, -0.9287, -1.7030, 0.1370, 0.6047, 0.3770, -0.6266, 0.3529, 2.2013, -0.6836, + 0.2477, 1.3127, + ], + dev, + )? + .reshape((1, 2, 4, 4))?; + let pool = t.avg_pool2d(2)?.squeeze(0)?; + assert_eq!( + test_utils::to_vec3_round(&pool, 4)?, + [ + [[-1.1926, -0.0395], [0.2688, 0.1871]], + [[0.1835, -0.1606], [0.6249, 0.3217]] + ] + ); + let pool = t.avg_pool2d(3)?.squeeze(0)?; + assert_eq!( + test_utils::to_vec3_round(&pool, 4)?, + [[[0.085]], [[0.0078]]] + ); + + let t = t.reshape((1, 1, 4, 8))?; + let pool = t.avg_pool2d(2)?.squeeze(0)?.squeeze(0)?; + assert_eq!( + test_utils::to_vec2_round(&pool, 4)?, + [ + [0.7745, 0.0276, -1.6983, 0.12], + [0.3542, 0.1625, 0.4542, -0.0014] + ] + ); + Ok(()) +} + +fn upsample_nearest2d(dev: &Device) -> Result<()> { + let t = Tensor::arange(0f32, 6f32, dev)?.reshape((1, 1, 2, 3))?; + let upsampled = t.upsample_nearest2d(4, 6)?.i(0)?.i(0)?; + assert_eq!( + t.i(0)?.i(0)?.to_vec2::()?, + [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + ); + assert_eq!( + upsampled.to_vec2::()?, + [ + [0.0, 0.0, 1.0, 1.0, 2.0, 2.0], + [0.0, 0.0, 1.0, 1.0, 2.0, 2.0], + [3.0, 3.0, 4.0, 4.0, 5.0, 5.0], + [3.0, 3.0, 4.0, 4.0, 5.0, 5.0] + ] + ); + Ok(()) +} + +test_device!(avg_pool2d, avg_pool2d_cpu, avg_pool2d_gpu, avg_pool2d_metal); +test_device!( + avg_pool2d_pytorch, + avg_pool2d_pytorch_cpu, + avg_pool2d_pytorch_gpu, + avg_pool2d_pytorch_metal +); +test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu, max_pool2d_metal); +test_device!( + upsample_nearest2d, + upsample_nearest2d_cpu, + upsample_nearest2d_gpu, + upsample_nearest2d_metal +); diff --git a/mcandle/mcandle-core/tests/pth.py b/mcandle/mcandle-core/tests/pth.py new file mode 100644 index 000000000..5c787c203 --- /dev/null +++ b/mcandle/mcandle-core/tests/pth.py @@ -0,0 +1,37 @@ +import torch +from collections import OrderedDict + +# Write a trivial tensor to a pt file +a= torch.tensor([[1,2,3,4], [5,6,7,8]]) +o = OrderedDict() +o["test"] = a + +# Write a trivial tensor to a pt file +torch.save(o, "test.pt") + +############################################################################################################ +# Write a trivial tensor to a pt file with a key +torch.save({"model_state_dict": o}, "test_with_key.pt") + +############################################################################################################ +# Create a tensor with fortran contiguous memory layout +import numpy as np + +# Step 1: Create a 3D NumPy array with Fortran order using a range of numbers +# For example, creating a 2x3x4 array +array_fortran = np.asfortranarray(np.arange(1, 2*3*4 + 1).reshape(2, 3, 4)) + +# Verify the memory order +print("Is Fortran contiguous (F order):", array_fortran.flags['F_CONTIGUOUS']) # Should be True +print("Is C contiguous (C order):", array_fortran.flags['C_CONTIGUOUS']) # Should be False + +# Step 2: Convert the NumPy array to a PyTorch tensor +tensor_fortran = torch.from_numpy(array_fortran) + +# Verify the tensor layout +print("Tensor stride:", tensor_fortran.stride()) # Stride will reflect the Fortran memory layout + +# Step 3: Save the PyTorch tensor to a .pth file +torch.save({"tensor_fortran": tensor_fortran}, 'fortran_tensor_3d.pth') + +print("3D Tensor saved with Fortran layout.") diff --git a/mcandle/mcandle-core/tests/pth_tests.rs b/mcandle/mcandle-core/tests/pth_tests.rs new file mode 100644 index 000000000..d3c458599 --- /dev/null +++ b/mcandle/mcandle-core/tests/pth_tests.rs @@ -0,0 +1,31 @@ +/// Regression test for pth files not loading on Windows. +#[test] +fn test_pth() { + let tensors = mcandle_core::pickle::PthTensors::new("tests/test.pt", None).unwrap(); + tensors.get("test").unwrap().unwrap(); +} + +#[test] +fn test_pth_with_key() { + let tensors = + mcandle_core::pickle::PthTensors::new("tests/test_with_key.pt", Some("model_state_dict")) + .unwrap(); + tensors.get("test").unwrap().unwrap(); +} + +#[test] +fn test_pth_fortran_congiguous() { + let tensors = + mcandle_core::pickle::PthTensors::new("tests/fortran_tensor_3d.pth", None).unwrap(); + let tensor = tensors.get("tensor_fortran").unwrap().unwrap(); + + assert_eq!(tensor.dims3().unwrap(), (2, 3, 4)); + + assert_eq!( + tensor.to_vec3::().unwrap(), + [ + [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], + [[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]] + ] + ); +} diff --git a/mcandle/mcandle-core/tests/quantized_tests.rs b/mcandle/mcandle-core/tests/quantized_tests.rs new file mode 100644 index 000000000..5cf83efc8 --- /dev/null +++ b/mcandle/mcandle-core/tests/quantized_tests.rs @@ -0,0 +1,1187 @@ +use mcandle_core::{ + bail, + quantized::{self, GgmlDType}, + test_device, + test_utils::to_vec2_round, + DType, Device, IndexOp, Module, Result, Tensor, +}; +use quantized::{k_quants, GgmlType}; +use rand::prelude::*; + +const GGML_TEST_SIZE: usize = 32 * 128; + +const GGML_MAX_QUANTIZATION_TOTAL_ERROR: f32 = 0.002; +const GGML_MAX_QUANTIZATION_TOTAL_ERROR_2BITS: f32 = 0.0075; +const GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS: f32 = 0.0040; +const GGML_MAX_DOT_PRODUCT_ERROR: f32 = 0.02; + +fn test_matmul( + device: &Device, + (b, m, n, k): (usize, usize, usize, usize), + dtype: GgmlDType, +) -> Result<()> { + let lhs = (0..(m * k)) + .map(|v| v as f32 / (m * k) as f32) + .collect::>(); + let rhs = (0..(k * n)) + .map(|v| v as f32 / (n * k) as f32) + .collect::>(); + + let lhs = Tensor::from_slice(&lhs, (m, k), device)?; + let rhs = Tensor::from_slice(&rhs, (k, n), device)?; + let mm = lhs.matmul(&rhs)?; + let qtensor = quantized::QTensor::quantize(&rhs.t()?, dtype)?; + let matmul = quantized::QMatMul::from_qtensor(qtensor)?; + let res = matmul.forward(&lhs)?; + + let error: f32 = ((&mm - &res)?.abs()? / &mm.abs()?)? + .sum_all()? + .to_scalar()?; + let error = error / (b * m * n) as f32; + assert!( + error <= 0.02, + "Error {error} is too big. \nExpected:\n {mm} \nFound:\n {res}\n for {dtype:?}" + ); + + Ok(()) +} + +#[cfg(feature = "metal")] +#[test] +fn test_matmul_mm() -> Result<()> { + let dtype = GgmlDType::Q8_0; + let device = Device::new_metal(0)?; + + let m = 32; + let n = 32; + let k = 32; + let lhs = (0..(m * k)) + .map(|v| v as f32 / (m * k) as f32) + .collect::>(); + let rhs = (0..(k * n)) + .map(|v| v as f32 / (n * k) as f32) + .collect::>(); + + let lhs = Tensor::from_slice(&lhs, (m, k), &device)?; + let rhs = Tensor::from_slice(&rhs, (1, 1, k, n), &device)?.repeat((5, 20, 1, 1))?; + let mm = lhs.broadcast_matmul(&rhs)?; + let qtensor = quantized::QTensor::quantize(&lhs.t()?, dtype)?; + let matmul = quantized::QMatMul::from_qtensor(qtensor)?; + let res = matmul.forward(&rhs)?; + + let error: f32 = ((&mm - &res)?.abs()? / &mm.abs()?)? + .sum_all()? + .to_scalar()?; + + let error = error / res.elem_count() as f32; + assert!( + error <= 0.001, + "Error {error} is too big. \nExpected:\n {mm} \nFound:\n {res}\n for {dtype:?}" + ); + + Ok(()) +} + +fn quantized_matmul(device: &Device) -> Result<()> { + let (m, k, n) = (3, 64, 4); + let lhs_s = (0..(m * k)).map(|v| v as f32).collect::>(); + let lhs = Tensor::from_slice(&lhs_s, (m, k), device)?; + let mut dst = vec![42.; 3 * 4]; + let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8]; + let rhs = (0..(k * n)).map(|v| v as f32).collect::>(); + k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?; + k_quants::matmul((m, k, n), &lhs_s, &rhs_t, &mut dst)?; + assert_eq!( + dst.iter().map(|x| x.round()).collect::>(), + &[ + 85120.0, 214562.0, 345455.0, 474748.0, 213475.0, 604465.0, 1000686.0, 1388317.0, + 341876.0, 994283.0, 1655709.0, 2301518.0 + ] + ); + let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?; + let mm = lhs.matmul(&tensor_rhs)?; + assert_eq!( + mm.to_vec2::()?, + &[ + [85344.0, 214368.0, 343392.0, 472416.0], + [214368.0, 605536.0, 996704.0, 1387872.0], + [343392.0, 996704.0, 1650016.0, 2303328.0] + ] + ); + + let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?; + let matmul = quantized::QMatMul::from_qtensor(qtensor)?; + let res = matmul.forward(&lhs)?; + match device { + Device::Metal(_) => assert_eq!( + to_vec2_round(&res, 0)?, + &[ + [84946.0, 214126.0, 344757.0, 473798.0], + [213458.0, 604350.0, 1000469.0, 1387990.0], + [341970.0, 994574.0, 1656181.0, 2302182.0] + ] + ), + Device::Cuda(_) => assert_eq!( + to_vec2_round(&res, 0)?, + &[ + [84866.0, 214045.0, 344676.0, 473707.0], + [213425.0, 604313.0, 1000431.0, 1387960.0], + [342030.0, 994630.0, 1656248.0, 2302250.0] + ] + ), + Device::Cpu => assert_eq!( + to_vec2_round(&res, 0)?, + &[ + [85120.0, 214562.0, 345455.0, 474748.0], + [213475.0, 604465.0, 1000686.0, 1388317.0], + [341876.0, 994283.0, 1655709.0, 2301518.0] + ] + ), + } + test_matmul(device, (1, 3, 4, 256), GgmlDType::Q4_0)?; + Ok(()) +} + +fn quantized_matmul_neg(device: &Device) -> Result<()> { + let (m, k, n) = (3, 64, 4); + let lhs_s = (0..(m * k)) + .map(|v| v as f32 - (m * k) as f32 / 2.0) + .collect::>(); + let lhs = Tensor::from_slice(&lhs_s, (m, k), device)?; + let mut dst = vec![42.; 3 * 4]; + let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8]; + let rhs = (0..k * n) + .map(|v| v as f32 - (k * n) as f32 / 3.0) + .collect::>(); + let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?; + k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?; + k_quants::matmul((m, k, n), &lhs_s, &rhs_t, &mut dst)?; + assert_eq!( + dst.iter().map(|x| x.round()).collect::>(), + &[ + 243524.0, -19596.0, -285051.0, -549815.0, 23777.0, 21651.0, 19398.0, 18367.0, + -196472.0, 63012.0, 324585.0, 587902.0 + ] + ); + let mm = lhs.matmul(&tensor_rhs)?; + assert_eq!( + to_vec2_round(&mm, 0)?, + &[ + [244064.0, -20128.0, -284320.0, -548512.0], + [23563.0, 21515.0, 19467.0, 17419.0], + [-196939.0, 63157.0, 323253.0, 583349.0] + ] + ); + + let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?; + let matmul = quantized::QMatMul::from_qtensor(qtensor)?; + let res = matmul.forward(&lhs)?; + match device { + Device::Metal(_) => assert_eq!( + to_vec2_round(&res, 0)?, + &[ + [243666.0, -19714.0, -285433.0, -550453.0], + [23782.0, 21654.0, 19400.0, 18369.0], + [-196102.0, 63022.0, 324233.0, 587191.0] + ] + ), + Device::Cuda(_) => assert_eq!( + to_vec2_round(&res, 0)?, + &[ + [243740.0, -19762.0, -285476.0, -550498.0], + [23774.0, 21645.0, 19395.0, 18364.0], + [-196045.0, 63030.0, 324120.0, 587079.0] + ] + ), + Device::Cpu => assert_eq!( + to_vec2_round(&res, 0)?, + &[ + [243524.0, -19596.0, -285051.0, -549815.0], + [23777.0, 21651.0, 19398.0, 18367.0], + [-196472.0, 63012.0, 324585.0, 587902.0] + ] + ), + } + let lhs2 = Tensor::stack(&[&lhs, &lhs], 0)?; + let res2 = matmul.forward(&lhs2)?; + let res2 = res2.i(1)?; + let diff = (&res - res2)?.abs()?.mean_all()?.to_vec0::()? / res.elem_count() as f32; + if device.is_cuda() { + assert!(diff < 0.1); + } else { + assert!(diff < 0.96); + } + Ok(()) +} + +fn qmm_batch(dev: &Device) -> Result<()> { + let (lhs, rhs, _mm) = get_random_tensors(2, 256, 6, dev)?; + let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q2K)?; + let rhs = quantized::QMatMul::from_qtensor(rhs)?; + let mm = rhs.forward(&lhs)?; + assert_eq!(mm.shape().dims(), [2, 6]); + let lhs2 = Tensor::cat(&[&lhs, &lhs], 0)?; + let mm2 = rhs.forward(&lhs2)?; + assert_eq!(mm2.shape().dims(), [4, 6]); + let diff2 = (mm2.i(2..)? - &mm)?.abs()?.sum_all()?.to_vec0::()?; + assert_eq!(diff2, 0.0); + let lhs3 = Tensor::cat(&[&lhs2, &lhs], 0)?; + let mm3 = rhs.forward(&lhs3)?; + assert_eq!(mm3.shape().dims(), [6, 6]); + let diff3 = (mm3.i(2..4)? - &mm)?.abs()?.sum_all()?.to_vec0::()?; + assert_eq!(diff3, 0.0); + let diff3 = (mm3.i(4..)? - &mm)?.abs()?.sum_all()?.to_vec0::()?; + assert_eq!(diff3, 0.0); + let lhs4 = Tensor::cat(&[&lhs3, &lhs3], 0)?; + let mm4 = rhs.forward(&lhs4)?; + assert_eq!(mm4.shape().dims(), [12, 6]); + let diff4 = (mm4.i(..6)? - &mm3)?.abs()?.sum_all()?.to_vec0::()?; + if dev.is_cuda() { + // We use a different kernel for sizes from 1 to 8 on cuda which explains + // the difference here. + assert!(0. < diff4 && diff4 < 1e-4) + } else { + assert_eq!(diff4, 0.0) + }; + let diff4 = (mm4.i(6..)? - &mm4.i(..6)?)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff4, 0.0); + Ok(()) +} + +test_device!(quantized_matmul, qmm_cpu, qmm_cuda, qmm_metal); +test_device!(quantized_matmul_neg, qmm_n_cpu, qmm_n_cuda, qmm_n_metal); +test_device!(qmm_batch, qmm_b_cpu, qmm_b_cuda, qmm_b_metal); + +fn quantize_q4_0(device: &Device) -> Result<()> { + let src = (0..32 * 4).map(|v| v as f32).collect::>(); + + let src = Tensor::from_slice(&src, (32 * 4,), device)?; + let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_0)?; + let dst = quant.dequantize(device)?; + let dst_f16 = quant.dequantize_f16(device)?; + let diff = (dst.to_dtype(DType::F16)? - dst_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); + assert_eq!( + dst.to_vec1::()?, + &[ + -0.0, -0.0, 3.875, 3.875, 3.875, 3.875, 7.75, 7.75, 7.75, 7.75, 11.625, 11.625, 11.625, + 11.625, 15.5, 15.5, 15.5, 15.5, 19.375, 19.375, 19.375, 19.375, 23.25, 23.25, 23.25, + 23.25, 27.125, 27.125, 27.125, 27.125, 31.0, 31.0, 31.5, 31.5, 31.5, 31.5, 39.375, + 39.375, 39.375, 39.375, 39.375, 39.375, 39.375, 39.375, 47.25, 47.25, 47.25, 47.25, + 47.25, 47.25, 47.25, 47.25, 55.125, 55.125, 55.125, 55.125, 55.125, 55.125, 55.125, + 55.125, 63.0, 63.0, 63.0, 63.0, 59.375, 59.375, 71.25, 71.25, 71.25, 71.25, 71.25, + 71.25, 71.25, 71.25, 71.25, 71.25, 71.25, 71.25, 83.125, 83.125, 83.125, 83.125, + 83.125, 83.125, 83.125, 83.125, 83.125, 83.125, 83.125, 83.125, 95.0, 95.0, 95.0, 95.0, + 95.0, 95.0, 95.25, 95.25, 95.25, 95.25, 95.25, 95.25, 95.25, 95.25, 111.125, 111.125, + 111.125, 111.125, 111.125, 111.125, 111.125, 111.125, 111.125, 111.125, 111.125, + 111.125, 111.125, 111.125, 111.125, 111.125, 127.0, 127.0, 127.0, 127.0, 127.0, 127.0, + 127.0, 127.0 + ] + ); + ggml_quantization_error_test(GgmlDType::Q4_0, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; + Ok(()) +} + +fn quantize_q4_1(device: &Device) -> Result<()> { + let src = (0..32 * 4).map(|v| v as f32).collect::>(); + let src = Tensor::from_slice(&src, (32 * 4,), device)?; + let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_1)?; + let dst = quant.dequantize(device)?; + let dst_f16 = quant.dequantize_f16(device)?; + let diff = (dst.to_dtype(DType::F16)? - dst_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); + assert_eq!( + round_vector(&dst.to_vec1::()?), + &[ + 0.0, 0.0, 2.066, 2.066, 4.133, 4.133, 6.199, 6.199, 8.266, 8.266, 10.332, 10.332, + 12.398, 12.398, 14.465, 14.465, 16.531, 16.531, 18.598, 18.598, 20.664, 20.664, 22.73, + 22.73, 24.797, 24.797, 26.863, 26.863, 28.93, 28.93, 30.996, 30.996, 32.0, 32.0, + 34.066, 34.066, 36.133, 36.133, 38.199, 38.199, 40.266, 40.266, 42.332, 42.332, 44.398, + 44.398, 46.465, 46.465, 48.531, 48.531, 50.598, 50.598, 52.664, 52.664, 54.73, 54.73, + 56.797, 56.797, 58.863, 58.863, 60.93, 60.93, 62.996, 62.996, 64.0, 64.0, 66.066, + 66.066, 68.133, 68.133, 70.199, 70.199, 72.266, 72.266, 74.332, 74.332, 76.398, 76.398, + 78.465, 78.465, 80.531, 80.531, 82.598, 82.598, 84.664, 84.664, 86.73, 86.73, 88.797, + 88.797, 90.863, 90.863, 92.93, 92.93, 94.996, 94.996, 96.0, 96.0, 98.066, 98.066, + 100.133, 100.133, 102.199, 102.199, 104.266, 104.266, 106.332, 106.332, 108.398, + 108.398, 110.465, 110.465, 112.531, 112.531, 114.598, 114.598, 116.664, 116.664, + 118.73, 118.73, 120.797, 120.797, 122.863, 122.863, 124.93, 124.93, 126.996, 126.996 + ] + ); + ggml_quantization_error_test(GgmlDType::Q4_1, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; + Ok(()) +} + +fn quantize_q5_0(device: &Device) -> Result<()> { + let src = (0..32 * 4).map(|v| v as f32).collect::>(); + let src = Tensor::from_slice(&src, (32 * 4,), device)?; + let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_0)?; + let dst = quant.dequantize(device)?; + let dst_f16 = quant.dequantize_f16(device)?; + let diff = (dst.to_dtype(DType::F16)? - dst_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); + assert_eq!( + round_vector(&dst.to_vec1::()?), + &[ + -0.0, 1.938, 1.938, 3.875, 3.875, 5.813, 5.813, 7.75, 7.75, 9.688, 9.688, 11.625, + 11.625, 13.563, 13.563, 15.5, 15.5, 17.438, 17.438, 19.375, 19.375, 21.313, 21.313, + 23.25, 23.25, 25.188, 25.188, 27.125, 27.125, 29.063, 29.063, 31.0, 31.5, 31.5, 35.438, + 35.438, 35.438, 35.438, 39.375, 39.375, 39.375, 39.375, 43.313, 43.313, 43.313, 43.313, + 47.25, 47.25, 47.25, 47.25, 51.188, 51.188, 51.188, 51.188, 55.125, 55.125, 55.125, + 55.125, 59.063, 59.063, 59.063, 59.063, 63.0, 63.0, 65.313, 65.313, 65.313, 65.313, + 65.313, 71.25, 71.25, 71.25, 71.25, 71.25, 71.25, 77.188, 77.188, 77.188, 77.188, + 77.188, 77.188, 83.125, 83.125, 83.125, 83.125, 83.125, 83.125, 89.063, 89.063, 89.063, + 89.063, 89.063, 89.063, 95.0, 95.0, 95.0, 95.25, 95.25, 95.25, 95.25, 103.188, 103.188, + 103.188, 103.188, 103.188, 103.188, 103.188, 103.188, 111.125, 111.125, 111.125, + 111.125, 111.125, 111.125, 111.125, 111.125, 119.063, 119.063, 119.063, 119.063, + 119.063, 119.063, 119.063, 119.063, 127.0, 127.0, 127.0, 127.0 + ] + ); + ggml_quantization_error_test(GgmlDType::Q5_0, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; + Ok(()) +} + +fn quantize_q5_1(device: &Device) -> Result<()> { + let src = (0..32 * 4).map(|v| v as f32).collect::>(); + let src = Tensor::from_slice(&src, (32 * 4,), device)?; + let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_1)?; + let dst = quant.dequantize(device)?; + let dst_f16 = quant.dequantize_f16(device)?; + let diff = (dst.to_dtype(DType::F16)? - dst_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); + assert_eq!( + round_vector(&dst.to_vec1::()?), + &[ + 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, + 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, + 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, + 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, + 58.0, 59.0, 60.0, 61.0, 62.0, 63.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, + 72.0, 73.0, 74.0, 75.0, 76.0, 77.0, 78.0, 79.0, 80.0, 81.0, 82.0, 83.0, 84.0, 85.0, + 86.0, 87.0, 88.0, 89.0, 90.0, 91.0, 92.0, 93.0, 94.0, 95.0, 96.0, 97.0, 98.0, 99.0, + 100.0, 101.0, 102.0, 103.0, 104.0, 105.0, 106.0, 107.0, 108.0, 109.0, 110.0, 111.0, + 112.0, 113.0, 114.0, 115.0, 116.0, 117.0, 118.0, 119.0, 120.0, 121.0, 122.0, 123.0, + 124.0, 125.0, 126.0, 127.0 + ] + ); + ggml_quantization_error_test(GgmlDType::Q5_1, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; + Ok(()) +} + +fn get_test_vector2(bound: f32, size: usize, device: &Device) -> Result { + assert!( + size % crate::quantized::k_quants::QK_K == 0, + "size must be a multiple of {}", + crate::quantized::k_quants::QK_K + ); + + let src = (0..size) + .map(|v| (v as f32 - size as f32 / 2.) * bound / (size as f32 / 2.)) + .collect::>(); + assert_eq!([src[0], src[size / 2]], [-bound, 0.0]); + Tensor::from_vec(src, (size,), device) +} + +/// Round a vector +fn round_vector(values: &[f32]) -> Vec { + values + .iter() + .map(|x| (1000. * x).round() / 1000.) + .collect::>() +} + +fn compare_with_error(values: &[f32], expected: &[f32], tolerance: f32) { + for (i, (value, expected_value)) in values.iter().zip(expected.iter()).enumerate() { + let difference = (value - expected_value).abs(); + + assert!( + difference < tolerance, + "Error at index {}: value = {}, expected = {}. Difference = {} exceeds tolerance = {}.", + i, + value, + expected_value, + difference, + tolerance + ); + } +} + +/// Creates a vector similar to the ones used in GGML unit tests: +/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30 +fn create_ggml_like_vector(offset: f32) -> Vec { + (0..GGML_TEST_SIZE) + .map(|i| 0.1 + 2.0 * (i as f32 + offset).cos()) + .collect() +} + +/// Calculates the root mean square error between two vectors +fn calculate_rmse(a: &[f32], b: &[f32]) -> f32 { + assert_eq!(a.len(), b.len()); + let sum = a + .iter() + .zip(b) + .map(|(a, b)| (a - b).powi(2)) + .sum::() + .sqrt(); + sum / a.len() as f32 +} + +/// Similar to the GGML quantization unit test: +/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50 +fn ggml_quantization_error_test(dtype: GgmlDType, device: &Device, max_error: f32) -> Result<()> { + let src = create_ggml_like_vector(0.0); + let src = Tensor::from_slice(&src, (GGML_TEST_SIZE,), device)?; + let quant = quantized::QTensor::quantize(&src, dtype)?; + let dst = quant.dequantize(device)?; + let dst_f16 = quant.dequantize_f16(device)?; + let diff = (dst.to_dtype(DType::F16)? - dst_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); + let error = calculate_rmse(&src.to_vec1::()?, &dst.to_vec1::()?); + if error > max_error { + bail!( + "Quantization error {} exceeds max error {}", + error, + max_error + ); + } + Ok(()) +} + +fn quantize_q2k(device: &Device) -> Result<()> { + let dtype = GgmlDType::Q2K; + + let src = get_test_vector2(0.5, 1024, device)?; + let quant = quantized::QTensor::quantize(&src, dtype)?; + let dst = quant.dequantize(device)?; + let dst_f16 = quant.dequantize_f16(device)?; + let diff = (dst.to_dtype(DType::F16)? - dst_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); + + let src = src.to_vec1::()?; + let dst = dst.to_vec1::()?; + compare_with_error(dst.as_slice(), src.as_slice(), 0.1); + + // Test some specific values + assert_eq!( + [src[0], src[128], src[256], src[512], src[800], src[1023]], + [-0.5, -0.375, -0.25, 0.0, 0.28125, 0.49902344] + ); + let dst = round_vector(&dst); + assert_eq!( + [dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]], + [-0.499, -0.366, -0.249, 0.0, 0.295, 0.492] + ); + + let src_big = get_test_vector2(128.0, 1024, device)?; + let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; + let dst_big = quant_big.dequantize(device)?; + let dst_big_f16 = quant_big.dequantize_f16(device)?; + let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); + + let src_big = src_big.to_vec1::()?; + let dst_big = dst_big.to_vec1::()?; + compare_with_error(dst_big.as_slice(), src_big.as_slice(), 6.0); + + ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR_2BITS)?; + Ok(()) +} + +fn quantize_q3k(device: &Device) -> Result<()> { + let dtype = GgmlDType::Q3K; + let src = get_test_vector2(0.5, 1024, device)?; + let quant = quantized::QTensor::quantize(&src, dtype)?; + let dst = quant.dequantize(device)?; + let dst_f16 = quant.dequantize_f16(device)?; + let diff = (dst.to_dtype(DType::F16)? - dst_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); + + let src = src.to_vec1::()?; + let dst = dst.to_vec1::()?; + compare_with_error(dst.as_slice(), src.as_slice(), 0.03); + + // Test some specific values + assert_eq!( + [src[0], src[128], src[256], src[512], src[800], src[1023]], + [-0.5, -0.375, -0.25, 0.0, 0.28125, 0.49902344] + ); + let dst = round_vector(&dst); + assert_eq!( + [dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]], + [-0.493, -0.37, -0.243, -0.0, 0.292, 0.492] + ); + + let src_big = get_test_vector2(128.0, 1024, device)?; + let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; + let dst_big = quant_big.dequantize(device)?; + let dst_big_f16 = quant_big.dequantize_f16(device)?; + let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); + + let src_big = src_big.to_vec1::()?; + let dst_big = dst_big.to_vec1::()?; + compare_with_error(dst_big.as_slice(), src_big.as_slice(), 3.5); + + ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS)?; + Ok(()) +} + +fn quantize_q4k(device: &Device) -> Result<()> { + let dtype = GgmlDType::Q4K; + let src = get_test_vector2(0.5, 1024, device)?; + let quant = quantized::QTensor::quantize(&src, dtype)?; + let dst = quant.dequantize(device)?; + let dst_f16 = quant.dequantize_f16(device)?; + let diff = (dst.to_dtype(DType::F16)? - dst_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); + + let src = src.to_vec1::()?; + let dst = dst.to_vec1::()?; + compare_with_error(dst.as_slice(), src.as_slice(), 0.017); + + // Test some specific values + assert_eq!( + [src[0], src[128], src[256], src[512], src[800], src[1023]], + [-0.5, -0.375, -0.25, 0.0, 0.28125, 0.49902344] + ); + let dst = round_vector(&dst); + assert_eq!( + [dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]], + [-0.5, -0.373, -0.25, 0.0, 0.288, 0.498] + ); + + let src_big = get_test_vector2(128.0, 1024, device)?; + let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; + let dst_big = quant_big.dequantize(device)?; + let dst_big_f16 = quant_big.dequantize_f16(device)?; + let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); + + let src_big = src_big.to_vec1::()?; + let dst_big = dst_big.to_vec1::()?; + compare_with_error(dst_big.as_slice(), src_big.as_slice(), 4.5); + + ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; + Ok(()) +} + +fn quantize_q5k(device: &Device) -> Result<()> { + let dtype = GgmlDType::Q5K; + let src = get_test_vector2(0.5, 1024, device)?; + let quant = quantized::QTensor::quantize(&src, dtype)?; + let dst = quant.dequantize(device)?; + let dst_f16 = quant.dequantize_f16(device)?; + let diff = (dst.to_dtype(DType::F16)? - dst_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); + + let src = src.to_vec1::()?; + let dst = dst.to_vec1::()?; + compare_with_error(dst.as_slice(), src.as_slice(), 0.009); + + // Test some specific values + assert_eq!( + [src[0], src[128], src[256], src[512], src[800], src[1023]], + [-0.5, -0.375, -0.25, 0.0, 0.28125, 0.49902344] + ); + let dst = round_vector(&dst); + assert_eq!( + [dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]], + [-0.5, -0.373, -0.25, 0.0, 0.279, 0.499] + ); + + let src_big = get_test_vector2(128.0, 1024, device)?; + let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; + let dst_big = quant_big.dequantize(device)?; + let dst_big_f16 = quant_big.dequantize_f16(device)?; + let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); + + let src_big = src_big.to_vec1::()?; + let dst_big = dst_big.to_vec1::()?; + compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.5); + + ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; + Ok(()) +} + +fn quantize_q6k(device: &Device) -> Result<()> { + let dtype = GgmlDType::Q6K; + let src = get_test_vector2(0.5, 1024, device)?; + let quant = quantized::QTensor::quantize(&src, dtype)?; + let dst = quant.dequantize(device)?; + let dst_f16 = quant.dequantize_f16(device)?; + let diff = (dst.to_dtype(DType::F16)? - dst_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); + + let src = src.to_vec1::()?; + let dst = dst.to_vec1::()?; + compare_with_error(dst.as_slice(), src.as_slice(), 0.008); + + // Test some specific values + assert_eq!( + [src[0], src[128], src[256], src[512], src[800], src[1023]], + [-0.5, -0.375, -0.25, 0.0, 0.28125, 0.49902344] + ); + let dst = round_vector(&dst); + assert_eq!( + [dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]], + [-0.497, -0.372, -0.25, -0.0, 0.284, 0.5] + ); + + let src_big = get_test_vector2(128.0, 1024, device)?; + let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; + let dst_big = quant_big.dequantize(device)?; + let dst_big_f16 = quant_big.dequantize_f16(device)?; + let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); + + let src_big = src_big.to_vec1::()?; + let dst_big = dst_big.to_vec1::()?; + compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.0); + + ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; + Ok(()) +} + +fn quantize_q8k(device: &Device) -> Result<()> { + let dtype = GgmlDType::Q8K; + let src = get_test_vector2(0.5, 1024, device)?; + let quant = quantized::QTensor::quantize(&src, dtype)?; + let dst = quant.dequantize(device)?; + let dst_f16 = quant.dequantize_f16(device)?; + let diff = (dst.to_dtype(DType::F16)? - dst_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); + + let src = src.to_vec1::()?; + let dst = dst.to_vec1::()?; + compare_with_error(dst.as_slice(), src.as_slice(), 0.008); + + // Test some specific values + assert_eq!( + [src[0], src[128], src[256], src[512], src[800], src[1023]], + [-0.5, -0.375, -0.25, 0.0, 0.28125, 0.49902344] + ); + let dst = round_vector(&dst); + assert_eq!( + [dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]], + [-0.5, -0.375, -0.25, -0.0, 0.281, 0.499] + ); + + let src_big = get_test_vector2(128.0, 1024, device)?; + let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; + let dst_big = quant_big.dequantize(device)?; + let dst_big_f16 = quant_big.dequantize_f16(device)?; + let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)? + .to_dtype(DType::F32)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); + + let src_big = src_big.to_vec1::()?; + let dst_big = dst_big.to_vec1::()?; + compare_with_error(dst_big.as_slice(), src_big.as_slice(), 0.6); + + ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; + Ok(()) +} + +test_device!( + quantize_q4_0, + quantize_q4_0_cpu, + quantize_q4_0_cuda, + quantize_q4_0_metal +); +test_device!( + quantize_q4_1, + quantize_q4_1_cpu, + quantize_q4_1_cuda, + quantize_q4_1_metal +); +test_device!( + quantize_q5_0, + quantize_q5_0_cpu, + quantize_q5_0_cuda, + quantize_q5_0_metal +); +test_device!( + quantize_q5_1, + quantize_q5_1_cpu, + quantize_q5_1_cuda, + quantize_q5_1_metal +); +test_device!( + quantize_q2k, + quantize_q2k_cpu, + quantize_q2k_cuda, + quantize_q2k_metal +); +test_device!( + quantize_q3k, + quantize_q3k_cpu, + quantize_q3k_cuda, + quantize_q3k_metal +); +test_device!( + quantize_q4k, + quantize_q4k_cpu, + quantize_q4k_cuda, + quantize_q4k_metal +); +test_device!( + quantize_q5k, + quantize_q5k_cpu, + quantize_q5k_cuda, + quantize_q5k_metal +); +test_device!( + quantize_q6k, + quantize_q6k_cpu, + quantize_q6k_cuda, + quantize_q6k_metal +); +test_device!( + quantize_q8k, + quantize_q8k_cpu, + quantize_q8k_cuda, + quantize_q8k_metal +); + +/// Very simple dot product implementation +fn vec_dot_reference(a: &[f32], b: &[f32]) -> f32 { + a.iter().zip(b).map(|(a, b)| a * b).sum() +} + +/// Returns the error achieved by the GGML matmul unit test. +fn ggml_reference_matmul_error(dtype: GgmlDType) -> Result { + let err = match dtype { + GgmlDType::F16 => 0.000010, + GgmlDType::Q2K => 0.004086, + GgmlDType::Q3K => 0.016148, + GgmlDType::Q4K => 0.002425, + GgmlDType::Q5K => 0.000740, + GgmlDType::Q6K => 0.000952, + GgmlDType::Q4_0 => 0.001143, + GgmlDType::Q4_1 => 0.008, + GgmlDType::Q5_0 => 0.001353, + GgmlDType::Q5_1 => 0.00149, + GgmlDType::Q8_0 => 0.000092, + + // Not from the ggml repo. + GgmlDType::Q8K => 0.00065, + _ => bail!("No GGML results for quantization type {dtype:?}",), + }; + Ok(err) +} + +/// Similar to the GGML matmul unit test: +/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L76-L91 +fn ggml_matmul_error_test() -> Result<()> { + let a = create_ggml_like_vector(0.0); + let b = create_ggml_like_vector(1.0); + ggml_matmul_error_test_::(a.as_slice(), b.as_slice(), 1.0)?; + // Another example that is more likely to trigger the overflow reported in #1526 + let a = (0..GGML_TEST_SIZE) + .map(|i| i as f32 / GGML_TEST_SIZE as f32) + .collect::>(); + let b = (0..GGML_TEST_SIZE) + .map(|i| i as f32 / GGML_TEST_SIZE as f32) + .collect::>(); + ggml_matmul_error_test_::(a.as_slice(), b.as_slice(), 2.0)?; + Ok(()) +} + +fn ggml_matmul_error_test_(a: &[f32], b: &[f32], err_m: f32) -> Result<()> { + let length = a.len(); + + let mut a_quant = vec![T::zeros(); length / T::BLCK_SIZE]; + let mut b_quant = vec![T::VecDotType::zeros(); length / T::VecDotType::BLCK_SIZE]; + T::from_float(a, &mut a_quant)?; + T::VecDotType::from_float(b, &mut b_quant)?; + + let result = T::vec_dot(length, &a_quant, &b_quant)?; + let result_unopt = T::vec_dot_unopt(length, &a_quant, &b_quant)?; + let reference_result = vec_dot_reference(a, b); + + if (result - result_unopt).abs() / length as f32 > 1e-6 { + bail!( + "the opt and unopt vec-dot returned different values, opt {result}, unopt {result_unopt}" + ) + } + + let error = (result - reference_result).abs() / length as f32; + + let ggml_error = ggml_reference_matmul_error(T::DTYPE)? * err_m; + + if !error.is_finite() || error > GGML_MAX_DOT_PRODUCT_ERROR { + bail!("Dot product error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}",); + } + + // We diverge slightly due to different rounding behavior / f16 to f32 conversions in GGML + // => we use a slightly higher error threshold + const ERROR_LENIENCY: f32 = 0.00001; + if error - ERROR_LENIENCY > ggml_error { + bail!( + "Dot product error {} exceeds ggml reference error {}", + error, + ggml_error + ); + } + Ok(()) +} + +#[test] +fn quantized_mm() -> Result<()> { + ggml_matmul_error_test::()?; + ggml_matmul_error_test::()?; + ggml_matmul_error_test::()?; + ggml_matmul_error_test::()?; + ggml_matmul_error_test::()?; + Ok(()) +} + +/// generates random tensors of size `m x k` and `n x k` and calculates their expected matrix multiplication result. +fn get_random_tensors( + m: usize, + k: usize, + n: usize, + device: &Device, +) -> Result<(Tensor, Tensor, Tensor)> { + let mut rng = StdRng::seed_from_u64(314159265358979); + + let lhs = (0..m * k) + .map(|_| rng.gen::() - 0.5) + .collect::>(); + let rhs = (0..n * k) + .map(|_| rng.gen::() - 0.5) + .collect::>(); + + let lhs = Tensor::from_vec(lhs, (m, k), device)?; + let rhs = Tensor::from_vec(rhs, (n, k), device)?; + + let mm = lhs.matmul(&rhs.t()?)?; + Ok((lhs, rhs, mm)) +} + +#[macro_export] +macro_rules! quantized_matmul { + // TODO: Switch to generating the two last arguments automatically once concat_idents is + // stable. https://github.com/rust-lang/rust/issues/29599 + ($fn_name: ident, $fn_name_cpu: ident, $fn_name_cuda: ident, $fn_name_metal: ident, $dtype: expr) => { + fn $fn_name(device: &Device) -> Result<()> { + test_matmul(device, (1, 3, 4, 256), $dtype)?; + Ok(()) + } + + test_device!($fn_name, $fn_name_cpu, $fn_name_cuda, $fn_name_metal); + }; +} + +quantized_matmul!( + quantized_matmul_q4_0_bis, + quantized_matmul_q4_0_cpu, + quantized_matmul_q4_0_cuda, + quantized_matmul_q4_0_metal, + GgmlDType::Q4_0 +); +quantized_matmul!( + quantized_matmul_q4_1_bis, + quantized_matmul_q4_1_cpu, + quantized_matmul_q4_1_cuda, + quantized_matmul_q4_1_metal, + GgmlDType::Q4_1 +); +quantized_matmul!( + quantized_matmul_q5_0_bis, + quantized_matmul_q5_0_cpu, + quantized_matmul_q5_0_cuda, + quantized_matmul_q5_0_metal, + GgmlDType::Q5_0 +); +quantized_matmul!( + quantized_matmul_q5_1_bis, + quantized_matmul_q5_1_cpu, + quantized_matmul_q5_1_cuda, + quantized_matmul_q5_1_metal, + GgmlDType::Q5_1 +); +quantized_matmul!( + quantized_matmul_q8_0_bis, + quantized_matmul_q8_0_cpu, + quantized_matmul_q8_0_cuda, + quantized_matmul_q8_0_metal, + GgmlDType::Q8_0 +); +// Not implemented in Ggml +// quantized_matmul!( +// quantized_matmul_q8_1_bis, +// quantized_matmul_q8_1_cpu, +// quantized_matmul_q8_1_cuda, +// quantized_matmul_q8_1_metal, +// GgmlDType::Q8_1 +// ); +// TODO This is bugged (also bugged in GGML +quantized_matmul!( + quantized_matmul_q2k_bis, + quantized_matmul_q2k_cpu, + quantized_matmul_q2k_cuda, + quantized_matmul_q2k_metal, + GgmlDType::Q2K +); +quantized_matmul!( + quantized_matmul_q3k_bis, + quantized_matmul_q3k_cpu, + quantized_matmul_q3k_cuda, + quantized_matmul_q3k_metal, + GgmlDType::Q3K +); +quantized_matmul!( + quantized_matmul_q4k_bis, + quantized_matmul_q4k_cpu, + quantized_matmul_q4k_cuda, + quantized_matmul_q4k_metal, + GgmlDType::Q4K +); +quantized_matmul!( + quantized_matmul_q5k_bis, + quantized_matmul_q5k_cpu, + quantized_matmul_q5k_cuda, + quantized_matmul_q5k_metal, + GgmlDType::Q5K +); +quantized_matmul!( + quantized_matmul_q6k_bis, + quantized_matmul_q6k_cpu, + quantized_matmul_q6k_cuda, + quantized_matmul_q6k_metal, + GgmlDType::Q6K +); +// Not implemented on metal +// quantized_matmul!( +// quantized_matmul_q8k_bis, +// quantized_matmul_q8k_cpu, +// quantized_matmul_q8k_cuda, +// quantized_matmul_q8k_metal, +// GgmlDType::Q8K +// ); + +#[test] +fn quantized_matmul_q2k() -> Result<()> { + use k_quants::BlockQ2K; + + let cpu = &Device::Cpu; + let (m, k, n) = (11, 512, 21); + let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?; + assert_eq!(mm.dims(), [m, n]); + let dst = mm.flatten_all()?.to_vec1::()?; + let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); + assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); + + let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q2K)?; + let rhs = quantized::QMatMul::from_qtensor(rhs)?; + let mm = rhs.forward(&lhs)?; + + assert_eq!(mm.dims(), [m, n]); + let dst = mm.flatten_all()?.to_vec1::()?; + let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); + assert_eq!(dst, [0.916, 0.422, 0.215, 1.668]); + + ggml_matmul_error_test::()?; + + Ok(()) +} + +#[test] +fn quantized_matmul_q3k() -> Result<()> { + use k_quants::BlockQ3K; + + let cpu = &Device::Cpu; + let (m, k, n) = (11, 512, 21); + let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?; + assert_eq!(mm.dims(), [m, n]); + let dst = mm.flatten_all()?.to_vec1::()?; + let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); + assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); + + let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q3K)?; + let rhs = quantized::QMatMul::from_qtensor(rhs)?; + let mm = rhs.forward(&lhs)?; + + assert_eq!(mm.dims(), [m, n]); + let dst = mm.flatten_all()?.to_vec1::()?; + let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); + assert_eq!(dst, [1.029, 1.418, -0.314, 1.495]); + + ggml_matmul_error_test::()?; + + Ok(()) +} + +#[test] +fn quantized_matmul_q4k() -> Result<()> { + use k_quants::BlockQ4K; + + let cpu = &Device::Cpu; + let (m, k, n) = (11, 512, 21); + let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?; + assert_eq!(mm.dims(), [m, n]); + let dst = mm.flatten_all()?.to_vec1::()?; + let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); + assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); + + let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q4K)?; + let rhs = quantized::QMatMul::from_qtensor(rhs)?; + let mm = rhs.forward(&lhs)?; + + assert_eq!(mm.dims(), [m, n]); + let dst = mm.flatten_all()?.to_vec1::()?; + let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); + assert_eq!(dst, [1.125, 1.435, -0.201, 1.589]); + + ggml_matmul_error_test::()?; + + Ok(()) +} + +#[test] +fn quantized_matmul_q5k() -> Result<()> { + use k_quants::BlockQ5K; + + let cpu = &Device::Cpu; + let (m, k, n) = (11, 512, 21); + let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?; + assert_eq!(mm.dims(), [m, n]); + let dst = mm.flatten_all()?.to_vec1::()?; + let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); + assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); + + let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q5K)?; + let rhs = quantized::QMatMul::from_qtensor(rhs)?; + let mm = rhs.forward(&lhs)?; + + assert_eq!(mm.dims(), [m, n]); + let dst = mm.flatten_all()?.to_vec1::()?; + let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); + assert_eq!(dst, [1.192, 1.491, -0.18, 1.743]); + + //Expected: 0.000740408897 + ggml_matmul_error_test::()?; + + Ok(()) +} + +#[test] +fn quantized_matmul_q6k() -> Result<()> { + use k_quants::BlockQ6K; + + let cpu = &Device::Cpu; + let (m, k, n) = (11, 512, 21); + let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?; + assert_eq!(mm.dims(), [m, n]); + let dst = mm.flatten_all()?.to_vec1::()?; + let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); + assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); + + let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q6K)?; + let rhs = quantized::QMatMul::from_qtensor(rhs)?; + let mm = rhs.forward(&lhs)?; + + assert_eq!(mm.dims(), [m, n]); + let dst = mm.flatten_all()?.to_vec1::()?; + let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); + assert_eq!(dst, [1.324, 1.49, -0.164, 1.741]); + + ggml_matmul_error_test::()?; + Ok(()) +} + +#[test] +fn quantized_matmul_q8k() -> Result<()> { + use k_quants::BlockQ8K; + + let cpu = &Device::Cpu; + let (m, k, n) = (11, 512, 21); + let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?; + assert_eq!(mm.dims(), [m, n]); + let dst = mm.flatten_all()?.to_vec1::()?; + let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); + assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); + + let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q8K)?; + let rhs = quantized::QMatMul::from_qtensor(rhs)?; + let mm = rhs.forward(&lhs)?; + + assert_eq!(mm.dims(), [m, n]); + let dst = mm.flatten_all()?.to_vec1::()?; + let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); + assert_eq!(dst, [1.266, 1.504, -0.204, 1.7]); + + ggml_matmul_error_test::()?; + Ok(()) +} diff --git a/mcandle/mcandle-core/tests/serialization_tests.rs b/mcandle/mcandle-core/tests/serialization_tests.rs new file mode 100644 index 000000000..e66479e29 --- /dev/null +++ b/mcandle/mcandle-core/tests/serialization_tests.rs @@ -0,0 +1,71 @@ +use mcandle_core::{DType, Result, Tensor}; + +struct TmpFile(std::path::PathBuf); + +impl TmpFile { + fn create(base: &str) -> TmpFile { + let filename = std::env::temp_dir().join(format!( + "mcandle-{}-{}-{:?}", + base, + std::process::id(), + std::thread::current().id(), + )); + TmpFile(filename) + } +} + +impl std::convert::AsRef for TmpFile { + fn as_ref(&self) -> &std::path::Path { + self.0.as_path() + } +} + +impl Drop for TmpFile { + fn drop(&mut self) { + std::fs::remove_file(&self.0).unwrap() + } +} + +#[test] +fn npy() -> Result<()> { + let npy = Tensor::read_npy("tests/test.npy")?; + assert_eq!( + npy.to_dtype(DType::U8)?.to_vec1::()?, + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + ); + Ok(()) +} + +#[test] +fn npz() -> Result<()> { + let npz = Tensor::read_npz("tests/test.npz")?; + assert_eq!(npz.len(), 2); + assert_eq!(npz[0].0, "x"); + assert_eq!(npz[1].0, "x_plus_one"); + assert_eq!( + npz[1].1.to_dtype(DType::U8)?.to_vec1::()?, + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + ); + Ok(()) +} + +#[test] +fn safetensors() -> Result<()> { + use mcandle_core::safetensors::Load; + + let tmp_file = TmpFile::create("st"); + let t = Tensor::arange(0f32, 24f32, &mcandle_core::Device::Cpu)?; + t.save_safetensors("t", &tmp_file)?; + // Load from file. + let st = mcandle_core::safetensors::load(&tmp_file, &mcandle_core::Device::Cpu)?; + let t2 = st.get("t").unwrap(); + let diff = (&t - t2)?.abs()?.sum_all()?.to_vec0::()?; + assert_eq!(diff, 0f32); + // Load from bytes. + let bytes = std::fs::read(tmp_file)?; + let st = mcandle_core::safetensors::SliceSafetensors::new(&bytes)?; + let t2 = st.get("t").unwrap().load(&mcandle_core::Device::Cpu); + let diff = (&t - t2)?.abs()?.sum_all()?.to_vec0::()?; + assert_eq!(diff, 0f32); + Ok(()) +} diff --git a/mcandle/mcandle-core/tests/tensor_tests.rs b/mcandle/mcandle-core/tests/tensor_tests.rs new file mode 100644 index 000000000..8827b71c1 --- /dev/null +++ b/mcandle/mcandle-core/tests/tensor_tests.rs @@ -0,0 +1,1702 @@ +use mcandle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor, D}; + +fn zeros(device: &Device) -> Result<()> { + let tensor = Tensor::zeros((5, 2), DType::F32, device)?; + let (dim1, dim2) = tensor.dims2()?; + assert_eq!(dim1, 5); + assert_eq!(dim2, 2); + Ok(()) +} + +fn ones(device: &Device) -> Result<()> { + assert_eq!( + Tensor::ones((2, 3), DType::U8, device)?.to_vec2::()?, + [[1, 1, 1], [1, 1, 1]], + ); + assert_eq!( + Tensor::ones((2, 3), DType::U32, device)?.to_vec2::()?, + [[1, 1, 1], [1, 1, 1]], + ); + assert_eq!( + Tensor::ones((2, 3), DType::I16, device)?.to_vec2::()?, + [[1, 1, 1], [1, 1, 1]], + ); + assert_eq!( + Tensor::ones((2, 3), DType::I32, device)?.to_vec2::()?, + [[1, 1, 1], [1, 1, 1]], + ); + assert_eq!( + Tensor::ones((2, 3), DType::I64, device)?.to_vec2::()?, + [[1, 1, 1], [1, 1, 1]], + ); + assert_eq!( + Tensor::ones((2, 3), DType::F32, device)?.to_vec2::()?, + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + ); + assert_eq!( + Tensor::ones((2, 3), DType::F64, device)?.to_vec2::()?, + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + ); + assert_eq!( + Tensor::ones((2, 3), DType::F16, device)?.to_vec2::()?, + [ + [ + half::f16::from_f32(1.0), + half::f16::from_f32(1.0), + half::f16::from_f32(1.0) + ], + [ + half::f16::from_f32(1.0), + half::f16::from_f32(1.0), + half::f16::from_f32(1.0) + ] + ], + ); + assert_eq!( + Tensor::ones((2, 3), DType::BF16, device)?.to_vec2::()?, + [ + [ + half::bf16::from_f32(1.0), + half::bf16::from_f32(1.0), + half::bf16::from_f32(1.0) + ], + [ + half::bf16::from_f32(1.0), + half::bf16::from_f32(1.0), + half::bf16::from_f32(1.0) + ] + ], + ); + Ok(()) +} + +fn full(device: &Device) -> Result<()> { + assert_eq!( + Tensor::full(42u32, (2, 3), device)?.to_vec2::()?, + [[42, 42, 42], [42, 42, 42]], + ); + Ok(()) +} + +fn arange(device: &Device) -> Result<()> { + assert_eq!( + Tensor::arange(0u8, 5u8, device)?.to_vec1::()?, + [0, 1, 2, 3, 4], + ); + assert_eq!( + Tensor::arange_step(0u8, 5u8, 2, device)?.to_vec1::()?, + [0, 2, 4], + ); + assert_eq!( + Tensor::arange_step(0u8, 5u8, 3, device)?.to_vec1::()?, + [0, 3], + ); + assert_eq!( + Tensor::arange_step(5i64, 0i64, -1, device)?.to_vec1::()?, + [5, 4, 3, 2, 1], + ); + Ok(()) +} + +fn add_mul(device: &Device) -> Result<()> { + let tensor = Tensor::new(&[3f32, 1., 4.], device)?; + let dim1 = tensor.dims1()?; + assert_eq!(dim1, 3); + let content: Vec = tensor.to_vec1()?; + assert_eq!(content, [3., 1., 4.]); + let tensor = Tensor::add(&tensor, &tensor)?; + let content: Vec = tensor.to_vec1()?; + assert_eq!(content, [6., 2., 8.]); + let tensor = Tensor::mul(&tensor, &tensor)?; + let content: Vec = tensor.to_vec1()?; + assert_eq!(content, [36., 4., 64.]); + Ok(()) +} + +fn tensor_2d(device: &Device) -> Result<()> { + let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]]; + let tensor = Tensor::new(data, device)?; + let dims = tensor.dims2()?; + assert_eq!(dims, (2, 5)); + let content: Vec> = tensor.to_vec2()?; + assert_eq!(content, data); + Ok(()) +} + +fn clamp(device: &Device) -> Result<()> { + let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]]; + let tensor = Tensor::new(data, device)?; + let tensor = tensor.clamp(1.5, 6.2)?; + assert_eq!( + tensor.to_vec2::()?, + [[3.0, 1.5, 4.0, 1.5, 5.0], [2.0, 1.5, 6.2, 6.2, 2.0]], + ); + Ok(()) +} + +fn asort(device: &Device) -> Result<()> { + let data = &[[3f32, 1., 4., 1.1, 5.], [2.1, 1., 7., 8., 2.]]; + let tensor = Tensor::new(data, device)?; + let indexes = tensor.arg_sort_last_dim(true)?; + assert_eq!( + indexes.to_vec2::()?, + [[1, 3, 0, 2, 4], [1, 4, 0, 2, 3]], + ); + let indexes = tensor.arg_sort_last_dim(false)?; + assert_eq!( + indexes.to_vec2::()?, + [[4, 2, 0, 3, 1], [3, 2, 0, 4, 1]], + ); + let (sorted, indexes) = tensor.sort_last_dim(true)?; + assert_eq!( + indexes.to_vec2::()?, + [[1, 3, 0, 2, 4], [1, 4, 0, 2, 3]], + ); + assert_eq!( + sorted.to_vec2::()?, + [[1.0, 1.1, 3.0, 4.0, 5.0], [1.0, 2.0, 2.1, 7.0, 8.0]] + ); + let (sorted, indexes) = tensor.sort_last_dim(false)?; + assert_eq!( + indexes.to_vec2::()?, + [[4, 2, 0, 3, 1], [3, 2, 0, 4, 1]], + ); + assert_eq!( + sorted.to_vec2::()?, + [[5.0, 4.0, 3.0, 1.1, 1.0], [8.0, 7.0, 2.1, 2.0, 1.0]] + ); + Ok(()) +} + +fn unary_op(device: &Device) -> Result<()> { + let data = &[[-3f32, 1., 4., -0.1, 0.5], [2.7, -1.8, -0.28, 1.8, 2.8]]; + let tensor = Tensor::new(data, device)?; + assert_eq!( + test_utils::to_vec2_round(&tensor.gelu()?, 4)?, + [ + [-0.0036, 0.8412, 3.9999, -0.046, 0.3457], + [2.6911, -0.0647, -0.1091, 1.7353, 2.7933] + ] + ); + let t_f16 = tensor.to_dtype(DType::F16)?.gelu()?.to_dtype(DType::F32)?; + let max_diff = (tensor.gelu()? - t_f16)?.flatten_all()?.max(0)?; + assert!(max_diff.to_vec0::()? < 5e-3); + assert_eq!( + test_utils::to_vec2_round(&tensor.gelu_erf()?, 4)?, + [ + [-0.004, 0.8413, 3.9999, -0.046, 0.3457], + [2.6906, -0.0647, -0.1091, 1.7353, 2.7928] + ] + ); + assert_eq!( + test_utils::to_vec2_round(&tensor.erf()?, 4)?, + [ + [-1.0, 0.8427, 1.0, -0.1125, 0.5205], + [0.9999, -0.9891, -0.3079, 0.9891, 0.9999] + ] + ); + assert_eq!( + test_utils::to_vec2_round(&tensor.silu()?, 4)?, + [ + [-0.1423, 0.7311, 3.9281, -0.0475, 0.3112], + [2.53, -0.2553, -0.1205, 1.5447, 2.6395] + ] + ); + assert_eq!( + test_utils::to_vec2_round(&tensor.ceil()?, 4)?, + [[-3.0, 1.0, 4.0, -0.0, 1.0], [3.0, -1.0, -0.0, 2.0, 3.0]] + ); + assert_eq!( + test_utils::to_vec2_round(&tensor.floor()?, 4)?, + [[-3.0, 1.0, 4.0, -1.0, 0.0], [2.0, -2.0, -1.0, 1.0, 2.0]] + ); + assert_eq!( + test_utils::to_vec2_round(&tensor.round()?, 4)?, + [[-3.0, 1.0, 4.0, -0.0, 1.0], [3.0, -2.0, -0.0, 2.0, 3.0]] + ); + let tensor = Tensor::new(&[2997.9246, 314.15926f32], device)?; + assert_eq!( + test_utils::to_vec1_round(&tensor.round_to(2)?, 4)?, + [2997.92, 314.16] + ); + assert_eq!( + test_utils::to_vec1_round(&tensor.round_to(-2)?, 4)?, + [3000.0, 300.] + ); + let tensor = Tensor::new( + &[-1.01f32, -0.9, -0.1, 0.0, -0.0, 0.1, 0.9, 1.0, 1.1], + device, + )?; + assert_eq!( + tensor.sign()?.to_vec1::()?, + [-1., -1., -1., 0., 0., 1., 1., 1., 1.] + ); + let tensor = Tensor::new(&[-1.0f32, 0., -2., 3.], device)?; + let y = tensor.elu(2.)?; + assert_eq!( + test_utils::to_vec1_round(&y, 4)?, + [-1.2642, 0.0000, -1.7293, 3.0000] + ); + // This test failed on metal prior to the following PR: + // https://github.com/huggingface/mcandle/pull/2490 + let y = tensor.reshape((2, 2))?.t()?.elu(2.)?.flatten_all()?; + assert_eq!( + test_utils::to_vec1_round(&y, 4)?, + [-1.2642, -1.7293, 0.0000, 3.0000] + ); + Ok(()) +} + +fn binary_op(device: &Device) -> Result<()> { + let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]]; + let tensor1 = Tensor::new(data, device)?; + let data2 = &[[5f32, 5., 5., 5., 5.], [2., 1., 7., 8., 2.]]; + let tensor2 = Tensor::new(data2, device)?; + let tensor = (&tensor1 + (&tensor1 * &tensor1)? / (&tensor1 + &tensor2))?; + let dims = tensor.dims2()?; + assert_eq!(dims, (2, 5)); + let content: Vec> = tensor.to_vec2()?; + assert_eq!(content[0], [4.125, 1.1666666, 5.7777777, 1.1666666, 7.5]); + assert_eq!(content[1], [3.0, 1.5, 10.5, 12.0, 3.0]); + #[allow(clippy::eq_op)] + let tensor = (&tensor - &tensor)?; + let content: Vec> = tensor.to_vec2()?; + assert_eq!(content[0], [0., 0., 0., 0., 0.]); + + let min = tensor1.minimum(&(&tensor2 * 0.5)?)?; + let max = tensor1.maximum(&(&tensor2 * 0.5)?)?; + assert_eq!( + min.to_vec2::()?, + [[2.5, 1.0, 2.5, 1.0, 2.5], [1.0, 0.5, 3.5, 4.0, 1.0]], + ); + assert_eq!( + max.to_vec2::()?, + [[3.0, 2.5, 4.0, 2.5, 5.0], [2.0, 1.0, 7.0, 8.0, 2.0]] + ); + Ok(()) +} + +fn transpose(device: &Device) -> Result<()> { + let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]]; + let tensor = Tensor::new(data, device)?.t()?; + let dims = tensor.dims2()?; + assert_eq!(dims, (5, 2)); + assert_eq!( + tensor.to_vec2::()?, + &[[3f32, 2.], [1., 1.], [4., 7.], [1., 8.], [5., 2.]] + ); + assert_eq!(tensor.t()?.to_vec2::()?, data); + assert_eq!(tensor.contiguous()?.t()?.to_vec2::()?, data); + assert_eq!(((tensor + 1.)?.t()? - 1.)?.to_vec2::()?, data); + Ok(()) +} + +fn var(device: &Device) -> Result<()> { + // Values taken from https://pytorch.org/docs/stable/generated/torch.var.html + let data = &[ + [0.2035f32, 1.2959, 1.8101, -0.4644], + [1.5027, -0.3270, 0.5905, 0.6538], + [-1.5745, 1.3330, -0.5596, -0.6548], + [0.1264, -0.5080, 1.6420, 0.1992], + ]; + let tensor = Tensor::new(data, device)?; + assert_eq!( + test_utils::to_vec2_round(&tensor.var_keepdim(1)?, 4)?, + &[[1.0631], [0.559], [1.4893], [0.8258]] + ); + Ok(()) +} + +fn sum(device: &Device) -> Result<()> { + let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]]; + let tensor = Tensor::new(data, device)?; + assert_eq!( + tensor.sum_keepdim(2)?.to_vec3::()?, + &[[[8], [15]], [[10], [18]]] + ); + assert_eq!( + tensor.sum_keepdim(0)?.to_vec3::()?, + &[[[5, 2, 11], [9, 7, 17]]], + ); + assert_eq!(tensor.sum_keepdim((0, 2, 1))?.to_vec3::()?, &[[[51]]],); + assert_eq!( + tensor.t()?.sum_keepdim(1)?.t()?.to_vec3::()?, + &[[[8], [15]], [[10], [18]]] + ); + assert_eq!( + tensor.sum_keepdim((2, 1))?.to_vec3::()?, + &[[[8 + 15]], [[10 + 18]]] + ); + let data: Vec = (0..4000u32).collect(); + let tensor = Tensor::new(data.as_slice(), device)?; + assert_eq!(tensor.sum_keepdim(0)?.to_vec1::()?, &[7998000]); + let tensor = tensor.reshape((2000, 2))?; + assert_eq!(tensor.sum_keepdim((0, 1))?.to_vec2::()?, &[[7998000]]); + assert_eq!( + tensor.sum_keepdim(0)?.sum_keepdim(1)?.to_vec2::()?, + &[[7998000]] + ); + assert_eq!( + tensor.sum_keepdim(1)?.sum_keepdim(0)?.to_vec2::()?, + &[[7998000]] + ); + assert_eq!( + tensor.sum_keepdim(0)?.to_vec2::()?, + &[[3998000, 4000000]] + ); + + // Make the tensor non contiguous. + let tensor = tensor.t()?.contiguous()?.t()?; + assert_eq!(tensor.sum_keepdim((0, 1))?.to_vec2::()?, &[[7998000]]); + assert_eq!( + tensor.sum_keepdim(0)?.sum_keepdim(1)?.to_vec2::()?, + &[[7998000]] + ); + assert_eq!( + tensor.sum_keepdim(1)?.sum_keepdim(0)?.to_vec2::()?, + &[[7998000]] + ); + assert_eq!( + tensor.sum_keepdim(0)?.to_vec2::()?, + &[[3998000, 4000000]] + ); + + let t1 = tensor.reshape((200, 5, 4))?; + let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?; + for tensor in [t1, t2] { + assert_eq!( + tensor.sum_keepdim((0, 1, 2))?.to_vec3::()?, + &[[[7998000]]] + ); + assert_eq!( + tensor + .sum_keepdim(0)? + .sum_keepdim(2)? + .sum_keepdim(1)? + .to_vec3::()?, + &[[[7998000]]] + ); + assert_eq!( + tensor + .sum_keepdim(0)? + .sum_keepdim((1, 2))? + .to_vec3::()?, + &[[[7998000]]] + ); + assert_eq!( + tensor + .sum_keepdim(1)? + .sum_keepdim((0, 2))? + .to_vec3::()?, + &[[[7998000]]] + ); + assert_eq!( + tensor.sum_keepdim(0)?.to_vec3::()?, + &[[ + [398000, 398200, 398400, 398600], + [398800, 399000, 399200, 399400], + [399600, 399800, 400000, 400200], + [400400, 400600, 400800, 401000], + [401200, 401400, 401600, 401800] + ]] + ); + } + Ok(()) +} + +fn min(device: &Device) -> Result<()> { + let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]]; + let tensor = Tensor::new(data, device)?; + assert_eq!( + tensor.min_keepdim(2)?.to_vec3::()?, + &[[[1], [1]], [[1], [2]]] + ); + assert_eq!( + tensor.min_keepdim(0)?.to_vec3::()?, + &[[[2, 1, 4], [1, 2, 8]]], + ); + let data: Vec = (200..4000u32).collect(); + let tensor = Tensor::new(data.as_slice(), device)?; + assert_eq!(tensor.min_keepdim(0)?.to_vec1::()?, &[200]); + let tensor = tensor.reshape((1900, 2))?; + assert_eq!( + tensor.min_keepdim(0)?.min_keepdim(1)?.to_vec2::()?, + &[[200]] + ); + assert_eq!( + tensor.min_keepdim(1)?.min_keepdim(0)?.to_vec2::()?, + &[[200]] + ); + assert_eq!(tensor.min_keepdim(0)?.to_vec2::()?, &[[200, 201]]); + + // Make the tensor non contiguous. + let tensor = tensor.t()?.contiguous()?.t()?; + assert_eq!( + tensor.min_keepdim(0)?.min_keepdim(1)?.to_vec2::()?, + &[[200]] + ); + assert_eq!( + tensor.min_keepdim(1)?.min_keepdim(0)?.to_vec2::()?, + &[[200]] + ); + assert_eq!(tensor.min_keepdim(0)?.to_vec2::()?, &[[200, 201]]); + + let t1 = tensor.reshape((190, 5, 4))?; + let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?; + for tensor in [t1, t2] { + assert_eq!( + tensor + .min_keepdim(0)? + .min_keepdim(2)? + .min_keepdim(1)? + .to_vec3::()?, + &[[[200]]] + ); + assert_eq!( + tensor.min_keepdim(0)?.to_vec3::()?, + &[[ + [200, 201, 202, 203], + [204, 205, 206, 207], + [208, 209, 210, 211], + [212, 213, 214, 215], + [216, 217, 218, 219] + ]] + ); + } + Ok(()) +} + +fn max(device: &Device) -> Result<()> { + let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]]; + let tensor = Tensor::new(data, device)?; + assert_eq!( + tensor.max_keepdim(2)?.to_vec3::()?, + &[[[4], [9]], [[7], [8]]] + ); + assert_eq!( + tensor.max_keepdim(0)?.to_vec3::()?, + &[[[3, 1, 7], [8, 5, 9]]], + ); + let data: Vec = (200..4000u32).collect(); + let tensor = Tensor::new(data.as_slice(), device)?; + assert_eq!(tensor.max_keepdim(0)?.to_vec1::()?, &[3999]); + let tensor = tensor.reshape((1900, 2))?; + assert_eq!( + tensor.max_keepdim(0)?.max_keepdim(1)?.to_vec2::()?, + &[[3999]] + ); + assert_eq!( + tensor.max_keepdim(1)?.max_keepdim(0)?.to_vec2::()?, + &[[3999]] + ); + assert_eq!(tensor.max_keepdim(0)?.to_vec2::()?, &[[3998, 3999]]); + + // Make the tensor non contiguous. + let tensor = tensor.t()?.contiguous()?.t()?; + assert_eq!( + tensor.max_keepdim(0)?.max_keepdim(1)?.to_vec2::()?, + &[[3999]] + ); + assert_eq!( + tensor.max_keepdim(1)?.max_keepdim(0)?.to_vec2::()?, + &[[3999]] + ); + assert_eq!(tensor.max_keepdim(0)?.to_vec2::()?, &[[3998, 3999]]); + + let t1 = tensor.reshape((190, 5, 4))?; + let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?; + for tensor in [t1, t2] { + assert_eq!( + tensor + .max_keepdim(0)? + .max_keepdim(2)? + .max_keepdim(1)? + .to_vec3::()?, + &[[[3999]]] + ); + assert_eq!( + tensor.max_keepdim(0)?.to_vec3::()?, + &[[ + [3980, 3981, 3982, 3983], + [3984, 3985, 3986, 3987], + [3988, 3989, 3990, 3991], + [3992, 3993, 3994, 3995], + [3996, 3997, 3998, 3999] + ]] + ); + } + Ok(()) +} + +fn argmin(device: &Device) -> Result<()> { + let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]]; + let tensor = Tensor::new(data, device)?; + assert_eq!( + tensor.argmin_keepdim(2)?.to_vec3::()?, + &[[[1], [0]], [[1], [1]]] + ); + assert_eq!( + tensor.argmin_keepdim(0)?.to_vec3::()?, + &[[[1, 0, 0], [0, 1, 1]]], + ); + let data: Vec = (200..4000u32).collect(); + let tensor = Tensor::new(data.as_slice(), device)?; + assert_eq!(tensor.argmin_keepdim(0)?.to_vec1::()?, &[0]); + let tensor = tensor.reshape((1900, 2))?; + assert_eq!( + tensor + .argmin_keepdim(0)? + .argmin_keepdim(1)? + .to_vec2::()?, + &[[0]] + ); + assert_eq!( + tensor + .argmin_keepdim(1)? + .argmin_keepdim(0)? + .to_vec2::()?, + &[[0]] + ); + assert_eq!(tensor.argmin_keepdim(0)?.to_vec2::()?, &[[0, 0]]); + + // Make the tensor non contiguous. + let tensor = tensor.t()?.contiguous()?.t()?; + assert_eq!( + tensor + .argmin_keepdim(0)? + .argmin_keepdim(1)? + .to_vec2::()?, + &[[0]] + ); + assert_eq!( + tensor + .argmin_keepdim(1)? + .argmin_keepdim(0)? + .to_vec2::()?, + &[[0]] + ); + assert_eq!(tensor.argmin_keepdim(0)?.to_vec2::()?, &[[0, 0]]); + + let t1 = tensor.reshape((190, 5, 4))?; + let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?; + for tensor in [t1, t2] { + assert_eq!( + tensor + .argmin_keepdim(0)? + .argmin_keepdim(2)? + .argmin_keepdim(1)? + .to_vec3::()?, + &[[[0]]] + ); + assert_eq!( + tensor.argmin_keepdim(0)?.to_vec3::()?, + &[[ + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0], + ]] + ); + } + Ok(()) +} + +fn argmax(device: &Device) -> Result<()> { + let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]]; + let tensor = Tensor::new(data, device)?; + assert_eq!( + tensor.argmax_keepdim(2)?.to_vec3::()?, + &[[[2], [2]], [[2], [0]]] + ); + assert_eq!( + tensor.argmax_keepdim(0)?.to_vec3::()?, + &[[[0, 0, 1], [1, 0, 0]]], + ); + let data: Vec = (200..4000u32).collect(); + let tensor = Tensor::new(data.as_slice(), device)?; + assert_eq!(tensor.argmax_keepdim(0)?.to_vec1::()?, &[3799]); + let tensor = tensor.reshape((1900, 2))?; + assert_eq!( + tensor + .argmax_keepdim(0)? + .argmax_keepdim(1)? + .to_vec2::()?, + &[[0]] + ); + assert_eq!( + tensor + .argmax_keepdim(1)? + .argmax_keepdim(0)? + .to_vec2::()?, + &[[0]] + ); + assert_eq!(tensor.argmax_keepdim(0)?.to_vec2::()?, &[[1899, 1899]]); + + // Make the tensor non contiguous. + let tensor = tensor.t()?.contiguous()?.t()?; + assert_eq!( + tensor + .argmax_keepdim(0)? + .argmax_keepdim(1)? + .to_vec2::()?, + &[[0]] + ); + assert_eq!( + tensor + .argmax_keepdim(1)? + .argmax_keepdim(0)? + .to_vec2::()?, + &[[0]] + ); + assert_eq!(tensor.argmax_keepdim(0)?.to_vec2::()?, &[[1899, 1899]]); + + let t1 = tensor.reshape((190, 5, 4))?; + let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?; + for tensor in [t1, t2] { + assert_eq!( + tensor + .argmax_keepdim(0)? + .argmax_keepdim(2)? + .argmax_keepdim(1)? + .to_vec3::()?, + &[[[0]]] + ); + assert_eq!( + tensor.argmax_keepdim(0)?.to_vec3::()?, + &[[ + [189, 189, 189, 189], + [189, 189, 189, 189], + [189, 189, 189, 189], + [189, 189, 189, 189], + [189, 189, 189, 189], + ]] + ); + } + Ok(()) +} + +fn narrow(device: &Device) -> Result<()> { + let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]]; + let tensor = Tensor::new(data, device)?; + assert_eq!( + tensor.narrow(2, 1, 2)?.to_vec3::()?, + &[[[1.0, 4.0], [5.0, 9.0]], [[1.0, 7.0], [2.0, 8.0]]], + ); + assert_eq!( + tensor.narrow(1, 1, 1)?.to_vec3::()?, + &[[[1.0, 5.0, 9.0]], [[8.0, 2.0, 8.0]]], + ); + assert_eq!( + tensor.narrow(0, 0, 1)?.to_vec3::()?, + &[[[3.0, 1.0, 4.0], [1.0, 5.0, 9.0]]], + ); + assert_eq!( + tensor.narrow(0, 1, 1)?.to_vec3::()?, + &[[[2.0, 1.0, 7.0], [8.0, 2.0, 8.0]]], + ); + // The following has been checked against PyTorch via: + // import torch + // t = torch.tensor([[[3., 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]]) + // t.transpose(-1, -2).narrow(1, 1, 2) + assert_eq!( + tensor.t()?.narrow(1, 1, 2)?.to_vec3::()?, + &[[[1.0, 5.0], [4.0, 9.0]], [[1.0, 2.0], [7.0, 8.0]]], + ); + Ok(()) +} + +fn broadcast(device: &Device) -> Result<()> { + let data = &[3f32, 1., 4.]; + let tensor = Tensor::new(data, device)?; + assert_eq!( + tensor.broadcast_left((3, 1))?.to_vec3::()?, + &[[[3.0, 1.0, 4.0]], [[3.0, 1.0, 4.0]], [[3.0, 1.0, 4.0]]] + ); + Ok(()) +} + +fn slice_set(device: &Device) -> Result<()> { + let (b, h, max_t, d) = (2, 4, 7, 3); + let cache = Tensor::zeros((b, h, max_t, d), DType::F32, device)?; + let tensor = Tensor::randn(0f32, 1f32, (b, h, 4, d), device)?; + cache.slice_set(&tensor, 2, 0)?; + let cache_t = cache.narrow(2, 0, 4)?; + let diff = (cache_t - &tensor)?.abs()?.sum_all()?.to_vec0::()?; + assert_eq!(diff, 0.); + cache.slice_set(&tensor, 2, 1)?; + let cache_t = cache.narrow(2, 1, 4)?; + let diff = (cache_t - &tensor)?.abs()?.sum_all()?.to_vec0::()?; + assert_eq!(diff, 0.); + let ones = Tensor::ones((b, h, 1, d), DType::F32, device)?; + cache.slice_set(&ones, 2, 6)?; + let diff = cache.narrow(2, 5, 1)?.abs()?.sum_all()?.to_vec0::()?; + assert_eq!(diff, 0.); + let diff = (cache.narrow(2, 6, 1)? - 1.)? + .abs()? + .sum_all()? + .to_vec0::()?; + assert_eq!(diff, 0.); + Ok(()) +} + +fn cat(device: &Device) -> Result<()> { + // 1D + let t1 = Tensor::new(&[3f32, 1., 4.], device)?; + let t2 = Tensor::new(&[1f32, 5., 9., 2.], device)?; + let t3 = Tensor::new(&[6f32, 5., 3., 5., 8., 9.], device)?; + assert_eq!(Tensor::cat(&[&t1], 0)?.to_vec1::()?, [3f32, 1., 4.],); + assert_eq!( + Tensor::cat(&[&t1, &t2], 0)?.to_vec1::()?, + [3f32, 1., 4., 1., 5., 9., 2.], + ); + assert_eq!( + Tensor::cat(&[&t1, &t2, &t3], 0)?.to_vec1::()?, + [3f32, 1., 4., 1., 5., 9., 2., 6., 5., 3., 5., 8., 9.], + ); + + // 2D + let data = &[[3f32, 1., 4., 1., 5.], [2., 7., 1., 8., 2.]]; + let t1 = Tensor::new(data, device)?; + let data2 = &[[5f32, 5., 5., 5., 5.], [2., 7., 1., 8., 2.]]; + let t2 = Tensor::new(data2, device)?; + assert_eq!( + Tensor::cat(&[&t1, &t2], 0)?.to_vec2::()?, + [ + [3.0, 1.0, 4.0, 1.0, 5.0], + [2.0, 7.0, 1.0, 8.0, 2.0], + [5.0, 5.0, 5.0, 5.0, 5.0], + [2.0, 7.0, 1.0, 8.0, 2.0] + ] + ); + // PyTorch equivalent: + // import torch + // t1 = torch.tensor([[3, 1, 4, 1, 5], [2, 7, 1, 8, 2]]) + // t2 = torch.tensor([[5]*5, [2, 7, 1, 8, 2]]) + // torch.cat([t1.t(), t2.t()], dim=1).t() + assert_eq!( + Tensor::cat(&[&t1.t()?, &t2.t()?], 1)? + .t()? + .to_vec2::()?, + [ + [3.0, 1.0, 4.0, 1.0, 5.0], + [2.0, 7.0, 1.0, 8.0, 2.0], + [5.0, 5.0, 5.0, 5.0, 5.0], + [2.0, 7.0, 1.0, 8.0, 2.0] + ] + ); + assert_eq!( + Tensor::cat(&[&t1, &t2], 1)?.to_vec2::()?, + [ + [3.0, 1.0, 4.0, 1.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0], + [2.0, 7.0, 1.0, 8.0, 2.0, 2.0, 7.0, 1.0, 8.0, 2.0] + ] + ); + + // 3D + let t1 = Tensor::arange(0, 48i64, device)?.reshape((2, 6, 4))?; + let t2 = Tensor::arange(100, 124i64, device)?.reshape((2, 3, 4))?; + let t3 = Tensor::arange(10000, 10032i64, device)?.reshape((2, 4, 4))?; + + let t_cat = Tensor::cat(&[&t1, &t2, &t3], 1)?; + + let t1 = t1.t()?.contiguous()?.t()?; + let t2 = t2.t()?.contiguous()?.t()?; + let t3 = t3.t()?.contiguous()?.t()?; + let t_cat2 = Tensor::cat(&[&t1, &t2, &t3], 1)?; + + let diff = t_cat.eq(&t_cat2)?.to_dtype(DType::F32)?.sum_all()?; + assert_eq!(diff.to_vec0::()?, 104.0); + assert_eq!(t_cat.i((0, 0, 0))?.to_vec0::()?, 0); + assert_eq!(t_cat.i((0, 4, 0))?.to_vec0::()?, 16); + assert_eq!(t_cat.i((0, 5, 0))?.to_vec0::()?, 20); + assert_eq!(t_cat.i((1, 5, 0))?.to_vec0::()?, 44); + assert_eq!(t_cat.i((0, 6, 0))?.to_vec0::()?, 100); + assert_eq!(t_cat.i((1, 6, 0))?.to_vec0::()?, 112); + assert_eq!(t_cat.i((0, 6, 1))?.to_vec0::()?, 101); + assert_eq!(t_cat.i((0, 7, 1))?.to_vec0::()?, 105); + assert_eq!(t_cat.i((0, 12, 1))?.to_vec0::()?, 10013); + assert_eq!(t_cat.i((1, 12, 3))?.to_vec0::()?, 10031); + Ok(()) +} + +fn embeddings(device: &Device) -> Result<()> { + let ids = Tensor::new(&[0u32, 2u32, 1u32], device)?; + let t = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], device)?; + let hs = t.embedding(&ids)?; + assert_eq!(hs.to_vec2::()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]); + let hs = t.index_select(&ids, 0)?; + assert_eq!(hs.to_vec2::()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]); + let hs = t.index_select(&ids.to_dtype(DType::I64)?, 0)?; + assert_eq!(hs.to_vec2::()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]); + Ok(()) +} + +fn cmp(device: &Device) -> Result<()> { + let t1 = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], device)?; + let t2 = Tensor::new(&[[1f32, 0f32], [3f32, 3f32], [4f32, 7f32]], device)?; + assert_eq!(t1.eq(&t2)?.to_vec2::()?, &[[0, 0], [0, 1], [1, 0]]); + assert_eq!(t1.ne(&t2)?.to_vec2::()?, &[[1, 1], [1, 0], [0, 1]]); + assert_eq!(t1.le(&t2)?.to_vec2::()?, &[[1, 0], [1, 1], [1, 1]]); + assert_eq!(t1.lt(&t2)?.to_vec2::()?, &[[1, 0], [1, 0], [0, 1]]); + assert_eq!(t1.gt(&t2)?.to_vec2::()?, &[[0, 1], [0, 0], [0, 0]]); + assert_eq!(t1.ge(&t2)?.to_vec2::()?, &[[0, 1], [0, 1], [1, 0]]); + Ok(()) +} + +fn index_select(device: &Device) -> Result<()> { + let ids = Tensor::new(&[0u32, 2u32, 1u32], device)?; + let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?; + assert_eq!( + t.to_vec2::()?, + &[ + [0.0, 1.0, 2.0], + [3.0, 4.0, 5.0], + [6.0, 7.0, 8.0], + [9.0, 10.0, 11.0] + ] + ); + for dtype in [DType::U8, DType::U32, DType::I16, DType::I32, DType::I64] { + let ids = ids.to_dtype(dtype)?; + let hs = t.index_select(&ids, 1)?; + assert_eq!( + hs.to_vec2::()?, + &[ + [0.0, 2.0, 1.0], + [3.0, 5.0, 4.0], + [6.0, 8.0, 7.0], + [9.0, 11.0, 10.0] + ] + ); + let hs = t.index_select(&ids, 0)?; + assert_eq!( + hs.to_vec2::()?, + &[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]] + ); + // Prior to https://github.com/huggingface/mcandle/pull/1022 + // There would be a bug where the last values in the result tensor would be set to 0. + let ids = Tensor::new(&[0u32, 2u32, 1u32, 0u32, 2u32, 1u32], device)?; + let hs = t.index_select(&ids, 0)?; + assert_eq!( + hs.to_vec2::()?, + &[ + [0.0, 1.0, 2.0], + [6.0, 7.0, 8.0], + [3.0, 4.0, 5.0], + [0.0, 1.0, 2.0], + [6.0, 7.0, 8.0], + [3.0, 4.0, 5.0], + ] + ); + + // Test when selecting dim > 0 with ids size different from elem count of + // target dim in source/input. + let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?; + let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?; + assert_eq!(t.to_vec2::()?, &[[1.0, 2.0], [3.0, 4.0]]); + let hs = t.index_select(&ids, 1)?; + assert_eq!(hs.to_vec2::()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]); + } + + Ok(()) +} + +fn index_add(device: &Device) -> Result<()> { + let ids = Tensor::new(&[0u32, 1u32, 1u32], device)?; + let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?; + assert_eq!( + t.to_vec2::()?, + &[ + [0.0, 1.0, 2.0], + [3.0, 4.0, 5.0], + [6.0, 7.0, 8.0], + [9.0, 10.0, 11.0] + ] + ); + let init = Tensor::ones((4, 2), DType::F32, device)?; + let hs = init.index_add(&ids, &t, 1)?; + assert_eq!( + hs.to_vec2::()?, + &[[1.0, 4.0], [4.0, 10.0], [7.0, 16.0], [10.0, 22.0]], + ); + let init = Tensor::zeros((4, 2), DType::F32, device)?; + let ids = Tensor::new(&[1u32, 0u32, 0u32], device)?; + let hs = init.index_add(&ids, &t, 1)?; + assert_eq!( + hs.to_vec2::()?, + &[[3.0, 0.0], [9.0, 3.0], [15.0, 6.0], [21.0, 9.0]], + ); + + let init = Tensor::zeros((6, 3), DType::F32, device)?; + let ids = Tensor::new(&[5u32, 0u32, 1u32, 0u32], device)?; + let hs = init.index_add(&ids, &t, 0)?; + assert_eq!( + hs.to_vec2::()?, + &[ + [12.0, 14.0, 16.0], + [6.0, 7.0, 8.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 1.0, 2.0] + ] + ); + Ok(()) +} + +fn slice_scatter(device: &Device) -> Result<()> { + let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?; + assert_eq!( + t.to_vec2::()?, + &[ + [0.0, 1.0, 2.0], + [3.0, 4.0, 5.0], + [6.0, 7.0, 8.0], + [9.0, 10.0, 11.0] + ] + ); + let src = Tensor::arange(100f32, 106f32, device)?.reshape((2, 3))?; + assert_eq!( + t.slice_scatter0(&src, 0)?.to_vec2::()?, + &[ + [100.0, 101.0, 102.0], + [103.0, 104.0, 105.0], + [6.0, 7.0, 8.0], + [9.0, 10.0, 11.0] + ] + ); + assert_eq!( + t.slice_scatter0(&src, 1)?.to_vec2::()?, + &[ + [0.0, 1.0, 2.0], + [100.0, 101.0, 102.0], + [103.0, 104.0, 105.0], + [9.0, 10.0, 11.0] + ] + ); + assert_eq!( + t.slice_scatter0(&src, 2)?.to_vec2::()?, + &[ + [0.0, 1.0, 2.0], + [3.0, 4.0, 5.0], + [100.0, 101.0, 102.0], + [103.0, 104.0, 105.0], + ] + ); + Ok(()) +} + +fn scatter_add(device: &Device) -> Result<()> { + let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?; + assert_eq!( + t.to_vec2::()?, + &[ + [0.0, 1.0, 2.0], + [3.0, 4.0, 5.0], + [6.0, 7.0, 8.0], + [9.0, 10.0, 11.0] + ] + ); + let ids = Tensor::new(&[[0u32, 1, 2], [3, 4, 0], [3, 3, 1], [2, 0, 4]], device)?; + let init = Tensor::ones((4, 5), DType::F32, device)?; + let hs = init.scatter_add(&ids, &t, 1)?; + assert_eq!( + hs.to_vec2::()?, + &[ + [1.0, 2.0, 3.0, 1.0, 1.0], + [6.0, 1.0, 1.0, 4.0, 5.0], + [1.0, 9.0, 1.0, 14.0, 1.0], + [11.0, 1.0, 10.0, 1.0, 12.0] + ] + ); + + let init = Tensor::ones((6, 3), DType::F32, device)?; + let hs = init.scatter_add(&ids, &t, 0)?; + assert_eq!( + hs.to_vec2::()?, + &[ + [1.0, 11.0, 6.0], + [1.0, 2.0, 9.0], + [10.0, 1.0, 3.0], + [10.0, 8.0, 1.0], + [1.0, 5.0, 12.0], + [1.0, 1.0, 1.0] + ] + ); + Ok(()) +} + +fn gather(device: &Device) -> Result<()> { + let ids = Tensor::new(&[[0u32], [2u32], [1u32], [0u32]], device)?; + let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?; + assert_eq!( + t.to_vec2::()?, + &[ + [0.0, 1.0, 2.0], + [3.0, 4.0, 5.0], + [6.0, 7.0, 8.0], + [9.0, 10.0, 11.0] + ] + ); + let hs = t.gather(&ids, 1)?; + assert_eq!(hs.to_vec2::()?, &[[0.0], [5.0], [7.0], [9.0]]); + let ids = Tensor::new( + &[[0u32, 0u32], [2u32, 0u32], [1u32, 1u32], [0u32, 2u32]], + device, + )?; + let hs = t.gather(&ids, 1)?; + assert_eq!( + hs.to_vec2::()?, + &[[0.0, 0.0], [5.0, 3.0], [7.0, 7.0], [9.0, 11.0]] + ); + let ids = Tensor::new(&[[0u32, 2u32, 0u32]], device)?; + let hs = t.gather(&ids, 0)?; + assert_eq!(hs.to_vec2::()?, &[[0.0, 7.0, 2.0]]); + let ids = Tensor::new(&[[0u32, 2u32, 0u32], [0u32, 1u32, 1u32]], device)?; + let hs = t.gather(&ids, 0)?; + assert_eq!(hs.to_vec2::()?, &[[0.0, 7.0, 2.0], [0.0, 4.0, 5.0]]); + + // Random data + + // Dim: 0 + let t = Tensor::new( + &[ + [ + [108_f32, -47., 16., -56., -83., -130., 210.], + [253., 95., 151., 228., -210., -123., -127.], + [-9., -217., 2., -78., 163., 245., -204.], + [-246., 79., -238., 88., -226., -184., 171.], + [8., -48., -153., 234., -34., 166., -153.], + [124., 0., -10., -61., -242., -15., -238.], + ], + [ + [12., -64., -199., 244., -240., 156., -128.], + [173., -57., 4., -198., 233., -110., 238.], + [95., 82., 0., 240., 53., -211., 209.], + [-122., 167., -212., 227., -144., 61., 118.], + [-63., -146., 200., 244., 168., -167., 116.], + [-125., -147., 110., -253., -178., -250., -18.], + ], + [ + [57., 86., -50., 56., 92., 205., -78.], + [-137., -156., -18., 248., -61., -239., 14.], + [-248., -30., -50., -70., -251., 250., -83.], + [-221., 67., 72., 59., -24., -154., 232.], + [-144., -23., -74., 5., 93., 171., 205.], + [46., -77., -38., -226., 246., 161., -17.], + ], + [ + [-153., -231., -236., 161., 126., 2., -22.], + [-229., -41., 209., 164., 234., 160., 57.], + [223., 254., -186., -162., -46., -160., -102.], + [65., 30., 213., -253., 59., 224., -154.], + [-82., -203., -177., 17., 31., -256., -246.], + [176., -135., -65., 54., -56., 210., 76.], + ], + [ + [-10., -245., 168., 124., -14., -33., -178.], + [25., -43., -39., 132., -89., 169., 179.], + [187., -215., 32., -133., 87., -7., -168.], + [-224., -215., -5., -230., -58., -162., 128.], + [158., -137., -122., -100., -202., -83., 136.], + [30., -185., -144., 250., 209., -40., 127.], + ], + [ + [-196., 108., -245., 122., 146., -228., 62.], + [-1., -66., 160., 137., 13., -172., -21.], + [244., 199., -164., 28., 119., -175., 198.], + [-62., 253., -162., 195., -95., -230., -211.], + [123., -72., -26., -107., -139., 64., 245.], + [11., -126., -182., 108., -12., 184., -127.], + ], + [ + [-159., 126., 176., 161., 73., -111., -138.], + [-187., 214., -217., -33., -223., -201., -212.], + [-61., -120., -166., -172., -95., 53., 196.], + [-33., 86., 134., -152., 154., -53., 74.], + [186., -28., -154., -174., 141., -109., 217.], + [82., 35., 252., 145., 181., 74., -87.], + ], + ], + device, + )?; + + let ids = Tensor::new( + &[ + [ + [6_u32, 6, 4, 3, 4, 4, 6], + [3, 3, 2, 4, 4, 4, 6], + [3, 3, 0, 2, 4, 6, 4], + [2, 5, 1, 2, 6, 6, 1], + [2, 1, 6, 5, 3, 2, 3], + [6, 1, 0, 1, 0, 2, 6], + ], + [ + [4, 6, 4, 3, 3, 3, 2], + [4, 3, 2, 4, 4, 4, 6], + [2, 3, 0, 2, 4, 6, 4], + [6, 5, 1, 2, 6, 6, 1], + [4, 1, 6, 5, 3, 2, 3], + [1, 1, 0, 1, 0, 2, 6], + ], + [ + [3, 6, 4, 3, 3, 3, 2], + [2, 3, 2, 4, 4, 4, 6], + [4, 3, 0, 2, 4, 6, 4], + [0, 5, 1, 2, 6, 6, 1], + [6, 1, 6, 5, 3, 2, 3], + [4, 1, 0, 1, 0, 2, 6], + ], + [ + [0, 6, 4, 3, 3, 3, 2], + [5, 3, 2, 4, 4, 4, 6], + [0, 3, 0, 2, 4, 6, 4], + [3, 5, 1, 2, 6, 6, 1], + [0, 1, 6, 5, 3, 2, 3], + [3, 1, 0, 1, 0, 2, 6], + ], + ], + device, + )?; + + let hs = t.gather(&ids, 0)?; + assert_eq!( + hs.to_vec3::()?, + &[ + [ + [-159_f32, 126., 168., 161., -14., -33., -138.], + [-229., -41., -18., 132., -89., 169., -212.], + [223., 254., 2., -70., 87., 53., -168.], + [-221., 253., -212., 59., 154., -53., 118.], + [-144., -146., -154., -107., 31., 171., -246.], + [82., -147., -10., -253., -242., 161., -87.] + ], + [ + [-10., 126., 168., 161., 126., 2., -78.], + [25., -41., -18., 132., -89., 169., -212.], + [-248., 254., 2., -70., 87., 53., -168.], + [-33., 253., -212., 59., 154., -53., 118.], + [158., -146., -154., -107., 31., 171., -246.], + [-125., -147., -10., -253., -242., 161., -87.] + ], + [ + [-153., 126., 168., 161., 126., 2., -78.], + [-137., -41., -18., 132., -89., 169., -212.], + [187., 254., 2., -70., 87., 53., -168.], + [-246., 253., -212., 59., 154., -53., 118.], + [186., -146., -154., -107., 31., 171., -246.], + [30., -147., -10., -253., -242., 161., -87.] + ], + [ + [108., 126., 168., 161., 126., 2., -78.], + [-1., -41., -18., 132., -89., 169., -212.], + [-9., 254., 2., -70., 87., 53., -168.], + [65., 253., -212., 59., 154., -53., 118.], + [8., -146., -154., -107., 31., 171., -246.], + [176., -147., -10., -253., -242., 161., -87.] + ] + ] + ); + + // Dim: 1 + let t = Tensor::new( + &[ + [ + [-117_f32, -175., 69., -163.], + [200., 242., -21., -67.], + [179., 150., -126., -75.], + [-118., 38., -138., -13.], + [-221., 136., -185., 180.], + [58., 182., -204., -149.], + ], + [ + [3., -148., -58., -154.], + [-43., 45., -108., 4.], + [-69., -249., -71., -21.], + [80., 110., -152., -235.], + [-88., 7., 92., -250.], + [-186., 207., -242., 98.], + ], + [ + [238., 19., 64., -242.], + [-150., -97., 218., 58.], + [111., -233., 204., -212.], + [-242., -232., 83., 42.], + [153., 62., -251., 219.], + [-117., 36., -119., 10.], + ], + [ + [215., 159., -169., -27.], + [-83., 101., -88., 169.], + [-205., 93., 225., -64.], + [-162., 240., 214., 23.], + [-112., 6., 21., 245.], + [-38., 113., 93., 215.], + ], + [ + [91., -188., -148., 101.], + [74., 203., -35., 55.], + [-116., -130., -153., -96.], + [58., 22., -45., -194.], + [-221., -134., 73., 159.], + [-203., -254., 31., 235.], + ], + [ + [105., -53., 61., 186.], + [-195., 234., 75., -1.], + [51., 139., 160., -108.], + [-173., -167., 161., 19.], + [83., -246., 156., -222.], + [109., 39., -149., 137.], + ], + ], + device, + )?; + + let ids = Tensor::new( + &[ + [[4_u32, 4, 4, 2]], + [[0, 4, 4, 3]], + [[1, 5, 3, 4]], + [[0, 3, 3, 2]], + [[1, 1, 5, 2]], + [[1, 4, 5, 4]], + ], + device, + )?; + + let hs = t.gather(&ids, 1)?; + assert_eq!( + hs.to_vec3::()?, + &[ + [[-221., 136., -185., -75.]], + [[3., 7., 92., -235.]], + [[-150., 36., 83., 219.]], + [[215., 240., 214., -64.]], + [[74., 203., 31., -96.]], + [[-195., -246., -149., -222.]] + ] + ); + + // Dim: 2 + let t = Tensor::new( + &[ + [[-162_f32, 202.], [-126., -39.], [35., -65.], [1., 80.]], + [[37., 248.], [-191., 89.], [117., -40.], [-217., 220.]], + ], + device, + )?; + + let ids = Tensor::new(&[[[1_u32], [0], [1], [1]], [[0], [1], [0], [1]]], device)?; + + let hs = t.gather(&ids, 2)?; + assert_eq!( + hs.to_vec3::()?, + &[ + [[202.], [-126.], [-65.], [80.]], + [[37.], [89.], [117.], [220.]] + ] + ); + + let t = Tensor::new( + &[ + [[-21_f32, -197.], [194., 122.]], + [[255., -106.], [-191., 250.]], + [[33., -117.], [43., 10.]], + [[-130., 238.], [-217., -92.]], + ], + device, + )?; + + let ids = Tensor::new( + &[ + [[0_u32, 1], [1, 0]], + [[1, 0], [0, 1]], + [[0, 1], [0, 1]], + [[1, 0], [1, 0]], + ], + device, + )?; + + let hs = t.gather(&ids, 2)?; + assert_eq!( + hs.to_vec3::()?, + &[ + [[-21., -197.], [122., 194.]], + [[-106., 255.], [-191., 250.]], + [[33., -117.], [43., 10.]], + [[238., -130.], [-92., -217.]] + ] + ); + + Ok(()) +} + +fn broadcasting(device: &Device) -> Result<()> { + let t1 = Tensor::arange(0f32, 24f32, device)?.reshape((4, 2, 3))?; + let t2 = Tensor::new(&[100f32, 200f32], device)?; + let s = t1.broadcast_add(&t2.reshape((2, 1))?)?; + assert_eq!( + s.to_vec3::()?, + &[ + [[100.0, 101.0, 102.0], [203.0, 204.0, 205.0]], + [[106.0, 107.0, 108.0], [209.0, 210.0, 211.0]], + [[112.0, 113.0, 114.0], [215.0, 216.0, 217.0]], + [[118.0, 119.0, 120.0], [221.0, 222.0, 223.0]] + ] + ); + let s = t1.t()?.broadcast_add(&t2)?; + assert_eq!( + s.to_vec3::()?, + &[ + [[100.0, 203.0], [101.0, 204.0], [102.0, 205.0]], + [[106.0, 209.0], [107.0, 210.0], [108.0, 211.0]], + [[112.0, 215.0], [113.0, 216.0], [114.0, 217.0]], + [[118.0, 221.0], [119.0, 222.0], [120.0, 223.0]] + ] + ); + let s = t1.broadcast_sub(&t2.reshape((2, 1))?)?; + assert_eq!( + s.to_vec3::()?, + &[ + [[-100.0, -99.0, -98.0], [-197.0, -196.0, -195.0]], + [[-94.0, -93.0, -92.0], [-191.0, -190.0, -189.0]], + [[-88.0, -87.0, -86.0], [-185.0, -184.0, -183.0]], + [[-82.0, -81.0, -80.0], [-179.0, -178.0, -177.0]] + ] + ); + let s = t1.t()?.broadcast_sub(&t2)?; + assert_eq!( + s.to_vec3::()?, + &[ + [[-100.0, -197.0], [-99.0, -196.0], [-98.0, -195.0]], + [[-94.0, -191.0], [-93.0, -190.0], [-92.0, -189.0]], + [[-88.0, -185.0], [-87.0, -184.0], [-86.0, -183.0]], + [[-82.0, -179.0], [-81.0, -178.0], [-80.0, -177.0]] + ] + ); + // Test a narrowed version as this uses a layout start_offset. + let t1 = t1.i(2..)?; + let s = t1.broadcast_add(&t2.reshape((2, 1))?)?; + assert_eq!( + s.to_vec3::()?, + &[ + [[112.0, 113.0, 114.0], [215.0, 216.0, 217.0]], + [[118.0, 119.0, 120.0], [221.0, 222.0, 223.0]] + ] + ); + let s = t1.t()?.broadcast_add(&t2)?; + assert_eq!( + s.to_vec3::()?, + &[ + [[112.0, 215.0], [113.0, 216.0], [114.0, 217.0]], + [[118.0, 221.0], [119.0, 222.0], [120.0, 223.0]] + ] + ); + let s = t1.broadcast_sub(&t2.reshape((2, 1))?)?; + assert_eq!( + s.to_vec3::()?, + &[ + [[-88.0, -87.0, -86.0], [-185.0, -184.0, -183.0]], + [[-82.0, -81.0, -80.0], [-179.0, -178.0, -177.0]] + ] + ); + let s = t1.t()?.broadcast_sub(&t2)?; + assert_eq!( + s.to_vec3::()?, + &[ + [[-88.0, -185.0], [-87.0, -184.0], [-86.0, -183.0]], + [[-82.0, -179.0], [-81.0, -178.0], [-80.0, -177.0]] + ] + ); + let t3 = Tensor::new(1f32, device)?.broadcast_div(&t2)?; + let s = t1.broadcast_mul(&t2.reshape((2, 1))?)?; + let s_div = t1.broadcast_div(&t3.reshape((2, 1))?)?; + assert_eq!( + s.to_vec3::()?, + &[ + [[1200.0, 1300.0, 1400.0], [3000.0, 3200.0, 3400.0]], + [[1800.0, 1900.0, 2000.0], [4200.0, 4400.0, 4600.0]] + ] + ); + assert_eq!(s.to_vec3::()?, s_div.to_vec3::()?,); + let s = t1.t()?.broadcast_mul(&t2)?; + let s_div = t1.t()?.broadcast_div(&t3)?; + assert_eq!( + s.to_vec3::()?, + &[ + [[1200.0, 3000.0], [1300.0, 3200.0], [1400.0, 3400.0]], + [[1800.0, 4200.0], [1900.0, 4400.0], [2000.0, 4600.0]] + ] + ); + assert_eq!(s.to_vec3::()?, s_div.to_vec3::()?,); + Ok(()) +} + +fn randn(device: &Device) -> Result<()> { + let tensor = Tensor::randn(0f32, 1f32, (5, 3), device)?; + assert_eq!(tensor.dims(), [5, 3]); + // Check that the seed gets updated by checking that + // a new series of numbers is generated each time + let tensor2 = Tensor::randn(0f32, 1f32, (5, 3), device)?; + assert_ne!(tensor.to_vec2::()?, tensor2.to_vec2::()?); + let tensor = Tensor::rand(0f32, 1f32, (5, 3), device)?; + assert_eq!(tensor.dims(), [5, 3]); + // Check that the seed gets updated by checking that + // a new series of numbers is generated each time + let tensor2 = Tensor::rand(0f32, 1f32, (5, 3), device)?; + assert_ne!(tensor.to_vec2::()?, tensor2.to_vec2::()?); + // We do not expect deterministic elements at any index. + // There once was a bug that had a deterministic zero element in evenly sized tensors. + const N: usize = 2; + let v = (0..100) + .map(|_| Tensor::randn(0f32, 1f32, N, device).and_then(|t| t.to_vec1::())) + .collect::>>()?; + assert!( + (0..N).all(|i| v.windows(2).any(|pair| pair[0][i] != pair[1][i])), + "There are deterministic values in the randn tensors" + ); + let v = (0..100) + .map(|_| Tensor::rand(0f32, 1f32, N, device).and_then(|t| t.to_vec1::())) + .collect::>>()?; + assert!( + (0..N).all(|i| v.windows(2).any(|pair| pair[0][i] != pair[1][i])), + "There are deterministic values in the rand tensors" + ); + Ok(()) +} + +fn zero_dim(device: &Device) -> Result<()> { + let t = Tensor::zeros((4, 0, 1), DType::F32, device)?; + assert_eq!(t.dims3()?, (4, 0, 1)); + let t2 = Tensor::zeros((4, 3, 1), DType::F32, device)?; + let t_cat = Tensor::cat(&[&t, &t2], 1)?; + assert_eq!(t_cat.dims3()?, (4, 3, 1)); + let t_cat = Tensor::cat(&[&t, &t], 1)?; + assert_eq!(t_cat.dims3()?, (4, 0, 1)); + let t_unary = t.sqrt()?; + assert_eq!(t_unary.dims3()?, (4, 0, 1)); + let t_plus = (&t + 1.)?; + assert_eq!(t_plus.dims3()?, (4, 0, 1)); + let t_mm = t2.matmul(&t.t()?)?; + assert_eq!(t_mm.dims3()?, (4, 3, 0)); + let t_mm = t.matmul(&t2.t()?)?; + assert_eq!(t_mm.dims3()?, (4, 0, 3)); + let t_mm = t.t()?.matmul(&t)?; + assert_eq!(t_mm.dims3()?, (4, 1, 1)); + Ok(()) +} + +test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal); +test_device!(ones, ones_cpu, ones_gpu, ones_metal); +test_device!(full, full_cpu, full_gpu, full_metal); +test_device!(arange, arange_cpu, arange_gpu, arange_metal); +test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal); +test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal); +test_device!(narrow, narrow_cpu, narrow_gpu, narrow_metal); +test_device!(broadcast, broadcast_cpu, broadcast_gpu, broadcast_metal); +test_device!(slice_set, ss_cpu, ss_gpu, ss_metal); +test_device!(cat, cat_cpu, cat_gpu, cat_metal); +test_device!(sum, sum_cpu, sum_gpu, sum_metal); +test_device!(min, min_cpu, min_gpu, min_metal); +test_device!(max, max_cpu, max_gpu, max_metal); +test_device!(argmax, argmax_cpu, argmax_gpu, argmax_metal); +test_device!(argmin, argmin_cpu, argmin_gpu, argmin_metal); +test_device!(transpose, transpose_cpu, transpose_gpu, transpose_metal); +test_device!(unary_op, unary_op_cpu, unary_op_gpu, unary_op_metal); +test_device!(binary_op, binary_op_cpu, binary_op_gpu, binary_op_metal); +test_device!(embeddings, embeddings_cpu, embeddings_gpu, embeddings_metal); +test_device!(cmp, cmp_cpu, cmp_gpu, cmp_metal); +test_device!( + broadcasting, + broadcasting_cpu, + broadcasting_gpu, + broadcasting_metal +); +test_device!( + index_select, + index_select_cpu, + index_select_gpu, + index_select_metal +); +test_device!(index_add, index_add_cpu, index_add_gpu, index_add_metal); +test_device!(gather, gather_cpu, gather_gpu, gather_metal); +test_device!( + scatter_add, + scatter_add_cpu, + scatter_add_gpu, + scatter_add_metal +); +test_device!( + slice_scatter, + slice_scatter_cpu, + slice_scatter_gpu, + slice_scatter_metal +); +test_device!(randn, randn_cpu, randn_gpu, randn_metal); +test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal); +test_device!(asort, asort_cpu, asort_gpu, asort_metal); +test_device!(var, var_cpu, var_gpu, var_metal); +test_device!(zero_dim, zero_dim_cpu, zero_dim_gpu, zero_dim_metal); + +// There was originally a bug on the CPU implementation for randn +// https://github.com/huggingface/mcandle/issues/381 +#[test] +fn randn_hasneg() -> Result<()> { + let t = Tensor::randn(0f32, 1f32, 200, &Device::Cpu)?.to_vec1::()?; + if t.iter().all(|&v| v >= 0.) { + mcandle_core::bail!("all values in tensors are non-negative") + } + Ok(()) +} + +#[test] +fn pad_with_same() -> Result<()> { + let t = Tensor::arange(1f32, 5f32, &Device::Cpu)?.reshape((2, 2))?; + let t0 = t.pad_with_same(0, 1, 2)?; + assert_eq!( + t0.to_vec2::()?, + [[1.0, 2.0], [1.0, 2.0], [3.0, 4.0], [3.0, 4.0], [3.0, 4.0]] + ); + let t1 = t.pad_with_same(1, 1, 2)?; + assert_eq!( + t1.to_vec2::()?, + [[1.0, 1.0, 2.0, 2.0, 2.0], [3.0, 3.0, 4.0, 4.0, 4.0]] + ); + Ok(()) +} + +#[test] +fn i64_abs() -> Result<()> { + let t = Tensor::new(&[-42i64, 1337], &Device::Cpu)?; + let t = t.abs()?; + assert_eq!(t.to_vec1::()?, [42, 1337]); + Ok(()) +} + +#[test] +fn tril_triu_eye() -> Result<()> { + let t = Tensor::tril2(4, DType::F32, &Device::Cpu)?; + assert_eq!( + t.to_vec2::()?, + [ + [1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 1.0, 1.0] + ], + ); + let t = Tensor::triu2(4, DType::F32, &Device::Cpu)?; + assert_eq!( + t.to_vec2::()?, + [ + [1.0, 1.0, 1.0, 1.0], + [0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 1.0] + ] + ); + let t = Tensor::eye(4, DType::F32, &Device::Cpu)?; + assert_eq!( + t.to_vec2::()?, + [ + [1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 1.0] + ] + ); + Ok(()) +} + +#[test] +fn cumsum() -> Result<()> { + let t = &[3f32, 1., 4., 1., 5.]; + let t = Tensor::new(t, &Device::Cpu)?; + assert_eq!(t.cumsum(0)?.to_vec1::()?, [3., 4., 8., 9., 14.]); + let t = t.unsqueeze(1)?; + assert_eq!( + t.cumsum(0)?.to_vec2::()?, + [[3.0], [4.0], [8.0], [9.0], [14.0]] + ); + assert_eq!( + t.cumsum(1)?.to_vec2::()?, + [[3.0], [1.0], [4.0], [1.0], [5.0]] + ); + let t = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]]; + let t = Tensor::new(t, &Device::Cpu)?; + assert_eq!( + t.cumsum(1)?.to_vec2::()?, + [[3.0, 4.0, 8.0, 9.0, 14.0], [2.0, 3.0, 10.0, 18.0, 20.0]], + ); + assert_eq!( + t.cumsum(0)?.to_vec2::()?, + [[3.0, 1.0, 4.0, 1.0, 5.0], [5.0, 2.0, 11.0, 9.0, 7.0]] + ); + Ok(()) +} + +/// A helper function for floating point comparison. Both a and b must be 1D Tensor and contains the same amount of data. +/// Assertion passes if the difference of all pairs of a and b is smaller than epsilon. +fn assert_close(a: &Tensor, b: &Tensor, epsilon: f64) -> Result<()> { + let a_vec: Vec = a.to_vec1()?; + let b_vec: Vec = b.to_vec1()?; + + assert_eq!(a_vec.len(), b_vec.len()); + for (a, b) in a_vec.iter().zip(b_vec.iter()) { + assert!((a - b).abs() < epsilon); + } + Ok(()) +} + +#[test] +fn log_sum_exp() -> Result<()> { + let input = Tensor::new( + &[ + [[1f64, 2., 3.], [4., 5., 6.]], + [[-1000.0, -999.0, -1001.0], [1000.0, 999.0, 1001.0]], + ], + &Device::Cpu, + )?; + + let output = input.log_sum_exp(D::Minus1)?; + // The expectations obtained from pytorch. + let expected = Tensor::new(&[[3.4076, 6.4076], [-998.5924, 1001.4076]], &Device::Cpu)?; + assert_eq!(output.dims(), expected.dims()); + assert_close(&output.flatten_all()?, &expected.flatten_all()?, 0.00001)?; + + assert_eq!( + input.log_sum_exp((0, 1))?.to_vec1::()?, + [1000.0, 999.0, 1001.0] + ); + assert_eq!( + input.log_sum_exp(())?.to_vec3::()?, + input.to_vec3::()? + ); + + Ok(()) +} + +#[test] +fn pow() -> Result<()> { + let lhs = Tensor::new(&[[1f32, 2., 3.], [4., 5., 6.]], &Device::Cpu)?; + let rhs = (&lhs - 2.)?; + let res = lhs.pow(&rhs)?; + assert_eq!( + test_utils::to_vec2_round(&res, 3)?, + [[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0]] + ); + Ok(()) +} + +#[test] +fn unfold() -> Result<()> { + let x = Tensor::arange(0i64, 3 * 2, &Device::Cpu)?.reshape((3, 2))?; + let unfolded = x.unfold(0, 2, 1)?; + dbg!(&unfolded); + assert_eq!( + unfolded.to_vec3::()?, + vec![[[0i64, 2], [1, 3]], [[2, 4], [3, 5]]] + ); + Ok(()) +} diff --git a/mcandle/mcandle-core/tests/test.npy b/mcandle/mcandle-core/tests/test.npy new file mode 100644 index 000000000..a3ff5af95 Binary files /dev/null and b/mcandle/mcandle-core/tests/test.npy differ diff --git a/mcandle/mcandle-core/tests/test.npz b/mcandle/mcandle-core/tests/test.npz new file mode 100644 index 000000000..b6683caab Binary files /dev/null and b/mcandle/mcandle-core/tests/test.npz differ diff --git a/mcandle/mcandle-core/tests/test.pt b/mcandle/mcandle-core/tests/test.pt new file mode 100644 index 000000000..f2fa7da3e Binary files /dev/null and b/mcandle/mcandle-core/tests/test.pt differ diff --git a/mcandle/mcandle-core/tests/test_with_key.pt b/mcandle/mcandle-core/tests/test_with_key.pt new file mode 100644 index 000000000..a598e02c4 Binary files /dev/null and b/mcandle/mcandle-core/tests/test_with_key.pt differ diff --git a/mcandle/mcandle-flash-attn/Cargo.toml b/mcandle/mcandle-flash-attn/Cargo.toml new file mode 100644 index 000000000..2c1a997e2 --- /dev/null +++ b/mcandle/mcandle-flash-attn/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "mcandle-flash-attn" +version = "0.3.2" +edition = "2021" + +description = "Flash attention layer for the mcandle ML framework." +repository = "https://github.com/huggingface/mcandle" +keywords = ["blas", "tensor", "machine-learning"] +categories = ["science"] +license = "MIT OR Apache-2.0" +readme = "README.md" + +[dependencies] +mcandle = { path = "../mcandle-core", features = ["cuda"], package = "mcandle-core" } +half = { version = "2.3.1", features = ["num-traits"] } + +[build-dependencies] +bindgen_cuda = "0.1.1" +anyhow = { version = "1", features = ["backtrace"] } + + +[dev-dependencies] +anyhow = { version = "1", features = ["backtrace"] } +mcandle-nn = { path = "../mcandle-nn", features = ["cuda"] } diff --git a/mcandle/mcandle-flash-attn/README.md b/mcandle/mcandle-flash-attn/README.md new file mode 100644 index 000000000..dc9d82a89 --- /dev/null +++ b/mcandle/mcandle-flash-attn/README.md @@ -0,0 +1 @@ +# mcandle-flash-attn diff --git a/mcandle/mcandle-flash-attn/build.rs b/mcandle/mcandle-flash-attn/build.rs new file mode 100644 index 000000000..bd86a44ad --- /dev/null +++ b/mcandle/mcandle-flash-attn/build.rs @@ -0,0 +1,118 @@ +// Build script to run nvcc and generate the C glue code for launching the flash-attention kernel. +// The cuda build time is very long so one can set the CANDLE_FLASH_ATTN_BUILD_DIR environment +// variable in order to cache the compiled artifacts and avoid recompiling too often. +use anyhow::{Context, Result}; +use std::path::PathBuf; + +const CUDA_NVCC_FLAGS: Option<&'static str> = option_env!("CUDA_NVCC_FLAGS"); + +const KERNEL_FILES: [&str; 33] = [ + "kernels/flash_api.cu", + "kernels/flash_fwd_hdim128_fp16_sm80.cu", + "kernels/flash_fwd_hdim160_fp16_sm80.cu", + "kernels/flash_fwd_hdim192_fp16_sm80.cu", + "kernels/flash_fwd_hdim224_fp16_sm80.cu", + "kernels/flash_fwd_hdim256_fp16_sm80.cu", + "kernels/flash_fwd_hdim32_fp16_sm80.cu", + "kernels/flash_fwd_hdim64_fp16_sm80.cu", + "kernels/flash_fwd_hdim96_fp16_sm80.cu", + "kernels/flash_fwd_hdim128_bf16_sm80.cu", + "kernels/flash_fwd_hdim160_bf16_sm80.cu", + "kernels/flash_fwd_hdim192_bf16_sm80.cu", + "kernels/flash_fwd_hdim224_bf16_sm80.cu", + "kernels/flash_fwd_hdim256_bf16_sm80.cu", + "kernels/flash_fwd_hdim32_bf16_sm80.cu", + "kernels/flash_fwd_hdim64_bf16_sm80.cu", + "kernels/flash_fwd_hdim96_bf16_sm80.cu", + "kernels/flash_fwd_hdim128_fp16_causal_sm80.cu", + "kernels/flash_fwd_hdim160_fp16_causal_sm80.cu", + "kernels/flash_fwd_hdim192_fp16_causal_sm80.cu", + "kernels/flash_fwd_hdim224_fp16_causal_sm80.cu", + "kernels/flash_fwd_hdim256_fp16_causal_sm80.cu", + "kernels/flash_fwd_hdim32_fp16_causal_sm80.cu", + "kernels/flash_fwd_hdim64_fp16_causal_sm80.cu", + "kernels/flash_fwd_hdim96_fp16_causal_sm80.cu", + "kernels/flash_fwd_hdim128_bf16_causal_sm80.cu", + "kernels/flash_fwd_hdim160_bf16_causal_sm80.cu", + "kernels/flash_fwd_hdim192_bf16_causal_sm80.cu", + "kernels/flash_fwd_hdim224_bf16_causal_sm80.cu", + "kernels/flash_fwd_hdim256_bf16_causal_sm80.cu", + "kernels/flash_fwd_hdim32_bf16_causal_sm80.cu", + "kernels/flash_fwd_hdim64_bf16_causal_sm80.cu", + "kernels/flash_fwd_hdim96_bf16_causal_sm80.cu", +]; + +fn main() -> Result<()> { + println!("cargo:rerun-if-changed=build.rs"); + for kernel_file in KERNEL_FILES.iter() { + println!("cargo:rerun-if-changed={kernel_file}"); + } + println!("cargo:rerun-if-changed=kernels/flash_fwd_kernel.h"); + println!("cargo:rerun-if-changed=kernels/flash_fwd_launch_template.h"); + println!("cargo:rerun-if-changed=kernels/flash.h"); + println!("cargo:rerun-if-changed=kernels/philox.cuh"); + println!("cargo:rerun-if-changed=kernels/softmax.h"); + println!("cargo:rerun-if-changed=kernels/utils.h"); + println!("cargo:rerun-if-changed=kernels/kernel_traits.h"); + println!("cargo:rerun-if-changed=kernels/block_info.h"); + println!("cargo:rerun-if-changed=kernels/static_switch.h"); + let out_dir = PathBuf::from(std::env::var("OUT_DIR").context("OUT_DIR not set")?); + let build_dir = match std::env::var("CANDLE_FLASH_ATTN_BUILD_DIR") { + Err(_) => + { + #[allow(clippy::redundant_clone)] + out_dir.clone() + } + Ok(build_dir) => { + let path = PathBuf::from(build_dir); + path.canonicalize().expect(&format!( + "Directory doesn't exists: {} (the current directory is {})", + &path.display(), + std::env::current_dir()?.display() + )) + } + }; + + let kernels = KERNEL_FILES.iter().collect(); + let mut builder = bindgen_cuda::Builder::default() + .kernel_paths(kernels) + .out_dir(build_dir.clone()) + .arg("-std=c++17") + .arg("-O3") + .arg("-U__CUDA_NO_HALF_OPERATORS__") + .arg("-U__CUDA_NO_HALF_CONVERSIONS__") + .arg("-U__CUDA_NO_HALF2_OPERATORS__") + .arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__") + .arg("-Icutlass/include") + .arg("--expt-relaxed-constexpr") + .arg("--expt-extended-lambda") + .arg("--use_fast_math") + .arg("--verbose"); + + // https://github.com/EricLBuehler/mistral.rs/issues/286 + // https://github.com/huggingface/mcandle-flash-attn-v1/pull/2 + if let Some(cuda_nvcc_flags_env) = CUDA_NVCC_FLAGS { + builder = builder.arg("--compiler-options"); + builder = builder.arg(cuda_nvcc_flags_env); + } + + let out_file = build_dir.join("libflashattention.a"); + builder.build_lib(out_file); + + println!("cargo:rustc-link-search={}", build_dir.display()); + println!("cargo:rustc-link-lib=flashattention"); + println!("cargo:rustc-link-lib=dylib=cudart"); + // https://github.com/denoland/rusty_v8/blob/20b2989186d1ecdf4c291d0706ff9eb1baaf2cfd/build.rs#L602 + let target = std::env::var("TARGET").unwrap(); + if target.contains("msvc") { + // nothing to link to + } else if target.contains("apple") || target.contains("freebsd") || target.contains("openbsd") { + println!("cargo:rustc-link-lib=dylib=c++"); + } else if target.contains("android") { + println!("cargo:rustc-link-lib=dylib=c++_shared"); + } else { + println!("cargo:rustc-link-lib=dylib=stdc++"); + } + + Ok(()) +} diff --git a/mcandle/mcandle-flash-attn/kernels/alibi.h b/mcandle/mcandle-flash-attn/kernels/alibi.h new file mode 100644 index 000000000..e714233e7 --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/alibi.h @@ -0,0 +1,74 @@ +#include + +#include + +#include +#include + +#include "utils.h" + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Alibi { + + const float alibi_slope; + const int max_seqlen_k, max_seqlen_q; + + __forceinline__ __device__ Alibi(const float alibi_slope, const int max_seqlen_k, const int max_seqlen_q) + : alibi_slope(alibi_slope) + , max_seqlen_k(max_seqlen_k) + , max_seqlen_q(max_seqlen_q) { + }; + + + template + __forceinline__ __device__ void apply_alibi(Tensor &tensor, + const int col_idx_offset_, + const int row_idx_offset, + const int warp_row_stride) { + // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx; + } + } + } + } else { // Bias depends on both row_idx and col_idx + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx); + } + } + } + } + } + } + +}; + +} // namespace flash diff --git a/mcandle/mcandle-flash-attn/kernels/block_info.h b/mcandle/mcandle-flash-attn/kernels/block_info.h new file mode 100644 index 000000000..3a23a1e1f --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/block_info.h @@ -0,0 +1,46 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +namespace flash { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct BlockInfo { + + template + __device__ BlockInfo(const Params ¶ms, const int bidb) + : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]) + , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb]) + , actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q) + // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. + // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. + , seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])) + , actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) + { + } + + template + __forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; + } + + template + __forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride; + } + + const int sum_s_q; + const int sum_s_k; + const int actual_seqlen_q; + // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0. + const int seqlen_k_cache; + const int actual_seqlen_k; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash diff --git a/mcandle/mcandle-flash-attn/kernels/dropout.h b/mcandle/mcandle-flash-attn/kernels/dropout.h new file mode 100644 index 000000000..4882f97d9 --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/dropout.h @@ -0,0 +1,94 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "philox.cuh" +#include "utils.h" + +namespace flash { + +struct Dropout { + + const unsigned long long seed, offset; + const uint8_t p_dropout_in_uint8_t; + + __forceinline__ __device__ Dropout(const unsigned long long seed, const unsigned long long offset, + const uint8_t p_dropout_in_uint8_t, + const int bid, const int hid, const int tid, const int nheads) + : seed(seed) + , offset(offset + (bid * nheads + hid) * 32 + tid % 32) + , p_dropout_in_uint8_t(p_dropout_in_uint8_t) { + } + + template + __forceinline__ __device__ void apply_dropout(Tensor &tensor_, + int block_row_start, int block_col_start, int block_row_stride) { + // convert shape from (4, MMA_M, MMA_N) to (8, MMA_M, MMA_N / 2) + Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_dropout(tensor_.layout())); + using T = typename Engine::value_type; + auto encode_dropout = [](bool keep, T val) { + return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0)); + }; + static_assert(decltype(size<2>(tensor))::value % 2 == 0); + const uint16_t p_dropout_8bit_in_uint16_t = uint16_t(p_dropout_in_uint8_t); + const uint32_t p_dropout_8bit_in_uint32_t = (uint32_t(p_dropout_8bit_in_uint16_t) << 16) | uint32_t(p_dropout_8bit_in_uint16_t); + // if (cute::thread0()) { printf("threshold2 = 0x%x\n", p_dropout_8bit_in_uint32_t); } + #pragma unroll + for (int m = 0; m < size<1>(tensor); ++m, block_row_start += block_row_stride) { + uint2 rowcol = make_uint2(block_row_start, block_col_start); + #pragma unroll + for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) { + // if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));} + uint4 random_uint4 = flash::philox(seed, reinterpret_cast(rowcol), offset); + // if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);} + uint8_t (&rnd_8)[16] = reinterpret_cast(random_uint4); + // Special implementation for 16-bit types: we duplicate the threshold to the + // low and high 16 bits of a 32-bit value, then use the f16x2 comparison instruction + // to get a mask. The low 16 bits of the mask will be either 0xffff or 0x0000, + // and the high 16 bits will be either 0xffff or 0x0000, depending on whether + // the random value is less than the threshold. + // We then do a bit-wise AND between the mask and the original value (in 32-bit). + // We're exploiting the fact that floating point comparison is equivalent to integer + // comparison, since we're comparing unsigned integers whose top 8-bits are zero. + if (!encode_dropout_in_sign_bit + && (std::is_same::value || std::is_same::value)) { + uint16_t rnd_16[16]; + #pragma unroll + for (int i = 0; i < 16; i++) { rnd_16[i] = uint16_t(rnd_8[i]); } + uint32_t (&rnd_32)[8] = reinterpret_cast(rnd_16); + #pragma unroll + for (int j = 0; j < 2; j++) { + Tensor tensor_uint32 = recast(tensor(_, m, n * 2 + j)); + // if (cute::thread0()) { printf("random = 0x%x, 0x%x, 0x%x, 0x%x\n", rnd_32[j * 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); } + // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } + #pragma unroll + for (int i = 0; i < 4; i++) { + uint32_t mask; + asm volatile("set.le.u32.f16x2 %0, %1, %2;\n" : "=r"(mask) : "r"(rnd_32[j * 4 + i]), "r"(p_dropout_8bit_in_uint32_t)); + tensor_uint32(i) &= mask; + } + // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } + } + } else { + #pragma unroll + for (int j = 0; j < 2; j++) { + #pragma unroll + for (int i = 0; i < 8; i++) { + tensor(i, m, n * 2 + j) = encode_dropout(rnd_8[j * 8 + i] <= p_dropout_in_uint8_t, tensor(i, m, n * 2 + j)); + } + Tensor tensor_uint32 = recast(tensor(_, m, n * 2 + j)); + // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } + } + } + // // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // // printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w); + // // } + } + } + } + +}; + +} // namespace flash diff --git a/mcandle/mcandle-flash-attn/kernels/error.h b/mcandle/mcandle-flash-attn/kernels/error.h new file mode 100644 index 000000000..034169240 --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/error.h @@ -0,0 +1,8 @@ +#pragma once + +#define C10_CUDA_CHECK(EXPR) \ + do { \ + const cudaError_t __err = EXPR; \ + } while (0) + +#define C10_CUDA_KERNEL_LAUNCH_CHECK() C10_CUDA_CHECK(cudaGetLastError()) diff --git a/mcandle/mcandle-flash-attn/kernels/flash.h b/mcandle/mcandle-flash-attn/kernels/flash.h new file mode 100644 index 000000000..88c2f22a5 --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/flash.h @@ -0,0 +1,194 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include + +// #ifdef OLD_GENERATOR_PATH +// #include +// #else +// #include +// #endif +// +// #include // For at::cuda::philox::unpack + +constexpr int TOTAL_DIM = 0; +constexpr int H_DIM = 1; +constexpr int D_DIM = 2; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Qkv_params { + using index_t = int64_t; + // The QKV matrices. + void *__restrict__ q_ptr; + void *__restrict__ k_ptr; + void *__restrict__ v_ptr; + + // The stride between rows of the Q, K and V matrices. + index_t q_batch_stride; + index_t k_batch_stride; + index_t v_batch_stride; + index_t q_row_stride; + index_t k_row_stride; + index_t v_row_stride; + index_t q_head_stride; + index_t k_head_stride; + index_t v_head_stride; + + // The number of heads. + int h, h_k; + // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be + // different from nheads (query). + int h_h_k_ratio; // precompute h / h_k, +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Flash_fwd_params : public Qkv_params { + + // The O matrix (output). + void * __restrict__ o_ptr; + void * __restrict__ oaccum_ptr; + + // The stride between rows of O. + index_t o_batch_stride; + index_t o_row_stride; + index_t o_head_stride; + + // The pointer to the P matrix. + void * __restrict__ p_ptr; + + // The pointer to the softmax sum. + void * __restrict__ softmax_lse_ptr; + void * __restrict__ softmax_lseaccum_ptr; + + // The dimensions. + int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim, total_q; + + // The scaling factors for the kernel. + float scale_softmax; + float scale_softmax_log2; + + // array of length b+1 holding starting offset of each sequence. + int * __restrict__ cu_seqlens_q; + int * __restrict__ cu_seqlens_k; + + // If provided, the actual length of each k sequence. + int * __restrict__ seqused_k; + + int *__restrict__ blockmask; + + // The K_new and V_new matrices. + void * __restrict__ knew_ptr; + void * __restrict__ vnew_ptr; + + // The stride between rows of the Q, K and V matrices. + index_t knew_batch_stride; + index_t vnew_batch_stride; + index_t knew_row_stride; + index_t vnew_row_stride; + index_t knew_head_stride; + index_t vnew_head_stride; + + // The cos and sin matrices for rotary embedding. + void * __restrict__ rotary_cos_ptr; + void * __restrict__ rotary_sin_ptr; + + // The indices to index into the KV cache. + int * __restrict__ cache_batch_idx; + + // Paged KV cache + int * __restrict__ block_table; + index_t block_table_batch_stride; + int page_block_size; + + // The dropout probability (probability of keeping an activation). + float p_dropout; + // uint32_t p_dropout_in_uint; + // uint16_t p_dropout_in_uint16_t; + uint8_t p_dropout_in_uint8_t; + + // Scale factor of 1 / (1 - p_dropout). + float rp_dropout; + float scale_softmax_rp_dropout; + + // Local window size + int window_size_left, window_size_right; + float softcap; + + // Random state. + // at::PhiloxCudaState philox_args; + + // Pointer to the RNG seed (idx 0) and offset (idx 1). + uint64_t * rng_state; + + bool is_bf16; + bool is_causal; + + // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. + // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. + bool is_seqlens_k_cumulative; + + bool is_rotary_interleaved; + + int num_splits; // For split-KV version + + void * __restrict__ alibi_slopes_ptr; + index_t alibi_slopes_batch_stride; + + bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q]. + bool seqlenq_ngroups_swapped; // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d). +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Flash_bwd_params : public Flash_fwd_params { + + // The dO and dQKV matrices. + void *__restrict__ do_ptr; + void *__restrict__ dq_ptr; + void *__restrict__ dk_ptr; + void *__restrict__ dv_ptr; + + // To accumulate dQ + void *__restrict__ dq_accum_ptr; + void *__restrict__ dk_accum_ptr; + void *__restrict__ dv_accum_ptr; + + // // To accumulate dK and dV in case we're splitting the bwd along seqlen_q + // dimension void *__restrict__ dk_accum_ptr; void *__restrict__ + // dv_accum_ptr; + + // The stride between rows of the dO, dQ, dK and dV matrices. + // TD [2022-04-16]: We're using 32-bit indexing to save registers. + // The code probably won't work for arrays larger than 2GB. + index_t do_batch_stride; + index_t do_row_stride; + index_t do_head_stride; + index_t dq_batch_stride; + index_t dk_batch_stride; + index_t dv_batch_stride; + index_t dq_row_stride; + index_t dk_row_stride; + index_t dv_row_stride; + index_t dq_head_stride; + index_t dk_head_stride; + index_t dv_head_stride; + + // The pointer to the softmax d sum. + void *__restrict__ dsoftmax_sum; + + bool deterministic; + index_t dq_accum_split_stride; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); diff --git a/mcandle/mcandle-flash-attn/kernels/flash_api.cu b/mcandle/mcandle-flash-attn/kernels/flash_api.cu new file mode 100644 index 000000000..ca5f2b255 --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/flash_api.cu @@ -0,0 +1,133 @@ +#include "kernels.h" +#include "kernel_helpers.h" +#include "flash_fwd_launch_template.h" + +void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + FP16_SWITCH(!params.is_bf16, [&] { + HEADDIM_SWITCH(params.d, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + run_mha_fwd_(params, stream); + }); + }); + }); +} + +extern "C" void run_mha( + void *q_ptr, + void *k_ptr, + void *v_ptr, + void *o_ptr, + void *softmax_lse_ptr, + void *alibi_slopes_ptr, + + int32_t *cu_seqlens_q_ptr, + int32_t *cu_seqlens_k_ptr, + + uint32_t q_batch_stride, + uint32_t k_batch_stride, + uint32_t v_batch_stride, + uint32_t o_batch_stride, + uint32_t alibi_slopes_batch_stride, + + uint32_t q_row_stride, + uint32_t k_row_stride, + uint32_t v_row_stride, + uint32_t o_row_stride, + + uint32_t q_head_stride, + uint32_t k_head_stride, + uint32_t v_head_stride, + uint32_t o_head_stride, + + uint32_t b, + uint32_t h, + uint32_t h_k, + uint32_t d, + uint32_t d_rounded, + float softmax_scale, + float softcap, + + uint32_t seqlen_q, + uint32_t seqlen_k, + uint32_t seqlen_q_rounded, + uint32_t seqlen_k_rounded, + + int is_bf16, + int is_causal, + + int window_size_left, + int window_size_right +) { + Flash_fwd_params params; + // Reset the parameters + memset(¶ms, 0, sizeof(params)); + + // Set the pointers and strides. + params.q_ptr = q_ptr; + params.k_ptr = k_ptr; + params.v_ptr = v_ptr; + params.o_ptr = o_ptr; + + params.softmax_lse_ptr = softmax_lse_ptr; + params.alibi_slopes_ptr = alibi_slopes_ptr; + + // All stride are in elements, not bytes. + params.q_batch_stride = q_batch_stride; + params.k_batch_stride = k_batch_stride; + params.v_batch_stride = v_batch_stride; + params.o_batch_stride = o_batch_stride; + params.alibi_slopes_batch_stride = alibi_slopes_batch_stride; + + params.q_row_stride = q_row_stride; + params.k_row_stride = k_row_stride; + params.v_row_stride = v_row_stride; + params.o_row_stride = o_row_stride; + params.q_head_stride = q_head_stride; + params.k_head_stride = k_head_stride; + params.v_head_stride = v_head_stride; + params.o_head_stride = o_head_stride; + + // Set the dimensions. + params.b = b; + params.h = h; + params.h_k = h_k; + params.h_h_k_ratio = h / h_k; + params.seqlen_q = seqlen_q; + params.seqlen_k = seqlen_k; + params.seqlen_q_rounded = seqlen_q_rounded; + params.seqlen_k_rounded = seqlen_k_rounded; + params.d = d; + params.d_rounded = d_rounded; + + // Set the different scale values. + if (softcap > 0.0) { + params.softcap = softmax_scale / softcap; + params.scale_softmax = softcap; + params.scale_softmax_log2 = softcap * M_LOG2E; + }else{ + // Remove potential NaN + params.softcap = 0.0; + params.scale_softmax = softmax_scale; + params.scale_softmax_log2 = softmax_scale * M_LOG2E; + } + + params.p_dropout = 1.; // probability to keep + params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); + params.rp_dropout = 1.f / params.p_dropout; + params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax; + params.is_bf16 = is_bf16; + params.cu_seqlens_q = cu_seqlens_q_ptr; + params.cu_seqlens_k = cu_seqlens_k_ptr; + params.p_ptr = nullptr; // used for `return_softmax`. + params.seqused_k = nullptr; + + params.is_causal = is_causal; + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + + params.is_seqlens_k_cumulative = true; + params.num_splits = 1; + + cudaStream_t stream = 0; // Use the default stream. + run_mha_fwd(params, stream); +} diff --git a/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim128_bf16_causal_sm80.cu b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim128_bf16_causal_sm80.cu new file mode 100644 index 000000000..f19049b49 --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim128_bf16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} diff --git a/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu new file mode 100644 index 000000000..cb1357419 --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} diff --git a/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim128_fp16_causal_sm80.cu b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim128_fp16_causal_sm80.cu new file mode 100644 index 000000000..dfb04b78b --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim128_fp16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} diff --git a/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu new file mode 100644 index 000000000..6df16b2c3 --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} diff --git a/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim160_bf16_causal_sm80.cu b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim160_bf16_causal_sm80.cu new file mode 100644 index 000000000..230af9069 --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim160_bf16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim160(params, stream); +} diff --git a/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu new file mode 100644 index 000000000..cf1ffad20 --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim160(params, stream); +} diff --git a/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim160_fp16_causal_sm80.cu b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim160_fp16_causal_sm80.cu new file mode 100644 index 000000000..1fc5ac597 --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim160_fp16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim160(params, stream); +} diff --git a/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu new file mode 100644 index 000000000..a9796aded --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim160(params, stream); +} diff --git a/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim192_bf16_causal_sm80.cu b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim192_bf16_causal_sm80.cu new file mode 100644 index 000000000..94792d4d3 --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim192_bf16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); +} diff --git a/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu new file mode 100644 index 000000000..76d5136b1 --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); +} diff --git a/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim192_fp16_causal_sm80.cu b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim192_fp16_causal_sm80.cu new file mode 100644 index 000000000..9e5b21e02 --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim192_fp16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); +} diff --git a/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu new file mode 100644 index 000000000..b4019a0be --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); +} diff --git a/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim224_bf16_causal_sm80.cu b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim224_bf16_causal_sm80.cu new file mode 100644 index 000000000..a12a5f4ad --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim224_bf16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim224(params, stream); +} diff --git a/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu new file mode 100644 index 000000000..8690bdb1a --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim224(params, stream); +} diff --git a/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim224_fp16_causal_sm80.cu b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim224_fp16_causal_sm80.cu new file mode 100644 index 000000000..f01dad09c --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim224_fp16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim224(params, stream); +} diff --git a/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu new file mode 100644 index 000000000..7ec1e16b7 --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim224(params, stream); +} diff --git a/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim256_bf16_causal_sm80.cu b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim256_bf16_causal_sm80.cu new file mode 100644 index 000000000..3d816ab60 --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim256_bf16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} diff --git a/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu new file mode 100644 index 000000000..c6c55229c --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} diff --git a/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim256_fp16_causal_sm80.cu b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim256_fp16_causal_sm80.cu new file mode 100644 index 000000000..0149abacd --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim256_fp16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} diff --git a/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu new file mode 100644 index 000000000..9c9a1715e --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} diff --git a/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim32_bf16_causal_sm80.cu b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim32_bf16_causal_sm80.cu new file mode 100644 index 000000000..29097ac3a --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim32_bf16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} diff --git a/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu new file mode 100644 index 000000000..cb52f34fa --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} diff --git a/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim32_fp16_causal_sm80.cu b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim32_fp16_causal_sm80.cu new file mode 100644 index 000000000..7bdadefbe --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim32_fp16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} diff --git a/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu new file mode 100644 index 000000000..44b388161 --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} diff --git a/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim64_bf16_causal_sm80.cu b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim64_bf16_causal_sm80.cu new file mode 100644 index 000000000..99cd728bc --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim64_bf16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} diff --git a/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu new file mode 100644 index 000000000..c11096ac1 --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} diff --git a/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim64_fp16_causal_sm80.cu b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim64_fp16_causal_sm80.cu new file mode 100644 index 000000000..2fbcd44e6 --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim64_fp16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} diff --git a/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu new file mode 100644 index 000000000..7b65a9c9e --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} diff --git a/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim96_bf16_causal_sm80.cu b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim96_bf16_causal_sm80.cu new file mode 100644 index 000000000..6fb3cf642 --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim96_bf16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); +} diff --git a/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu new file mode 100644 index 000000000..e696b2f2c --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); +} diff --git a/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim96_fp16_causal_sm80.cu b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim96_fp16_causal_sm80.cu new file mode 100644 index 000000000..bb3b744d1 --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim96_fp16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); +} diff --git a/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu new file mode 100644 index 000000000..5f3accc30 --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); +} diff --git a/mcandle/mcandle-flash-attn/kernels/flash_fwd_kernel.h b/mcandle/mcandle-flash-attn/kernels/flash_fwd_kernel.h new file mode 100644 index 000000000..1bf77f81d --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/flash_fwd_kernel.h @@ -0,0 +1,1298 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +#include +#include +#include + +#include "block_info.h" +#include "kernel_traits.h" +#include "utils.h" +#include "softmax.h" +#include "mask.h" +#include "dropout.h" +#include "rotary.h" + +namespace flash { + +using namespace cute; + +template +__forceinline__ __device__ void apply_softcap(Tensor &tensor, const float softcap){ + #pragma unroll + for (int i = 0; i < size(tensor); ++i) { + tensor(i) = cutlass::fast_tanh(tensor(i) * softcap); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ auto get_lse_tile(const Params ¶ms, const int bidb, const int bidh, const int m_block, const BlockInfo &binfo) { + // When params.unpadded_lse is false, LSE is written as (b, h, seqlen_q) - this is non-variable seqlen path. + // Otherwise, when params.seqlenq_ngroups_swapped is true, it is written as (h, seqlen_q, b) to account for seqlen_q <-> h swapping trick. + // Otherwise, it's written as (h, b, seqlen_q). + const bool varlen_q = params.unpadded_lse && !params.seqlenq_ngroups_swapped; + auto lse_offset = varlen_q ? binfo.q_offset(params.seqlen_q, 1, bidb) : 0; + auto gmem_ptr_lse = make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + lse_offset); + + auto lse_shape = varlen_q ? make_shape(1, params.h, params.total_q) : make_shape(params.b, params.h, params.seqlen_q); + auto lse_stride = params.seqlenq_ngroups_swapped ? make_stride(1, params.seqlen_q * params.b, params.b) : ( + params.unpadded_lse ? make_stride(params.h * params.total_q, params.total_q, 1) : make_stride(params.h * params.seqlen_q, params.seqlen_q, 1) + ); + + auto lse_layout = make_layout(lse_shape, lse_stride); + Tensor mLSE = make_tensor(gmem_ptr_lse, lse_layout); + auto mLSE_slice = varlen_q ? mLSE(0, bidh, _) : mLSE(bidb, bidh, _); + return local_tile(mLSE_slice, Shape>{}, make_coord(m_block)); +} + + +template +inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) { + + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNWarps = Kernel_traits::kNWarps; + + auto seed_offset = std::make_tuple(0ull, 0ull); + // auto seed_offset = at::cuda::philox::unpack(params.philox_args); + flash::Dropout dropout(std::get<0>(seed_offset), std::get<1>(seed_offset), params.p_dropout_in_uint8_t, + bidb, bidh, tidx, params.h); + + // Save seed and offset for backward, before any early exiting. Otherwise the 0-th thread block might + // exit early and no one saves the rng states. + if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) { + params.rng_state[0] = std::get<0>(seed_offset); + params.rng_state[1] = std::get<1>(seed_offset); + } + + const BlockInfo binfo(params, bidb); + if (m_block * kBlockM >= binfo.actual_seqlen_q) return; + + const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); + int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); + if (Is_causal || Is_local) { + n_block_max = std::min(n_block_max, + cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { + // printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max); + // } + } + // We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0. + // Otherwise we might read OOB elements from gK and gV. + if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) { + Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + + binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)), + make_shape(binfo.actual_seqlen_q, params.h, params.d), + make_stride(params.o_row_stride, params.o_head_stride, _1{})); + Tensor gO = local_tile(mO(_, bidh, _), Shape, Int>{}, + make_coord(m_block, 0)); // (kBlockM, kHeadDim) + + Tensor gLSE = get_lse_tile(params, bidb, bidh, m_block, binfo); + + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + Tensor tOrO = make_tensor(shape(tOgO)); + clear(tOrO); + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + ); + #pragma unroll + for (int m = 0; m < size<1>(tOgO); ++m) { + const int row = get<0>(tOcO(0, m, 0)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSE(row) = INFINITY; } + } + return; + } + // if (tidx == 0) { printf("m_block = %d, n_block_min = %d, n_block_max = %d\n", m_block, n_block_min, n_block_max); } + + // We iterate over the blocks in reverse order. This is because the last block is the only one + // that needs masking when we read K and V from global memory. Moreover, iterating in reverse + // might save us 1 register (we just need n_block instead of both n_block and n_block_max). + + const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN; + + Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)), + make_shape(binfo.actual_seqlen_q, params.h, params.d), + make_stride(params.q_row_stride, params.q_head_stride, _1{})); + Tensor gQ = local_tile(mQ(_, bidh, _), Shape, Int>{}, + make_coord(m_block, 0)); // (kBlockM, kHeadDim) + Tensor mK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + + binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)), + make_shape(binfo.actual_seqlen_k, params.h_k, params.d), + make_stride(params.k_row_stride, params.k_head_stride, _1{})); + Tensor gK = local_tile(mK(_, bidh / params.h_h_k_ratio, _), Shape, Int>{}, + make_coord(_, 0)); // (kBlockN, kHeadDim, nblocksN) + Tensor mV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + + binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)), + make_shape(binfo.actual_seqlen_k, params.h_k, params.d), + make_stride(params.v_row_stride, params.v_head_stride, _1{})); + Tensor gV = local_tile(mV(_, bidh / params.h_h_k_ratio, _), Shape, Int>{}, + make_coord(_, 0)); // (kBlockN, kHeadDim, nblocksN) + Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast(params.p_ptr) + row_offset_p), + Shape, Int>{}, + make_stride(params.seqlen_k_rounded, _1{})); + + Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutQ{}); + // Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem; + Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)), + typename Kernel_traits::SmemLayoutKV{}); + Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); + Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); + Tensor sVtNoSwizzle = make_tensor(sV.data().get(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); + + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; + auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); + + Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K, nblocksN) + Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K, nblocksN) + Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + + typename Kernel_traits::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) + + Tensor tSgS = thr_mma.partition_C(gP); + + Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K + + // + // Copy Atom retiling + // + + auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); + // if (cute::thread0()) {smem_thr_copy_Q.print_all();} + Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); + // if (cute::thread0()) {print(tSsQ.layout()); printf("\n");} + + auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); + Tensor tSsK = smem_thr_copy_K.partition_S(sK); + + auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); + auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); + Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); + + // + // PREDICATES + // + + // // Allocate predicate tensors for m and n + // Tensor tQpQ = make_tensor(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{}); + // Tensor tKVpKV = make_tensor(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{}); + + // Construct identity layout for sQ and sK + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + // Tensor tScQ = thr_mma.partition_A(cQ); // (MMA,MMA_M,MMA_K) + // if (cute::thread0()) { + // print(tScQ.layout()); printf("\n"); + // for (int i = 0; i < size(tScQ); ++i) { + // printf("%d ", get<0>(tScQ(i))); + // } + // printf("\n"); + // for (int i = 0; i < size(tScQ); ++i) { + // printf("%d ", get<1>(tScQ(i))); + // } + // printf("\n"); + // } + + // Repeat the partitioning with identity layouts + Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + + // Allocate predicate tensors for k + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); + + // Set predicates for k bounds + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; } + #pragma unroll + for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } + } + + // Prologue + + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - m_block * kBlockM); + if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); } + + // // if (cute::thread(1, 0)) { print(tQsQ); } + // // Tensor sQNoSwizzle = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutQNoSwizzle{}); + // // if (cute::thread0()) { print(sQNoSwizzle); } + + if (Kernel_traits::Share_Q_K_smem) { + flash::cp_async_wait<0>(); + __syncthreads(); + Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M + cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); + __syncthreads(); + } + + int n_block = n_block_max - 1; + // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. + flash::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block), tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN); + cute::cp_async_fence(); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); } + // __syncthreads(); + + if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) { + flash::cp_async_wait<1>(); + __syncthreads(); + Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M + cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); + } + + clear(acc_o); + + flash::Softmax<2 * size<1>(acc_o)> softmax; + + const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; + flash::Mask mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope); + + // For performance reason, we separate out two kinds of iterations: + // those that need masking on S, and those that don't. + // We need masking on S for the very last block when K and V has length not multiple of kBlockN. + // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. + // We will have at least 1 "masking" iteration. + + // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to + // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. + constexpr int n_masking_steps = (!Is_causal && !Is_local) + ? 1 + : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); + #pragma unroll + for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + + // Advance gV + if (masking_step > 0) { + flash::copy(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV); + } else { + // Clear the smem tiles to account for predicated off loads + flash::copy( + gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + } + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K + ); + // if (cute::thread0()) { print(acc_s); } + if constexpr (Is_softcap){ + apply_softcap(acc_s, params.softcap); + } + + mask.template apply_mask( + acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + + flash::cp_async_wait<0>(); + __syncthreads(); + if (n_block > n_block_min) { + flash::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // TODO: when we have key_padding_mask we'll need to Check_inf + masking_step == 0 + ? softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2) + : softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); + + // Convert acc_s from fp32 to fp16/bf16 + Tensor rP = flash::convert_type(acc_s); + int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; + int block_col_idx = n_block * (kBlockN / 32); + if (Return_softmax) { + Tensor rP_drop = make_fragment_like(rP); + cute::copy(rP, rP_drop); + dropout.template apply_dropout( + rP_drop, block_row_idx, block_col_idx, kNWarps + ); + cute::copy(rP_drop, tSgS); + tSgS.data() = tSgS.data() + (-kBlockN); + } + if (Is_dropout) { + dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps); + } + + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); + // if (cute::thread0()) { print(tOrP); } + flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + // if (cute::thread0()) { print(scores); } + + // This check is at the end of the loop since we always have at least 1 iteration + if (n_masking_steps > 1 && n_block <= n_block_min) { + --n_block; + break; + } + } + + // These are the iterations where we don't need masking on S + for (; n_block >= n_block_min; --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + flash::copy(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV); + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K + ); + if constexpr (Is_softcap){ + apply_softcap(acc_s, params.softcap); + } + + flash::cp_async_wait<0>(); + __syncthreads(); + if (n_block > n_block_min) { + flash::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + mask.template apply_mask( + acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + + softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); + + Tensor rP = flash::convert_type(acc_s); + int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; + int block_col_idx = n_block * (kBlockN / 32); + if (Return_softmax) { + Tensor rP_drop = make_fragment_like(rP); + cute::copy(rP, rP_drop); + dropout.template apply_dropout( + rP_drop, block_row_idx, block_col_idx, kNWarps + ); + cute::copy(rP_drop, tSgS); + tSgS.data() = tSgS.data() + (-kBlockN); + } + if (Is_dropout) { + dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps); + } + + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); + flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + } + + // Epilogue + + Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax, params.rp_dropout); + + // Convert acc_o from fp32 to fp16/bf16 + Tensor rO = flash::convert_type(acc_o); + Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx); + Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // sO has the same size as sQ, so we don't need to sync here. + if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); } + + cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); + + Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + + binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)), + make_shape(binfo.actual_seqlen_q, params.h, params.d), + make_stride(params.o_row_stride, params.o_head_stride, _1{})); + Tensor gO = local_tile(mO(_, bidh, _), Shape, Int>{}, + make_coord(m_block, 0)); // (kBlockM, kHeadDim) + Tensor gLSE = get_lse_tile(params, bidb, bidh, m_block, binfo); + + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + + __syncthreads(); + + Tensor tOrO = make_tensor(shape(tOgO)); + cute::copy(gmem_tiled_copy_O, tOsO, tOrO); + + Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + static_assert(decltype(size<0>(taccOcO))::value == 4); + // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. + Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + if (get<1>(taccOcO_row(0)) == 0) { + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccOcO_row(mi)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSE(row) = lse(mi); } + } + } + + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(sO), size<1>(sO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + ); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) { + + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNWarps = Kernel_traits::kNWarps; + + using GmemTiledCopyO = std::conditional_t< + !Split, + typename Kernel_traits::GmemTiledCopyO, + typename Kernel_traits::GmemTiledCopyOaccum + >; + using ElementO = std::conditional_t; + + const BlockInfo binfo(params, bidb); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("Is_even_MN = %d, is_cumulativ = %d, seqlen_k_cache = %d, actual_seqlen_k = %d\n", Is_even_MN, params.is_seqlens_k_cumulative, binfo.seqlen_k_cache, binfo.actual_seqlen_k); } + // if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); } + if (m_block * kBlockM >= binfo.actual_seqlen_q) return; + + const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits; + const int n_block_min = !Is_local + ? n_split_idx * n_blocks_per_split + : std::max(n_split_idx * n_blocks_per_split, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); + int n_block_max = std::min(cute::ceil_div(binfo.actual_seqlen_k, kBlockN), (n_split_idx + 1) * n_blocks_per_split); + if (Is_causal || Is_local) { + n_block_max = std::min(n_block_max, + cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); + } + if (n_block_min >= n_block_max) { // This also covers the case where n_block_max <= 0 + // We exit early and write 0 to gOaccum and -inf to gLSEaccum. + // Otherwise we might read OOB elements from gK and gV, + // or get wrong results when we combine gOaccum from different blocks. + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + + m_block * kBlockM) * params.d_rounded; + const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), + Shape, Int>{}, + make_stride(Split ? kHeadDim : params.o_row_stride, _1{})); + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum), + Shape>{}, Stride<_1>{}); + + GmemTiledCopyO gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + clear(tOrOaccum); + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(gOaccum), size<1>(gOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); + Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + ); + #pragma unroll + for (int m = 0; m < size<1>(tOgOaccum); ++m) { + const int row = get<0>(tOcO(0, m, 0)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSEaccum(row) = Split ? -INFINITY : INFINITY; } + } + return; + } + + // We iterate over the blocks in reverse order. This is because the last block is the only one + // that needs masking when we read K and V from global memory. Moreover, iterating in reverse + // might save us 1 register (we just need n_block instead of both n_block and n_block_max). + + // We move K and V to the last block. + const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb]; + const int *block_table = params.block_table == nullptr ? nullptr : params.block_table + bidb * params.block_table_batch_stride; + const int block_table_idx = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN / params.page_block_size; + const int block_table_offset = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN - block_table_idx * params.page_block_size; + const index_t row_offset_k = block_table == nullptr + ? binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache) + + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride + : block_table[block_table_idx] * params.k_batch_stride + block_table_offset * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; + const index_t row_offset_v = block_table == nullptr + ? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) + + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride + : block_table[block_table_idx] * params.v_batch_stride + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + + Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)), + make_shape(binfo.actual_seqlen_q, params.h, params.d), + make_stride(params.q_row_stride, params.q_head_stride, _1{})); + Tensor gQ = local_tile(mQ(_, bidh, _), Shape, Int>{}, + make_coord(m_block, 0)); // (kBlockM, kHeadDim) + Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), + Shape, Int>{}, + make_stride(params.k_row_stride, _1{})); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("k_ptr = %p, row_offset_k = %d, gK_ptr = %p\n", params.k_ptr, row_offset_k, gK.data()); } + Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), + Shape, Int>{}, + make_stride(params.v_row_stride, _1{})); + + Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutQ{}); + Tensor sK = make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutKV{}); + Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); + Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); + Tensor sVtNoSwizzle = make_tensor(sV.data().get(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); + + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; + auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); + + Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) + Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) + Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + + typename Kernel_traits::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) + + Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K + + // + // Copy Atom retiling + // + + auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); + Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); + + auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); + Tensor tSsK = smem_thr_copy_K.partition_S(sK); + + auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); + auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); + Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); + + // PREDICATES + // + + // // Allocate predicate tensors for m and n + // Tensor tQpQ = make_tensor(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{}); + // Tensor tKVpKV = make_tensor(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{}); + + // Construct identity layout for sQ and sK + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + + // Repeat the partitioning with identity layouts + Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + + // Allocate predicate tensors for k + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); + + // Set predicates for k bounds + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; } + #pragma unroll + for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } + } + + // Prologue + + // Copy from Knew to K, optionally apply rotary embedding. + typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary; + auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont; + auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx); + if constexpr (Append_KV) { + // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to + // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe. + // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache. + const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN) * (params.rotary_dim / 2); + Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); + Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); + Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); + Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); + // if (cute::thread(0, 0)) { printf("rotary_cos_ptr = %p, gCos.data() = %p, tRgCos.data() = %p, rotary_dim = %d\n", params.rotary_cos_ptr, gCos.data(), tRgCos.data(), params.rotary_dim); } + // if (cute::thread(8, 0)) { print_tensor(gCos); } + // if (cute::thread(0, 0)) { print_tensor(tRgCos); } + + const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) + + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride; + const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) + + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride; + // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them, + // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64]. + // This maps to accessing the first 64 rows of knew_ptr. + Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.knew_ptr) + + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride), + Shape, Int>{}, + make_stride(params.knew_row_stride, _1{})); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); } + Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.vnew_ptr) + + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride), + Shape, Int>{}, + make_stride(params.vnew_row_stride, _1{})); + Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K) + Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) + + const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN); + auto tKgK_data = tKgK.data(); + auto tVgV_data = tVgV.data(); + for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) { + flash::copy_w_min_idx( + tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN + ); + tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); + if (params.rotary_dim == 0) { + flash::copy_w_min_idx( + tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN + ); + } else { + if (params.is_rotary_interleaved) { + // Don't clear OOB_K because we're writing to global memory + flash::copy_rotary_interleaved( + tKgKnew, tKgK, tRgCos, tRgSin, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, + binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim + ); + tRgCos.data() = tRgCos.data() + (-int(kBlockN * params.rotary_dim / 2)); + tRgSin.data() = tRgSin.data() + (-int(kBlockN * params.rotary_dim / 2)); + } else { + // Don't clear OOB_K because we're writing to global memory + flash::copy_rotary_contiguous( + tKgKnew, tKgK, tRgCosCont, tRgSinCont, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, + binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim + ); + tRgCosCont.data() = tRgCosCont.data() + (-int(kBlockN * params.rotary_dim / 2)); + tRgSinCont.data() = tRgSinCont.data() + (-int(kBlockN * params.rotary_dim / 2)); + + } + } + tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); + if (block_table == nullptr) { + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + } else { + if (n_block > n_block_copy_min) { + const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; + const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; + const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; + const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; + const int table_diff = block_table[block_table_idx_next] - block_table[block_table_idx_cur]; + const int offset_diff = block_table_offset_next - block_table_offset_cur; + tVgV.data() = tVgV.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride; + tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride; + } + } + } + // Need this before we can read in K again, so that we'll see the updated K values. + __syncthreads(); + tKgK.data() = tKgK_data; + tVgV.data() = tVgV_data; + } + + // Read Q from gmem to smem, optionally apply rotary embedding. + if (!Append_KV || params.rotary_dim == 0) { + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - m_block * kBlockM); + } else { + const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2); + // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache. + // We do this by setting the row stride of gCos / gSin to 0. + Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); + Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); + Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); + Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); + if (params.is_rotary_interleaved) { + flash::copy_rotary_interleaved( + tQgQ, tQsQ, tRgCos, tRgSin, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, + 0, params.d, params.rotary_dim + ); + } else { + flash::copy_rotary_contiguous( + tQgQ, tQsQ, tRgCosCont, tRgSinCont, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, + 0, params.d, params.rotary_dim + ); + } + } + + int n_block = n_block_max - 1; + // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN); + cute::cp_async_fence(); + + // flash::cp_async_wait<0>(); + // __syncthreads(); + // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tKsK); } + // __syncthreads(); + + clear(acc_o); + + flash::Softmax<2 * size<1>(acc_o)> softmax; + + const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; + flash::Mask mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope); + + // For performance reason, we separate out two kinds of iterations: + // those that need masking on S, and those that don't. + // We need masking on S for the very last block when K and V has length not multiple of kBlockN. + // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. + // We will have at least 1 "masking" iteration. + + // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to + // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. + constexpr int n_masking_steps = (!Is_causal && !Is_local) + ? 1 + : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); + #pragma unroll + for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + + // Advance gV + if (masking_step > 0) { + if (block_table == nullptr) { + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + } else { + const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size; + const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size; + const int block_table_idx_next = n_block * kBlockN / params.page_block_size; + const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size; + tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride; + } + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + } else { + // Clear the smem tiles to account for predicated off loads + flash::copy( + gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + } + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K + ); + // if (cute::thread0()) { print(acc_s); } + if constexpr (Is_softcap){ + apply_softcap(acc_s, params.softcap); + } + + + mask.template apply_mask( + acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + + flash::cp_async_wait<0>(); + __syncthreads(); + // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); } + // __syncthreads(); + + if (n_block > n_block_min) { + // Advance gK + if (block_table == nullptr) { + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + } else { + const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; + const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; + const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; + const int block_table_offset_next =(n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; + tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; + } + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // We have key_padding_mask so we'll need to Check_inf + masking_step == 0 + ? softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2) + : softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); + // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); } + + // Convert acc_s from fp32 to fp16/bf16 + Tensor rP = flash::convert_type(acc_s); + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); + + flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + + // This check is at the end of the loop since we always have at least 1 iteration + if (n_masking_steps > 1 && n_block <= n_block_min) { + --n_block; + break; + } + } + + // These are the iterations where we don't need masking on S + for (; n_block >= n_block_min; --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + // Advance gV + if (block_table == nullptr) { + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + } else { + const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size; + const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size; + const int block_table_idx_next = n_block * kBlockN / params.page_block_size; + const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size; + tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride; + } + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K + ); + if constexpr (Is_softcap){ + apply_softcap(acc_s, params.softcap); + } + + flash::cp_async_wait<0>(); + __syncthreads(); + if (n_block > n_block_min) { + // Advance gK + if (block_table == nullptr) { + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + } else { + const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; + const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; + const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; + const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; + tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; + } + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + mask.template apply_mask( + acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); + + Tensor rP = flash::convert_type(acc_s); + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); + + flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + } + + // Epilogue + + Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax); + // if (cute::thread0()) { print(lse); } + + Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + using SmemTiledCopyO = std::conditional_t< + !Split, + typename Kernel_traits::SmemCopyAtomO, + typename Kernel_traits::SmemCopyAtomOaccum + >; + auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma); + auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor rO = flash::convert_type(acc_o); + Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // sOaccum is larger than sQ, so we need to syncthreads here + // TODO: allocate enough smem for sOaccum + if constexpr (Split) { __syncthreads(); } + + cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum); + + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + + m_block * kBlockM) * params.d_rounded; + const index_t row_offset_lseaccum = (Split || !params.unpadded_lse ? + ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q : bidh * params.total_q + binfo.q_offset(params.seqlen_q, 1, bidb) + ) + m_block * kBlockM; + + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), + Shape, Int>{}, + make_stride(Split ? kHeadDim : params.o_row_stride, _1{})); + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum), + Shape>{}, Stride<_1>{}); + // if (tidx == 0) { printf("row_offset_o = %d, bidh = %d, gOaccum = %p\n", row_offset_o, bidh, gOaccum.data()); } + + GmemTiledCopyO gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); + + __syncthreads(); + + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum); + + Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + static_assert(decltype(size<0>(taccOcO))::value == 4); + // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. + Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + if (get<1>(taccOcO_row(0)) == 0) { + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccOcO_row(mi)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); } + } + } + + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + ); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void compute_attn(const Params ¶ms) { + const int m_block = blockIdx.x; + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.z; + + // We want the fwd and bwd to generate the same dropout pattern (RNG), without restricting + // them to have the same number of threads or have to traverse the attention matrix + // in the same order. + // In the Philox RNG, we use the offset to store the batch, head, and the lane id + // (within a warp). We use the subsequence to store the location of the 16 x 32 blocks within + // the attention matrix. This way, as long as we have the batch, head, and the location of + // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. + + flash::compute_attn_1rowblock(params, bidb, bidh, m_block); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void compute_attn_splitkv(const Params ¶ms) { + const int m_block = blockIdx.x; + // The block index for the batch. + const int bidb = Split ? blockIdx.z / params.h : blockIdx.y; + // The block index for the head. + const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z; + const int n_split_idx = Split ? blockIdx.y : 0; + const int num_n_splits = Split ? gridDim.y : 1; + flash::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + constexpr int kMaxSplits = 1 << Log_max_splits; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNThreads = Kernel_traits::kNThreads; + + static_assert(kMaxSplits <= 128, "kMaxSplits must be <= 128"); + static_assert(kBlockM == 4 || kBlockM == 8 || kBlockM == 16 || kBlockM == 32, "kBlockM must be 4, 8, 16 or 32"); + static_assert(kNThreads == 128, "We assume that each block has 128 threads"); + + // Shared memory. + // kBlockM + 1 instead of kBlockM to reduce bank conflicts. + __shared__ ElementAccum sLSE[kMaxSplits][kBlockM + 1]; + + // The thread and block index. + const int tidx = threadIdx.x; + const int bidx = blockIdx.x; + + const index_t lse_size = params.b * params.h * params.seqlen_q; + + const index_t row_offset_lse = bidx * kBlockM; + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lseaccum_ptr) + row_offset_lse), + Shape, Int>{}, + make_stride(lse_size, _1{})); + + // LSE format is different depending on params.unpadded_lse and params.seqlenq_ngroups_swapped, see comment in get_lse_tile. + // This tensor's layout maps row_offset_lse to {bidb, bidh, q_offset}. + Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + Shape>{}, Stride<_1>{}); + + // This layout maps row_offset_lse to {bidh, q_offset, bidb} or {bidh, bidb, q_offset}. + Layout flat_layout = make_layout(lse_size); + Layout orig_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b)); + auto transposed_stride = params.seqlenq_ngroups_swapped ? make_stride(params.b, params.seqlen_q * params.b, 1) : make_stride(1, params.seqlen_q * params.b, params.seqlen_q); + Layout remapped_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b), transposed_stride); + Layout final_layout = cute::composition(remapped_layout, cute::composition(orig_layout, flat_layout)); + + Tensor gLSE_unpadded = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr)), final_layout); + + constexpr int kNLsePerThread = (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads; + + // Read the LSE values from gmem and store them in shared memory, then transpose them. + constexpr int kRowsPerLoadLSE = kNThreads / kBlockM; + #pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadLSE + tidx / kBlockM; + const int col = tidx % kBlockM; + ElementAccum lse = (row < params.num_splits && col < lse_size - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY; + if (row < kMaxSplits) { sLSE[row][col] = lse; } + // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse); } + } + // if (bidx == 1 && tidx < 32) { printf("tidx = %d, row_offset_lse = %d, lse = %f\n", tidx, row_offset_lse, lse_accum(0)); } + __syncthreads(); + Tensor lse_accum = make_tensor(Shape>{}); + constexpr int kRowsPerLoadTranspose = std::min(kRowsPerLoadLSE, kMaxSplits); + // To make sure that kMaxSplits is within 1 warp: we decide how many elements within kMaxSplits + // each thread should hold. If kMaxSplits = 16, then each thread holds 2 elements (128 threads, + // kBlockM rows, so each time we load we can load 128 / kBlockM rows). + // constexpr int kThreadsPerSplit = kMaxSplits / kRowsPerLoadTranspose; + // static_assert(kThreadsPerSplit <= 32); + static_assert(kRowsPerLoadTranspose <= 32); + static_assert(kNLsePerThread * kRowsPerLoadTranspose <= kMaxSplits); + #pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; + const int col = tidx / kRowsPerLoadTranspose; + lse_accum(l) = (row < kMaxSplits && col < kBlockM) ? sLSE[row][col] : -INFINITY; + // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); } + } + + // Compute the logsumexp of the LSE along the split dimension. + ElementAccum lse_max = lse_accum(0); + #pragma unroll + for (int l = 1; l < kNLsePerThread; ++l) { lse_max = max(lse_max, lse_accum(l)); } + MaxOp max_op; + lse_max = Allreduce::run(lse_max, max_op); + lse_max = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf + float lse_sum = expf(lse_accum(0) - lse_max); + #pragma unroll + for (int l = 1; l < kNLsePerThread; ++l) { lse_sum += expf(lse_accum(l) - lse_max); } + SumOp sum_op; + lse_sum = Allreduce::run(lse_sum, sum_op); + // For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise + // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum. + ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max; + // if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); } + if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) { + if (params.unpadded_lse) { + const index_t lse_offset = row_offset_lse + tidx / kRowsPerLoadTranspose; + if (lse_offset < lse_size) { + gLSE_unpadded(lse_offset) = lse_logsum; + } + } else { + gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum; + } + } + // Store the scales exp(lse - lse_logsum) in shared memory. + #pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; + const int col = tidx / kRowsPerLoadTranspose; + if (row < params.num_splits && col < kBlockM) { sLSE[row][col] = expf(lse_accum(l) - lse_logsum); } + } + __syncthreads(); + + const index_t row_offset_oaccum = bidx * kBlockM * params.d_rounded; + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + row_offset_oaccum), + Shape, Int>{}, + Stride, _1>{}); + constexpr int kBlockN = kNThreads / kBlockM; + using GmemLayoutAtomOaccum = Layout, Int>, Stride, _1>>; + using GmemTiledCopyOaccum = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomOaccum{}, + Layout>{})); // Val layout, 4 vals per store + GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum); + Tensor tOrO = make_tensor(shape(tOgOaccum)); + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + clear(tOrO); + + // Predicates + Tensor cOaccum = make_identity_tensor(Shape, Int>{}); + // Repeat the partitioning with identity layouts + Tensor tOcOaccum = gmem_thr_copy_Oaccum.partition_S(cOaccum); + Tensor tOpOaccum = make_tensor(make_shape(size<2>(tOgOaccum))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpOaccum); ++k) { tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d; } + } + // Load Oaccum in then scale and accumulate to O + for (int split = 0; split < params.num_splits; ++split) { + flash::copy( + gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * params.seqlen_q - bidx * kBlockM + ); + #pragma unroll + for (int m = 0; m < size<1>(tOrOaccum); ++m) { + int row = get<0>(tOcOaccum(0, m, 0)); + ElementAccum lse_scale = sLSE[split][row]; + #pragma unroll + for (int k = 0; k < size<2>(tOrOaccum); ++k) { + #pragma unroll + for (int i = 0; i < size<0>(tOrOaccum); ++i) { + tOrO(i, m, k) += lse_scale * tOrOaccum(i, m, k); + } + } + // if (cute::thread0()) { printf("lse_scale = %f, %f\n", sLSE[split][0], sLSE[split][1]); print(tOrOaccum); } + } + tOgOaccum.data() = tOgOaccum.data() + params.b * params.h * params.seqlen_q * params.d_rounded; + } + // if (cute::thread0()) { print_tensor(tOrO); } + + Tensor rO = flash::convert_type(tOrO); + // Write to gO + #pragma unroll + for (int m = 0; m < size<1>(rO); ++m) { + const int idx = bidx * kBlockM + get<0>(tOcOaccum(0, m, 0)); + if (idx < params.b * params.h * params.seqlen_q) { + const int batch_idx = idx / (params.h * params.seqlen_q); + const int head_idx = (idx - batch_idx * (params.h * params.seqlen_q)) / params.seqlen_q; + // The index to the rows of Q + const int row = idx - batch_idx * (params.h * params.seqlen_q) - head_idx * params.seqlen_q; + auto o_ptr = reinterpret_cast(params.o_ptr) + batch_idx * params.o_batch_stride + + head_idx * params.o_head_stride + row * params.o_row_stride; + #pragma unroll + for (int k = 0; k < size<2>(rO); ++k) { + if (Is_even_K || tOpOaccum(k)) { + const int col = get<1>(tOcOaccum(0, m, k)); + Tensor gO = make_tensor(make_gmem_ptr(o_ptr + col), + Shape(rO))::value>>{}, Stride<_1>{}); + // TODO: Should check if this is using vectorized store, but it seems pretty fast + copy(rO(_, m, k), gO); + // if (bidx == 0 && tidx == 0) { printf("tidx = %d, idx = %d, batch_idx = %d, head_idx = %d, row = %d, col = %d\n", tidx, idx, batch_idx, head_idx, row, col); print(rO(_, m, k)); print(gO); } + // reinterpret_cast(o_ptr)[col / 4] = recast(rO)(0, m, k); + } + } + } + } +} + +} // namespace flash diff --git a/mcandle/mcandle-flash-attn/kernels/flash_fwd_launch_template.h b/mcandle/mcandle-flash-attn/kernels/flash_fwd_launch_template.h new file mode 100644 index 000000000..29918c87c --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/flash_fwd_launch_template.h @@ -0,0 +1,364 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +// #include + +#include "error.h" +#include "static_switch.h" +#include "flash.h" +#include "flash_fwd_kernel.h" + +// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#define ARCH_SUPPORTS_FLASH +#define KERNEL_PARAM_MODIFIER __grid_constant__ +#else +#define KERNEL_PARAM_MODIFIER +#endif + +// Define a macro for unsupported architecture handling to centralize the error message +#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!"); + +// Use a macro to clean up kernel definitions +#define DEFINE_FLASH_FORWARD_KERNEL(kernelName, ...) \ +template \ +__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params) + +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) { + #if defined(ARCH_SUPPORTS_FLASH) + static_assert(!(Is_causal && Is_local)); // Enforce constraints + flash::compute_attn(params); + #else + FLASH_UNSUPPORTED_ARCH + #endif +} + +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV) { + #if defined(ARCH_SUPPORTS_FLASH) + flash::compute_attn_splitkv(params); + #else + FLASH_UNSUPPORTED_ARCH + #endif +} + +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int Log_max_splits, bool Is_even_K) { + static_assert(Log_max_splits >= 1); + flash::combine_attn_seqk_parallel(params); +} + +template +void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr size_t smem_size = Kernel_traits::kSmemSize; + // printf("smem_size = %d\n", smem_size); + + // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. + // https://github.com/kokkos/kokkos-kernels/issues/349 + // https://github.com/HazyResearch/flash-attention/issues/21 + + const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; + dim3 grid(num_m_block, params.b, params.h); + const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; + const bool is_even_K = params.d == Kernel_traits::kHeadDim; + const bool return_softmax = params.p_ptr != nullptr; + BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { + EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { + LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { + BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { + ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { + SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { + // Will only return softmax if dropout, to reduce compilation time. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If return_softmax, set IsEvenMNConst to false to reduce number of templates + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates + // If Is_local, set Is_causal to false + auto kernel = &flash_fwd_kernel; + // auto kernel = &flash_fwd_kernel; + // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); + // auto kernel = &flash_fwd_kernel; + if (smem_size >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + // int ctas_per_sm; + // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); + }); + }); + }); +} + +template +void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + static_assert(!Kernel_traits::Is_Q_in_regs, "SplitKV implementation does not support Is_Q_in_regs"); + static_assert(!Kernel_traits::Share_Q_K_smem, "SplitKV implementation does not support Share_Q_K_smem"); + constexpr size_t smem_size = Kernel_traits::kSmemSize; + const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; + dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.b, params.num_splits > 1 ? params.b * params.h : params.h); + const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; + const bool is_even_K = params.d == Kernel_traits::kHeadDim; + BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { + EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { + LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { + BOOL_SWITCH(params.num_splits > 1, Split, [&] { + BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { + ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { + SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { + // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If Is_local, set Is_causal to false + auto kernel = &flash_fwd_splitkv_kernel; + // auto kernel = &flash_fwd_splitkv_kernel; + // auto kernel = &flash_fwd_splitkv_kernel; + if (smem_size >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); + }); + }); + }); + }); + if (params.num_splits > 1) { + // We want kBlockM to be as small as possible for more parallelism. + // With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4. + // If headdim is divisible by 64, then we set kBlockM = 8, etc. + constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16); + dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM); + EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { + if (params.num_splits <= 2) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 4) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 8) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 16) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 32) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 64) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 128) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + } +} + +template +void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int kBlockM = 64; // Fixed for all head dimensions + // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256, + // and for headdim 192 with block size 64 x 128. + // Also for headdim 160 with block size 64 x 128 after the rotary addition. + constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); + run_flash_splitkv_fwd, Is_causal>(params, stream); +} + +template +void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 32; + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + if constexpr(!Is_dropout) { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + }); +} + +template +void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 64; + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + if constexpr(!Is_dropout) { + // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower + // Using block size (64 x 256) is 27% slower for seqlen=2k + // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + }); +} + +inline bool cuda_is_sm8x() { + // dprops = at::cuda::getCurrentDeviceProperties(); + // return dprops->major == 8 && dprops->minor > 0; + return false; +} + +template +void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 96; + bool is_sm8x = cuda_is_sm8x(); + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + if (is_sm8x) { + if constexpr(!Is_causal) { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // These two are always slower + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + }); +} + +template +void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 128; + bool is_sm8x = cuda_is_sm8x(); + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + if constexpr(!Is_dropout) { + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM. + if (is_sm8x) { + if constexpr(!Is_causal) { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // 1st ones are good for H100, A100 + // 2nd one is good for A6000 bc we get slightly better occupancy + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + }); +} + +template +void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 160; + bool is_sm8x = cuda_is_sm8x(); + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + // For A100, H100, 128 x 32 is the fastest. + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + // and 128 x 64 with 8 warps is the fastest for non-causal. + if (is_sm8x) { + if constexpr(!Is_causal) { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + }); +} + +template +void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 192; + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + if constexpr(!Is_dropout) { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + }); +} + +template +void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 224; + int device; + cudaGetDevice(&device); + int max_smem_per_block; + cudaError status_ = cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + if (status_ != cudaSuccess) { + C10_CUDA_CHECK(status_); + } + // printf("max_smem_per_block = %d\n", max_smem_per_block); + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // We can't do 128 x 32 with 8 warps because with headdim 224, kBlockKSmem = 32. + // If we have N = 32, there are only 1024 elements to load at once, where each load + // is 8 elements. This means we can only use 128 threads and not 256 threads. + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + }); +} + +template +void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 256; + int device; + cudaGetDevice(&device); + int max_smem_per_sm, max_smem_per_block; + cudaError status_ = cudaDeviceGetAttribute( + &max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device); + status_ = cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + if (status_ != cudaSuccess) { + C10_CUDA_CHECK(status_); + } + // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block); + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + // For A100, we want to run with 128 x 64 (128KB smem). + // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM. + if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + // 64 KB + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // 96 KB + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + }); +} diff --git a/mcandle/mcandle-flash-attn/kernels/kernel_helpers.h b/mcandle/mcandle-flash-attn/kernels/kernel_helpers.h new file mode 100644 index 000000000..22e40cc4a --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/kernel_helpers.h @@ -0,0 +1,50 @@ +// This header is not specific to our application and you'll probably want +// something like this for any extension you're building. This includes the +// infrastructure needed to serialize descriptors that are used with the +// "opaque" parameter of the GPU custom call. In our example we'll use this +// parameter to pass the size of our problem. + +#ifndef _GPU_OPS_KERNEL_HELPERS_H_ +#define _GPU_OPS_KERNEL_HELPERS_H_ + +#include +#include +#include +#include + +#define JAX_APEX_WARP_SIZE 32 + +namespace gpu_ops { + +// https://en.cppreference.com/w/cpp/numeric/bit_cast +template +typename std::enable_if::value && + std::is_trivially_copyable::value, + To>::type +bit_cast(const From &src) noexcept { + static_assert(std::is_trivially_constructible::value, + "This implementation additionally requires destination type to " + "be trivially constructible"); + + To dst; + memcpy(&dst, &src, sizeof(To)); + return dst; +} + +template std::string PackDescriptorAsString(const T &descriptor) { + return std::string(bit_cast(&descriptor), sizeof(T)); +} + +template +const T *UnpackDescriptor(const char *opaque, std::size_t opaque_len) { + if (opaque_len != sizeof(T)) { + throw std::runtime_error("Invalid opaque object size"); + } + return bit_cast(opaque); +} + +} // namespace gpu_ops + +#endif + diff --git a/mcandle/mcandle-flash-attn/kernels/kernel_traits.h b/mcandle/mcandle-flash-attn/kernels/kernel_traits.h new file mode 100644 index 000000000..5a7b74911 --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/kernel_traits.h @@ -0,0 +1,344 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cute/tensor.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/layout/layout.h" +#include + +using namespace cute; + +template +struct Flash_kernel_traits { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using Element = elem_type; + static constexpr bool Has_cp_async = true; +#else + using Element = cutlass::half_t; + static constexpr bool Has_cp_async = false; +#endif + + using ElementAccum = float; + using index_t = int64_t; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using MMA_Atom_Arch = std::conditional_t< + std::is_same_v, + MMA_Atom, + MMA_Atom + >; +#else + using MMA_Atom_Arch = MMA_Atom; +#endif + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#else + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#endif +}; + +// If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true +template > +struct Flash_fwd_kernel_traits : public Base { + using Element = typename Base::Element; + using ElementAccum = typename Base::ElementAccum; + using index_t = typename Base::index_t; + static constexpr bool Has_cp_async = Base::Has_cp_async; + using SmemCopyAtom = typename Base::SmemCopyAtom; + using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; + + static constexpr bool Share_Q_K_smem = Share_Q_K_smem_; + static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem; + + // The number of threads. + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + using TiledMma = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group + Tile, _16, _16>>; + + using SmemLayoutAtomQ = decltype( + composition(Swizzle{}, + // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 + Layout>, + Stride, _1>>{})); + using SmemLayoutQ = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + using SmemLayoutKV = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + // https://github.com/ColfaxResearch/cutlass-kernels/blob/a222587e6d59b93ba704853d3946fb686d8b8892/src/fmha/fmha_forward.cu#L434 + using SmemLayoutVtransposed = decltype( + composition(SmemLayoutKV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{})); + + using SmemLayoutAtomO = decltype( + composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutO = decltype(tile_to_shape( + SmemLayoutAtomO{}, + Shape, Int>{})); + using SmemCopyAtomO = Copy_Atom; + using SmemCopyAtomOaccum = Copy_Atom; + + static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element); + static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); + static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize; + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts. + // For example, for d=128, smem is split into 2 "pages", each page takes care of columns + // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem, + // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page, + // to the same banks. + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + + // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading + // from the same address by the same threadblock. This is slightly faster. + using Gmem_copy_struct = std::conditional_t< + Has_cp_async, + SM80_CP_ASYNC_CACHEGLOBAL, + DefaultCopy + >; + using GmemTiledCopyQKV = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per read + using GmemTiledCopyO = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + + using GmemLayoutAtomOaccum = std::conditional_t< + kBlockKSmem == 32, + Layout, // Thread layout, 8 threads per row + Stride< _8, _1>>, + Layout, // Thread layout, 16 threads per row + Stride< _16, _1>> + >; + using GmemTiledCopyOaccum = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomOaccum{}, + Layout>{})); // Val layout, 4 vals per store + using GmemLayoutAtomRotcossin = GmemLayoutAtom; + using GmemTiledCopyRotcossin = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtomRotcossin{}, + Layout>{})); // Val layout, 4 vals per load + using GmemTiledCopyRotcossinCont = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomRotcossin{}, + Layout>{})); // Val layout, 8 vals per load +}; + +// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue. +// No_double_buffer is another option to reduce smem usage, but will slow things down. +template > +struct Flash_bwd_kernel_traits : public Base { + using Element = typename Base::Element; + using ElementAccum = typename Base::ElementAccum; + using index_t = typename Base::index_t; + static constexpr bool Has_cp_async = Base::Has_cp_async; + using SmemCopyAtom = typename Base::SmemCopyAtom; + using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; + + static constexpr bool Is_V_in_regs = Is_V_in_regs_; + static constexpr bool No_double_buffer = No_double_buffer_; + + // The number of threads. + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + static constexpr int AtomLayoutMSdP = AtomLayoutMSdP_; + static_assert(kNWarps % AtomLayoutMSdP == 0); + static_assert(kNWarps % AtomLayoutNdKV == 0); + static_assert(kNWarps % AtomLayoutMdQ == 0); + + using TiledMmaSdP = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout, Int, _1>>, + Tile, Int<16 * kNWarps / AtomLayoutMSdP>, _16>>; + + using TiledMmadKV = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout, Int, _1>>, + Tile, Int<16 * kNWarps / AtomLayoutNdKV>, _16>>; + + using TiledMmadQ = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout, Int, _1>>, // 2x4x1 or 4x2x1 thread group + Tile, Int<16 * kNWarps / AtomLayoutMdQ>, _16>>; + + using SmemLayoutAtomQdO = decltype( + composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutQdO = decltype(tile_to_shape( + SmemLayoutAtomQdO{}, + make_shape(Int{}, Int{}))); + + using SmemLayoutAtomKV = decltype( + composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutKV = decltype(tile_to_shape( + // SmemLayoutAtomQdO{}, + SmemLayoutAtomKV{}, + make_shape(Int{}, Int{}))); + + using SmemLayoutKtransposed = decltype( + composition(SmemLayoutKV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + using SmemLayoutKtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutKtransposed{})); + + // TODO: generalize to other values of kBlockN + // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 + // static constexpr int kPBlockN = kBlockN; + // Temporarily disabling this for hdim 256 on sm86 and sm89 + // static_assert(kBlockN >= 64); + static_assert(kBlockN >= 32); + // TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest. + static constexpr int kPBlockN = kBlockN >= 64 ? 64 : 32; + static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64); + // static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3); + static constexpr int kSwizzlePdS = 3; + using SmemLayoutAtomPdS = decltype( + composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutPdS = decltype(tile_to_shape( + SmemLayoutAtomPdS{}, + make_shape(Int{}, Int{}))); + using SmemLayoutPdStransposed = decltype( + composition(SmemLayoutPdS{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{})); + + using SmemCopyAtomPdS = Copy_Atom; + + using SmemLayoutQdOtransposed = decltype( + composition(SmemLayoutQdO{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + using SmemLayoutQdOtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutQdOtransposed{})); + + using SmemLayoutAtomdKV = decltype( + composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutdKV = decltype(tile_to_shape( + SmemLayoutAtomdKV{}, + make_shape(Int{}, Int{}))); + using SmemCopyAtomdKV = Copy_Atom; + + using SmemLayoutAtomdQ = decltype( + composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutdQ = decltype(tile_to_shape( + SmemLayoutAtomdQ{}, + make_shape(Int{}, Int{}))); + using SmemCopyAtomdQ = Copy_Atom; + + // Double buffer for sQ + static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element); + static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); + static constexpr int kSmemdSSize = size(SmemLayoutPdS{}) * sizeof(Element); + static constexpr int kSmemPSize = size(SmemLayoutPdS{}) * sizeof(Element); + static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element); + static constexpr int kSmemSize = kSmemQdOSize + + (!Is_V_in_regs + ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) + : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize))); + static constexpr int kSmemSize1colblock = kSmemQdOSize + + (!Is_V_in_regs + ? kSmemKVSize + kSmemdSSize + kSmemPSize + : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize)); + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem + // to affect speed in practice. + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + + // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading + // from the same address by the same threadblock. This is slightly faster. + using Gmem_copy_struct = std::conditional_t< + Has_cp_async, + SM80_CP_ASYNC_CACHEGLOBAL, + DefaultCopy + >; + using GmemTiledCopyQKV = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per read + using GmemTiledCopydO = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + using GmemTiledCopydKV = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + using GmemTiledCopydQ = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + using GmemLayoutAtomdQaccum = std::conditional_t< + kBlockKSmem == 32, + Layout, // Thread layout, 8 threads per row + Stride< _8, _1>>, + Layout, // Thread layout, 16 threads per row + Stride< _16, _1>> + >; + using GmemTiledCopydQaccum = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomdQaccum{}, + Layout>{})); // Val layout, 4 vals per store + + using GmemTiledCopydQaccumAtomicAdd = decltype( + make_tiled_copy(Copy_Atom{}, + Layout, // Thread layout, 8 threads per row + Stride<_32, _1>>{}, + Layout>{})); // Val layout, 1 val per store + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/mcandle/mcandle-flash-attn/kernels/kernel_traits_sm90.h b/mcandle/mcandle-flash-attn/kernels/kernel_traits_sm90.h new file mode 100644 index 000000000..e07f38390 --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/kernel_traits_sm90.h @@ -0,0 +1,159 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cute/algorithm/copy.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/layout/layout.h" +#include + +using namespace cute; + +template +struct Flash_kernel_traits_sm90 { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using Element = elem_type; + static constexpr bool Has_cp_async = true; +#else + using Element = cutlass::half_t; + static constexpr bool Has_cp_async = false; +#endif + + using ElementAccum = float; + using index_t = uint32_t; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using MMA_Atom_Arch = std::conditional_t< + std::is_same_v, + MMA_Atom, + MMA_Atom + >; + using ValLayoutMNK = Layout>; +#else + using MMA_Atom_Arch = MMA_Atom; + using ValLayoutMNK = Layout>; +#endif + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#else + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#endif +}; + +template > +struct Flash_fwd_kernel_traits : public Base { + using Element = typename Base::Element; + using ElementAccum = typename Base::ElementAccum; + using index_t = typename Base::index_t; + static constexpr bool Has_cp_async = Base::Has_cp_async; + using SmemCopyAtom = typename Base::SmemCopyAtom; + using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; + + static constexpr bool Share_Q_K_smem = Share_Q_K_smem_; + static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem; + + // The number of threads. + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + using TiledMma = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group + typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + + using SmemLayoutAtomQ = decltype( + composition(Swizzle{}, + // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 + Layout>, + Stride, _1>>{})); + using SmemLayoutQ = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + using SmemLayoutKV = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + using SmemLayoutAtomVtransposed = decltype( + composition(Swizzle{}, + // This has to be kBlockN and not 8, otherwise we get wrong results for d=128 + Layout, Int>, + Stride<_1, Int>>{})); + using SmemLayoutVtransposed = decltype(tile_to_shape( + SmemLayoutAtomVtransposed{}, + Shape, Int>{})); + // Maybe the VtransposeNoSwizzle just needs to have the right shape + // And the strides don't matter? + using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn()); + + using SmemLayoutAtomO = decltype( + composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutO = decltype(tile_to_shape( + SmemLayoutAtomO{}, + Shape, Int>{})); + using SmemCopyAtomO = Copy_Atom; + + static constexpr int kSmemQCount = size(SmemLayoutQ{}); + static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; + static constexpr int kSmemQSize = kSmemQCount * sizeof(Element); + static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); + static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize; + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts. + // For example, for d=128, smem is split into 2 "pages", each page takes care of columns + // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem, + // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page, + // to the same banks. + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + + // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading + // from the same address by the same threadblock. This is slightly faster. + using Gmem_copy_struct = std::conditional_t< + Has_cp_async, + SM80_CP_ASYNC_CACHEGLOBAL, + DefaultCopy + >; + using GmemTiledCopyQKV = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per read + using GmemTiledCopyO = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP"); + using GmemLayoutAtomP = Layout, Int>, + Stride, _1>>; + + using GmemTiledCopyP = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomP{}, + Layout>{})); // Val layout, 8 vals per store + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/mcandle/mcandle-flash-attn/kernels/kernels.h b/mcandle/mcandle-flash-attn/kernels/kernels.h new file mode 100644 index 000000000..20d6605f4 --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/kernels.h @@ -0,0 +1,58 @@ +#ifndef _GPU_OPS_KERNELS_H_ +#define _GPU_OPS_KERNELS_H_ + +#include + +#include +#include + +#include +#include + +namespace gpu_ops { + +struct MHAParams { + uint32_t q_batch_stride; + uint32_t k_batch_stride; + uint32_t v_batch_stride; + uint32_t o_batch_stride; + + uint32_t q_row_stride; + uint32_t k_row_stride; + uint32_t v_row_stride; + uint32_t o_row_stride; + + uint32_t q_head_stride; + uint32_t k_head_stride; + uint32_t v_head_stride; + uint32_t o_head_stride; + + uint32_t b; + uint32_t h; + uint32_t h_k; + uint32_t d; + uint32_t d_rounded; + float softmax_scale; + float softcap; + + uint32_t seqlen_q; + uint32_t seqlen_k; + uint32_t seqlen_q_rounded; + uint32_t seqlen_k_rounded; + + int window_size_left; + int window_size_right; + + int is_causal; + int is_bf16; +}; + +void run_mha_fwd_j(cudaStream_t stream, void **buffers, + const char *opaque, + std::size_t opaque_len); +void run_mha_bwd_j(cudaStream_t stream, void **buffers, + const char *opaque, + std::size_t opaque_len); +} + +#endif diff --git a/mcandle/mcandle-flash-attn/kernels/mask.h b/mcandle/mcandle-flash-attn/kernels/mask.h new file mode 100644 index 000000000..7ba435a37 --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/mask.h @@ -0,0 +1,213 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +namespace flash { + +using namespace cute; + +template +__forceinline__ __device__ void apply_mask(Tensor &tensor, const int max_seqlen_k, + const int col_idx_offset_ = 0) { + // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + if (col_idx >= max_seqlen_k) { + // Without the "make_coord" we get wrong results + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + tensor(mi, make_coord(j, nj)) = -INFINITY; + } + } + } + } +} + +template +__forceinline__ __device__ void apply_mask_local(Tensor &tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset, + const int max_seqlen_q, const int warp_row_stride, + const int window_size_left, const int window_size_right) { + // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); + const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + } + // if (cute::thread0()) { + // printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k); + // print(tensor(make_coord(i, mi), _)); + // // print(tensor(_, j + nj * size<1, 0>(tensor))); + // } + } + } +} + +template +__forceinline__ __device__ void apply_mask_causal(Tensor &tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset, + const int max_seqlen_q, const int warp_row_stride) { + // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0 + apply_mask_local(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset, + max_seqlen_q, warp_row_stride, -1, 0); +} + +template +__forceinline__ __device__ void apply_mask_causal_w_idx( + Tensor &tensor, Tensor const &idx_rowcol, + const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset) +{ + // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N)) + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 2, "Only support 2D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol)); + CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset + get<0>(idx_rowcol(mi, 0))); + #pragma unroll + for (int ni = 0; ni < size<1, 1>(tensor); ++ni) { + if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) { + tensor(mi, ni) = -INFINITY; + } + } + // if (cute::thread0()) { + // printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k); + // print(tensor(_, make_coord(j, ni))); + // // print(tensor(_, j + ni * size<1, 0>(tensor))); + // } + } +} + +template +struct Mask { + + const int max_seqlen_k, max_seqlen_q; + const int window_size_left, window_size_right; + const float alibi_slope; + + __forceinline__ __device__ Mask(const int max_seqlen_k, const int max_seqlen_q, + const int window_size_left, const int window_size_right, + const float alibi_slope=0.f) + : max_seqlen_k(max_seqlen_k) + , max_seqlen_q(max_seqlen_q) + , window_size_left(window_size_left) + , window_size_right(window_size_right) + , alibi_slope(!Has_alibi ? 0.0 : alibi_slope) { + }; + + // Causal_mask: whether this particular iteration needs causal masking + template + __forceinline__ __device__ void apply_mask(Tensor &tensor_, + const int col_idx_offset_, + const int row_idx_offset, + const int warp_row_stride) { + static_assert(!(Causal_mask && Is_local), "Cannot be both causal and local"); + static_assert(Layout::rank == 3, "Only support 3D Tensor"); + static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4"); + static constexpr bool Need_masking = Has_alibi || Causal_mask || Is_local || !Is_even_MN; + // if (cute::thread0()) { printf("Has_alibi = %d, Causal_mask=%d, Is_local=%d, Is_even_MN = %d, Need_masking = %d\n", Has_alibi, Causal_mask, Is_local, Is_even_MN, Need_masking); } + if constexpr (Need_masking) { + // Reshape tensor_ from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_rowcol(tensor_.layout())); + // Do we need both row and column indices, or just column incides? + static constexpr bool Col_idx_only = !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask; + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + if constexpr (Col_idx_only) { + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + // No causal, no local + if constexpr (Has_alibi) { + tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx; + } + if constexpr (!Is_even_MN) { + if (col_idx >= max_seqlen_k) { tensor(mi, make_coord(j, nj)) = -INFINITY; } + } + } + } + } + } else { + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); + const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + if constexpr (Has_alibi) { + if constexpr (Is_causal) { + tensor(make_coord(i, mi), make_coord(j, nj)) += alibi_slope * col_idx; + } else { + tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx); + + } + } + if constexpr (Causal_mask) { + if (col_idx >= col_idx_limit_right) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + if constexpr (Is_local) { + if (col_idx >= col_idx_limit_right || col_idx < col_idx_limit_left) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + if constexpr (!Causal_mask && !Is_local && !Is_even_MN) { + // Causal and Local already handles MN masking + if (col_idx >= max_seqlen_k) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + } + } + } + } + } + } + }; + +}; + +} // namespace flash diff --git a/mcandle/mcandle-flash-attn/kernels/philox.cuh b/mcandle/mcandle-flash-attn/kernels/philox.cuh new file mode 100644 index 000000000..cd7e4d2fa --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/philox.cuh @@ -0,0 +1,51 @@ +// Pytorch also has an implementation of Philox RNG: https://github.com/pytorch/pytorch/blob/8ca3c881db3e3510fcb7725389f6a0633c9b992c/torch/csrc/jit/tensorexpr/cuda_random.h +#pragma once +// Philox CUDA. + +namespace flash { + +struct ull2 { + unsigned long long x; + unsigned long long y; +}; + +__forceinline__ __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) { + uint2 *res; + unsigned long long tmp; + asm ("mul.wide.u32 %0, %1, %2;\n\t" + : "=l"(tmp) + : "r"(a), "r"(b)); + res = (uint2*)(&tmp); + return *res; +} + +__forceinline__ __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) { + constexpr unsigned long kPhiloxSA = 0xD2511F53; + constexpr unsigned long kPhiloxSB = 0xCD9E8D57; + uint2 res0 = mulhilo32(kPhiloxSA, ctr.x); + uint2 res1 = mulhilo32(kPhiloxSB, ctr.z); + uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x}; + return ret; +} + +__forceinline__ __device__ uint4 philox(unsigned long long seed, + unsigned long long subsequence, + unsigned long long offset) { + constexpr unsigned long kPhilox10A = 0x9E3779B9; + constexpr unsigned long kPhilox10B = 0xBB67AE85; + uint2 key = reinterpret_cast(seed); + uint4 counter; + ull2 *tmp = reinterpret_cast(&counter); + tmp->x = offset; + tmp->y = subsequence; + #pragma unroll + for (int i = 0; i < 6; i++) { + counter = philox_single_round(counter, key); + key.x += (kPhilox10A); + key.y += (kPhilox10B); + } + uint4 output = philox_single_round(counter, key); + return output; +} + +} // namespace flash diff --git a/mcandle/mcandle-flash-attn/kernels/rotary.h b/mcandle/mcandle-flash-attn/kernels/rotary.h new file mode 100644 index 000000000..7f1614ad2 --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/rotary.h @@ -0,0 +1,152 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +#include "utils.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void copy_rotary_interleaved(Tensor const &S, + Tensor &D, + Tensor const &Cos, + Tensor const &Sin, + Tensor const &identity_MN, + const int max_MN, const int min_MN, + const int dim, const int rotary_dim) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K + CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); // MMA_K + static_assert(decltype(size<0>(S))::value == decltype(size<0>(Cos))::value * 2); + static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + Tensor rCos = make_fragment_like(Cos); + Tensor rSin = make_fragment_like(Sin); + Tensor rS = make_fragment_like(S); + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { + #pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { + cute::copy(S(_, m, k), rS(_, m, k)); + if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { + cute::copy(Cos(_, m, k), rCos(_, m, k)); + cute::copy(Sin(_, m, k), rSin(_, m, k)); + Tensor S_fp32 = convert_type(rS(_, m, k)); + Tensor cos_fp32 = convert_type(rCos(_, m, k)); + Tensor sin_fp32 = convert_type(rSin(_, m, k)); + #pragma unroll + for (int i = 0; i < size<0>(rS) / 2; ++i) { + float real = S_fp32(2 * i) * cos_fp32(i) - S_fp32(2 * i + 1) * sin_fp32(i); + float imag = S_fp32(2 * i) * sin_fp32(i) + S_fp32(2 * i + 1) * cos_fp32(i); + S_fp32(2 * i) = real; + S_fp32(2 * i + 1) = imag; + } + // Idk but I need to copy for the convert_type to work + Tensor S_fp32_copy = make_fragment_like(S_fp32); + cute::copy(S_fp32, S_fp32_copy); + using T = typename Engine0::value_type; + Tensor S_og_type = convert_type(S_fp32_copy); + cute::copy(S_og_type, rS(_, m, k)); + } + cute::copy(rS(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void copy_rotary_contiguous(Tensor const &S, + Tensor &D, + Tensor const &Cos, + Tensor const &Sin, + Tensor const &identity_MN, + const int max_MN, const int min_MN, + const int dim, const int rotary_dim) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(Cos)); // MMA + CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); + static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + Tensor rCos = make_fragment_like(Cos); + Tensor rSin = make_fragment_like(Sin); + Tensor rS = make_fragment_like(S); + Tensor rS_other = make_fragment_like(rS(_, 0, 0)); + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { + #pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { + cute::copy(S(_, m, k), rS(_, m, k)); + if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { + const bool is_left = get<1>(identity_MN(0, 0, k)) < rotary_dim / 2; + Tensor gS_other = make_tensor(S(_, m, k).data() + (is_left ? rotary_dim / 2 : -rotary_dim / 2), S(_, m, k).layout()); + cute::copy(gS_other, rS_other); + // if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); } + Tensor gCos = make_tensor(Cos(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Cos(_, m, k).layout()); + Tensor gSin = make_tensor(Sin(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Sin(_, m, k).layout()); + cute::copy(gCos, rCos(_, m, k)); + cute::copy(gSin, rSin(_, m, k)); + // if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); } + Tensor S_fp32 = convert_type(rS(_, m, k)); + Tensor S_other_fp32 = convert_type(rS_other); + Tensor cos_fp32 = convert_type(rCos(_, m, k)); + Tensor sin_fp32 = convert_type(rSin(_, m, k)); + #pragma unroll + for (int i = 0; i < size<0>(rS); ++i) { + S_fp32(i) = S_fp32(i) * cos_fp32(i) + S_other_fp32(i) * (is_left ? -sin_fp32(i) : sin_fp32(i)); + } + // Idk but I need to copy for the convert_type to work + Tensor S_fp32_copy = make_fragment_like(S_fp32); + cute::copy(S_fp32, S_fp32_copy); + using T = typename Engine0::value_type; + Tensor S_og_type = convert_type(S_fp32_copy); + cute::copy(S_og_type, rS(_, m, k)); + // if (cute::thread0()) { print_tensor(rS(_, m, k)); } + } + cute::copy(rS(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash diff --git a/mcandle/mcandle-flash-attn/kernels/softmax.h b/mcandle/mcandle-flash-attn/kernels/softmax.h new file mode 100644 index 000000000..ebf1b0979 --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/softmax.h @@ -0,0 +1,188 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +#include + +#include + +#include "philox.cuh" +#include "utils.h" + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ __forceinline__ void thread_reduce_(Tensor const &tensor, Tensor &summary, Operator &op) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); mi++) { + summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); + #pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + summary(mi) = op(summary(mi), tensor(mi, ni)); + } + } +} + +template +__device__ __forceinline__ void quad_allreduce_(Tensor &dst, Tensor &src, Operator &op) { + CUTE_STATIC_ASSERT_V(size(dst) == size(src)); + #pragma unroll + for (int i = 0; i < size(dst); i++){ + dst(i) = Allreduce<4>::run(src(i), op); + } +} + +template +__device__ __forceinline__ void reduce_(Tensor const& tensor, Tensor &summary, Operator &op) { + thread_reduce_(tensor, summary, op); + quad_allreduce_(summary, summary, op); +} + +template +__device__ __forceinline__ void reduce_max(Tensor const& tensor, Tensor &max){ + MaxOp max_op; + reduce_(tensor, max, max_op); +} + +template +__device__ __forceinline__ void reduce_sum(Tensor const& tensor, Tensor &sum){ + SumOp sum_op; + thread_reduce_(tensor, sum, sum_op); +} + +// Apply the exp to all the elements. +template +__forceinline__ __device__ void scale_apply_exp2(Tensor &tensor, Tensor const &max, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + // If we don't have float around M_LOG2E the multiplication is done in fp64. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E)); + #pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + // The following macro will disable the use of fma. + // See: https://github.com/pytorch/pytorch/issues/121558 for more details + // This macro is set in PyTorch and not FlashAttention + #ifdef UNFUSE_FMA + tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled); + #else + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + #endif + } + } +} + +// Apply the exp to all the elements. +template +__forceinline__ __device__ void max_scale_exp2_sum(Tensor &tensor, Tensor &max, Tensor &sum, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + MaxOp max_op; + max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0)); + #pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + max(mi) = max_op(max(mi), tensor(mi, ni)); + } + max(mi) = Allreduce<4>::run(max(mi), max_op); + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale; + sum(mi) = 0; + #pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + sum(mi) += tensor(mi, ni); + } + SumOp sum_op; + sum(mi) = Allreduce<4>::run(sum(mi), sum_op); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax { + + using TensorT = decltype(make_tensor(Shape>{})); + TensorT row_max, row_sum; + + __forceinline__ __device__ Softmax() {}; + + template + __forceinline__ __device__ void softmax_rescale_o(Tensor0 &acc_s, Tensor1 &acc_o, float softmax_scale_log2) { + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + static_assert(decltype(size<0>(scores))::value == kNRows); + if (Is_first) { + flash::template reduce_max(scores, row_max); + flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); + flash::reduce_sum(scores, row_sum); + } else { + Tensor scores_max_prev = make_fragment_like(row_max); + cute::copy(row_max, scores_max_prev); + flash::template reduce_max(scores, row_max); + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); + #pragma unroll + for (int mi = 0; mi < size(row_max); ++mi) { + float scores_max_cur = !Check_inf + ? row_max(mi) + : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi)); + float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); + row_sum(mi) *= scores_scale; + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; } + } + flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); + // We don't do the reduce across threads here since we don't need to use the row_sum. + // We do that reduce at the end when we need to normalize the softmax. + flash::reduce_sum(scores, row_sum); + } + }; + + template + __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) { + SumOp sum_op; + quad_allreduce_(row_sum, row_sum, sum_op); + TensorT lse = make_fragment_like(row_sum); + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); + #pragma unroll + for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { + float sum = row_sum(mi); + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; + lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum); + float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout; + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } + } + return lse; + }; +}; + +} // namespace flash diff --git a/mcandle/mcandle-flash-attn/kernels/static_switch.h b/mcandle/mcandle-flash-attn/kernels/static_switch.h new file mode 100644 index 000000000..20c2afd6c --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/static_switch.h @@ -0,0 +1,117 @@ +// Inspired by +// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h + +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` + +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() + +#ifdef FLASHATTENTION_DISABLE_DROPOUT + #define DROPOUT_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define DROPOUT_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_ALIBI + #define ALIBI_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define ALIBI_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_UNEVEN_K + #define EVENK_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + }() +#else + #define EVENK_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_SOFTCAP + #define SOFTCAP_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define SOFTCAP_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_LOCAL + #define LOCAL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define LOCAL_SWITCH BOOL_SWITCH +#endif + +#define FP16_SWITCH(COND, ...) \ + [&] { \ + if (COND) { \ + using elem_type = cutlass::half_t; \ + return __VA_ARGS__(); \ + } else { \ + using elem_type = cutlass::bfloat16_t; \ + return __VA_ARGS__(); \ + } \ + }() + +#define HEADDIM_SWITCH(HEADDIM, ...) \ + [&] { \ + if (HEADDIM <= 32) { \ + constexpr static int kHeadDim = 32; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 64) { \ + constexpr static int kHeadDim = 64; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 96) { \ + constexpr static int kHeadDim = 96; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 128) { \ + constexpr static int kHeadDim = 128; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 160) { \ + constexpr static int kHeadDim = 160; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 192) { \ + constexpr static int kHeadDim = 192; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 224) { \ + constexpr static int kHeadDim = 224; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 256) { \ + constexpr static int kHeadDim = 256; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/mcandle/mcandle-flash-attn/kernels/utils.h b/mcandle/mcandle-flash-attn/kernels/utils.h new file mode 100644 index 000000000..708aeddfa --- /dev/null +++ b/mcandle/mcandle-flash-attn/kernels/utils.h @@ -0,0 +1,393 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include + +#include + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#include +#endif + +#include + +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace flash { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ uint32_t relu2(const uint32_t x); + +template<> +__forceinline__ __device__ uint32_t relu2(const uint32_t x) { + uint32_t res; + const uint32_t zero = 0u; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); +#else + asm volatile( \ + "{\n" \ + "\t .reg .f16x2 sela;\n" \ + "\t set.gtu.u32.f16x2 sela, %1, %2;\n" \ + "\t and.b32 %0, sela, %1;\n" + "}\n" : "=r"(res) : "r"(x), "r"(zero)); +#endif + return res; +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +template<> +__forceinline__ __device__ uint32_t relu2(const uint32_t x) { + uint32_t res; + const uint32_t zero = 0u; + asm volatile("max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); + return res; +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + +template +__forceinline__ __device__ uint32_t convert_relu2(const float2 x); + +template<> +__forceinline__ __device__ uint32_t convert_relu2(const float2 x) { + uint32_t res; + const uint32_t a = reinterpret_cast(x.x); + const uint32_t b = reinterpret_cast(x.y); + asm volatile("cvt.rn.relu.f16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a)); + return res; +} + +template<> +__forceinline__ __device__ uint32_t convert_relu2(const float2 x) { + uint32_t res; + const uint32_t a = reinterpret_cast(x.x); + const uint32_t b = reinterpret_cast(x.y); + asm volatile("cvt.rn.relu.bf16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a)); + return res; +} + +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MaxOp { +__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; } +}; + +template <> +struct MaxOp { +// This is slightly faster +__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SumOp { +__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static __device__ __forceinline__ T run(T x, Operator &op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +struct Allreduce<2> { +template +static __device__ __forceinline__ T run(T x, Operator &op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; +} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, + Tensor4 const& tCsB, TiledMma tiled_mma, + TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B, + ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) { + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); } + if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); } + #pragma unroll + for (int i = 0; i < size<2>(tCrA); ++i) { + if (i < size<2>(tCrA) - 1) { + if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); } + if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); } + } + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void gemm_rs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB, + TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, + ThrCopy smem_thr_copy_B) { + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); + #pragma unroll + for (int i = 0; i < size<2>(tCrA); ++i) { + if (i < size<2>(tCrA) - 1) { + cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); + } + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) +template +__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) +// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8. +template +__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) { + using X = Underscore; + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{}); + static_assert(mma_shape_K == 8 || mma_shape_K == 16); + if constexpr (mma_shape_K == 8) { + return acc_layout; + } else { + auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) + return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) +template +__forceinline__ __device__ auto convert_layout_acc_dropout(Layout acc_layout) { + using X = Underscore; + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) + return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ auto convert_type(Tensor const &tensor) { + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + // HACK: this requires tensor to be "contiguous" + auto frag = convert_op(*reinterpret_cast *>(tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void relu_(Tensor &tensor) { + constexpr int numel = decltype(size(tensor))::value; + static_assert(numel % 2 == 0); + using value_t = typename Engine::value_type; + // HACK: this requires tensor to be "contiguous" + Tensor tensor_uint32 = recast(tensor); + #pragma unroll + for (int i = 0; i < size(tensor_uint32); ++i) { + tensor_uint32(i) = relu2(tensor_uint32(i)); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction +template +__forceinline__ __device__ auto convert_type_relu(Tensor const &tensor) { + using From_type = typename Engine::value_type; + static_assert(std::is_same_v || std::is_same_v); + static_assert(std::is_same_v); + constexpr int numel = decltype(size(tensor))::value; + static_assert(numel % 2 == 0); +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + // HACK: this requires tensor to be "contiguous" + Tensor tensor_float2 = recast(tensor); + Tensor out_uint32 = make_tensor(tensor_float2.layout()); + #pragma unroll + for (int i = 0; i < size(out_uint32); ++i) { + out_uint32(i) = convert_relu2(tensor_float2(i)); + } + Tensor out = make_tensor(make_rmem_ptr(out_uint32.data()), tensor.layout()); +#else + Tensor out = flash::convert_type(tensor); + flash::relu_(out); +#endif + return out; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Blocks until all but N previous cp.async.commit_group operations have committed. +// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all +// (which is equivalent to commit_group then wait_group 0). +// Instead we just call cp.async.wait_group 0, which is slightly faster. +// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113 +template +CUTE_HOST_DEVICE +void cp_async_wait() { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + asm volatile("cp.async.wait_group %0;\n" :: "n"(N)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor const &S, + Tensor &D, Tensor const &identity_MN, + Tensor const &predicate_K, const int max_MN=0) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + // There's no case where !Clear_OOB_K && Clear_OOB_MN + static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { + #pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || predicate_K(k)) { + cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } else if (Clear_OOB_MN) { + cute::clear(D(_, m, _)); + } + } + // TD [2023-04-13]: Strange that the code below can cause race condition. + // I think it's because the copies are under an if statement. + // if (Is_even_K) { + // #pragma unroll + // for (int m = 0; m < size<1>(S); ++m) { + // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { + // copy(tiled_copy, S(_, m, _), D(_, m, _)); + // } else if (Clear_OOB_MN) { + // clear(D(_, m, _)); + // } + // } + // } else { // It's slightly faster in this case if iterate over K first + // #pragma unroll + // for (int k = 0; k < size<2>(S); ++k) { + // if (predicate_K(k)) { + // #pragma unroll + // for (int m = 0; m < size<1>(S); ++m) { + // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { + // copy(tiled_copy, S(_, m, k), D(_, m, k)); + // } else if (Clear_OOB_MN) { + // clear(D(_, m, k)); + // } + // } + // } else if (Clear_OOB_K) { // There's no case where !Clear_OOB_K && Clear_OOB_MN + // if (Clear_OOB_MN || Is_even_MN) { + // clear(D(_, _, k)); + // } else { + // #pragma unroll + // for (int m = 0; m < size<1>(S); ++m) { + // if (!(Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN)) { + // clear(D(_, m, k)); + // } + // } + // } + // } + // } + // } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void copy_w_min_idx(Tensor const &S, + Tensor &D, Tensor const &identity_MN, + Tensor const &predicate_K, + const int max_MN=0, const int min_MN=0) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); } + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { + // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } + #pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || predicate_K(k)) { + cute::copy(S(_, m, k), D(_, m, k)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash diff --git a/mcandle/mcandle-flash-attn/src/ffi.rs b/mcandle/mcandle-flash-attn/src/ffi.rs new file mode 100644 index 000000000..fe565beae --- /dev/null +++ b/mcandle/mcandle-flash-attn/src/ffi.rs @@ -0,0 +1,51 @@ +use core::ffi::{c_int, c_void}; + +extern "C" { + pub(crate) fn run_mha( + q_ptr: *const c_void, + k_ptr: *const c_void, + v_ptr: *const c_void, + o_ptr: *const c_void, + softmax_lse_ptr: *const c_void, + alibi_slopes_ptr: *const c_void, + + cu_seqlens_q_ptr: *const i32, + cu_seqlens_k_ptr: *const i32, + + q_batch_stride: u32, + k_batch_stride: u32, + v_batch_stride: u32, + o_batch_stride: u32, + alibi_slopes_batch_stride: u32, + + q_row_stride: u32, + k_row_stride: u32, + v_row_stride: u32, + o_row_stride: u32, + + q_head_stride: u32, + k_head_stride: u32, + v_head_stride: u32, + o_head_stride: u32, + + b: u32, + h: u32, + h_k: u32, + d: u32, + d_rounded: u32, + softmax_scale: f32, + softcap: f32, + + seqlen_q: u32, + seqlen_k: u32, + seqlen_q_rounded: u32, + seqlen_k_rounded: u32, + + is_bf16: c_int, + is_causal: c_int, + + window_size_left: c_int, + window_size_right: c_int, + ); + +} diff --git a/mcandle/mcandle-flash-attn/src/lib.rs b/mcandle/mcandle-flash-attn/src/lib.rs new file mode 100644 index 000000000..4483f2099 --- /dev/null +++ b/mcandle/mcandle-flash-attn/src/lib.rs @@ -0,0 +1,1036 @@ +mod ffi; + +use mcandle_core::backend::BackendStorage; +use mcandle_core::cuda_backend::cudarc::driver::DevicePtr; +use mcandle_core::cuda_backend::WrapErr; +use mcandle_core::{CpuStorage, DType, Layout, Result, Shape, Tensor}; +use half::{bf16, f16}; + +pub struct FlashAttn { + pub softmax_scale: f32, + pub softcap: Option, + pub alibi_slopes: Option, + pub window_size_left: Option, + pub window_size_right: Option, +} + +fn round_multiple(x: usize, m: usize) -> usize { + (x + m - 1) / m * m +} + +impl FlashAttn { + fn cuda_fwd_t< + T: mcandle_core::cuda_backend::CudaDType + mcandle_core::cuda_backend::cudarc::driver::DeviceRepr, + >( + &self, + q: &mcandle_core::CudaStorage, + q_l: &Layout, + k: &mcandle_core::CudaStorage, + k_l: &Layout, + v: &mcandle_core::CudaStorage, + v_l: &Layout, + is_bf16: bool, + ) -> Result<(mcandle_core::CudaStorage, Shape)> { + // https://github.com/Dao-AILab/flash-attention/blob/b252072409e69c25f2b9d473cc534e49b24decd2/csrc/flash_attn/flash_api.cpp#L187 + let dev = q.device(); + let out_shape = q_l.shape().clone(); + let out_l = Layout::contiguous(&out_shape); + + let q = q.as_cuda_slice::()?; + let k = k.as_cuda_slice::()?; + let v = v.as_cuda_slice::()?; + let q = q.slice(q_l.start_offset()..); + let k = k.slice(k_l.start_offset()..); + let v = v.slice(v_l.start_offset()..); + + let q_stride = q_l.stride(); + let k_stride = k_l.stride(); + let v_stride = v_l.stride(); + let o_stride = out_l.stride(); + + let q_rank = q_stride.len(); + let k_rank = k_stride.len(); + let v_rank = v_stride.len(); + let o_rank = o_stride.len(); + + if q_rank != 4 || k_rank != 4 || v_rank != 4 { + mcandle_core::bail!( + "flash-attn expects input tensors of rank 4 (q: {q_rank}, k: {k_rank}, v: {v_rank}" + ) + } + if q_stride[q_rank - 1] != 1 { + mcandle_core::bail!("the last dim of q must be contiguous {q_stride:?}") + } + if k_stride[k_rank - 1] != 1 { + mcandle_core::bail!("the last dim of k must be contiguous {k_stride:?}") + } + if v_stride[v_rank - 1] != 1 { + mcandle_core::bail!("the last dim of v must be contiguous {v_stride:?}") + } + + let (b_sz, seqlen_q, num_heads, head_size_og) = q_l.shape().dims4()?; + let (_b_sz, seqlen_k, num_heads_k, _head_size_og) = k_l.shape().dims4()?; + let expected_kv = (b_sz, seqlen_k, num_heads_k, head_size_og); + if expected_kv != k_l.shape().dims4()? { + mcandle_core::bail!("shape mismatch q {:?} and k {:?}", q_l.shape(), k_l.shape()) + } + if expected_kv != v_l.shape().dims4()? { + mcandle_core::bail!("shape mismatch q {:?} and v {:?}", q_l.shape(), v_l.shape()) + } + if head_size_og > 256 { + mcandle_core::bail!("only supports head dimension at most 256 (got {head_size_og})") + } + if head_size_og % 8 != 0 { + // TODO: Handle head sizes that are not a multiple of 8 via some padding. + mcandle_core::bail!("only supports head sizes that are a multiple of 8 (got {head_size_og})") + } + if num_heads % num_heads_k != 0 { + mcandle_core::bail!("number of k/v heads {num_heads_k} must divide number of heads in query {num_heads}") + } + + let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes { + if alibi_slopes.dtype() != DType::F32 { + mcandle_core::bail!( + "DType mismatch alibi_slopes {:?}, expected {:?}", + alibi_slopes.dtype(), + DType::F32 + ); + } + + let (alibi_slopes, alibi_slopes_layout) = alibi_slopes.storage_and_layout(); + + if num_heads != alibi_slopes_layout.shape().dims1()? { + mcandle_core::bail!( + "shape mismatch alibi_slopes {:?}, expected {:?}", + alibi_slopes_layout.shape(), + (num_heads) + ); + } + + let alibi_slopes = match &*alibi_slopes { + mcandle_core::Storage::Cuda(c) => c.as_cuda_slice::()?, + _ => mcandle_core::bail!("alibi_slopes must be a cuda tensor"), + }; + + let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..); + + *alibi_slopes.device_ptr() as *const core::ffi::c_void + } else { + std::ptr::null() + }; + + // if window_size_left > self.max_seqlen_k or None => -1 + let mut window_size_left = self + .window_size_left + .filter(|v| v <= &seqlen_k) + .map(|v| v as i32) + .unwrap_or(-1); + + // if window_size_right > self.max_seqlen_k or None => -1 + let mut window_size_right = self + .window_size_right + .filter(|v| v <= &seqlen_k) + .map(|v| v as i32) + .unwrap_or(-1); + + let head_size = round_multiple(head_size_og, 8); + let head_size_rounded = round_multiple(head_size, 32); + let seqlen_q_rounded = round_multiple(seqlen_q, 128); + let seqlen_k_rounded = round_multiple(seqlen_k, 128); + + let elem_count = out_shape.elem_count(); + let dst = unsafe { dev.alloc::(elem_count) }.w()?; + let softmax_lse = dev + .alloc_zeros::(b_sz * 128 * num_heads * seqlen_q) + .w()?; + + let is_bf16 = if is_bf16 { 1 } else { 0 }; + + // Causal is the special case where window_size_right == 0 and window_size_left < 0. + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. + let is_causal = if window_size_left < 0 && window_size_right == 0 { + 1 + } else { + 0 + }; + if window_size_left < 0 && window_size_right >= 0 { + window_size_left = seqlen_k as i32; + } + if window_size_left >= 0 && window_size_right < 0 { + window_size_right = seqlen_k as i32; + } + + unsafe { + let q_ptr = *q.device_ptr() as *const core::ffi::c_void; + let k_ptr = *k.device_ptr() as *const core::ffi::c_void; + let v_ptr = *v.device_ptr() as *const core::ffi::c_void; + let dst_ptr = *dst.device_ptr() as *const core::ffi::c_void; + let softmax_lse_ptr = *softmax_lse.device_ptr() as *const core::ffi::c_void; + ffi::run_mha( + q_ptr, + k_ptr, + v_ptr, + dst_ptr, + softmax_lse_ptr, + /* alibi_slopes_ptr */ alibi_slopes_ptr, + /* cu_seqlens_q_ptr */ std::ptr::null(), + /* cu_seqlens_k_ptr */ std::ptr::null(), + /* q_batch_stride */ q_stride[0] as u32, + /* k_batch_stride */ k_stride[0] as u32, + /* v_batch_stride */ v_stride[0] as u32, + /* o_batch_stride */ o_stride[0] as u32, + /* alibi_slopes_batch_stride */ 0, + /* q_row_stride */ q_stride[q_rank - 3] as u32, + /* k_row_stride */ k_stride[k_rank - 3] as u32, + /* v_row_stride */ v_stride[v_rank - 3] as u32, + /* o_row_stride */ o_stride[o_rank - 3] as u32, + /* q_head_stride */ q_stride[q_rank - 2] as u32, + /* k_head_stride */ k_stride[k_rank - 2] as u32, + /* v_head_stride */ v_stride[v_rank - 2] as u32, + /* o_head_stride */ o_stride[o_rank - 2] as u32, + /* b */ b_sz as u32, + /* h */ num_heads as u32, + /* h_k */ num_heads_k as u32, + /* d */ head_size as u32, + /* d_rounded */ head_size_rounded as u32, + /* softmax_scale*/ self.softmax_scale, + /* softcap */ self.softcap.unwrap_or(0.0), + /* seqlen_q */ seqlen_q as u32, + /* seqlen_k */ seqlen_k as u32, + /* seqlen_q_rounded */ seqlen_q_rounded as u32, + /* seqlen_k_rounded */ seqlen_k_rounded as u32, + /* is_bf16 */ is_bf16, + /* is_causal */ is_causal, + /* window_size_left */ window_size_left, + /* window_size_right */ window_size_right, + ) + } + + let dst = mcandle_core::CudaStorage::wrap_cuda_slice(dst, dev.clone()); + Ok((dst, out_shape)) + } +} + +impl mcandle_core::CustomOp3 for FlashAttn { + fn name(&self) -> &'static str { + "flash-attn" + } + + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + mcandle_core::bail!("no cpu support for flash-attn") + } + + fn cuda_fwd( + &self, + q: &mcandle_core::CudaStorage, + q_l: &Layout, + k: &mcandle_core::CudaStorage, + k_l: &Layout, + v: &mcandle_core::CudaStorage, + v_l: &Layout, + ) -> Result<(mcandle_core::CudaStorage, Shape)> { + match q.dtype() { + mcandle_core::DType::F16 => self.cuda_fwd_t::(q, q_l, k, k_l, v, v_l, false), + mcandle_core::DType::BF16 => self.cuda_fwd_t::(q, q_l, k, k_l, v, v_l, true), + dt => mcandle_core::bail!("flash-attn is only supported for f16/bf16 ({dt:?})"), + } + } +} + +/// Flash-attention v2 layer. +/// +/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. +/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads +/// than q, the number of heads in k and v has to be divisible by the number of heads in q. +/// +/// # Arguments +/// +/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`. +/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`. +/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`. +/// +/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`. +pub fn flash_attn( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + causal: bool, +) -> Result { + flash_attn_softcap(q, k, v, softmax_scale, None, causal) +} + +/// Equivalent to [`flash_attn`], but with softcap support +pub fn flash_attn_softcap( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + softcap: Option, + causal: bool, +) -> Result { + let window_size_left = None; + let window_size_right = if causal { Some(0) } else { None }; + + let op = FlashAttn { + softmax_scale, + softcap, + alibi_slopes: None, + window_size_left, + window_size_right, + }; + q.apply_op3(k, v, op) +} + +/// Flash-attention v2 layer. +/// +/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. +/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads +/// than q, the number of heads in k and v has to be divisible by the number of heads in q. +/// +/// # Arguments +/// +/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`. +/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`. +/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`. +/// * `window_size_left` - Limit left attention to value tokens. +/// * `window_size_right` - Limit right attention to value tokens. +/// +/// # Causal mask +/// +/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result +/// of `Q @ K^T` +/// +/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`. +pub fn flash_attn_windowed( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + window_size_left: Option, + window_size_right: Option, +) -> Result { + flash_attn_windowed_softcap( + q, + k, + v, + softmax_scale, + None, + window_size_left, + window_size_right, + ) +} + +/// Equivalent to [`flash_attn_windowed`], but with softcap support. +pub fn flash_attn_windowed_softcap( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + softcap: Option, + window_size_left: Option, + window_size_right: Option, +) -> Result { + let op = FlashAttn { + softmax_scale, + softcap, + alibi_slopes: None, + window_size_left, + window_size_right, + }; + q.apply_op3(k, v, op) +} + +/// Flash-attention v2 layer. +/// +/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. +/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads +/// than q, the number of heads in k and v has to be divisible by the number of heads in q. +/// +/// # Arguments +/// +/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`. +/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`. +/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`. +/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`. +/// +/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`. +pub fn flash_attn_alibi( + q: &Tensor, + k: &Tensor, + v: &Tensor, + alibi_slopes: &Tensor, + softmax_scale: f32, + causal: bool, +) -> Result { + flash_attn_alibi_softcap(q, k, v, alibi_slopes, softmax_scale, None, causal) +} + +/// Equivalent to [`flash_attn_alibi`], but with softcap support. +pub fn flash_attn_alibi_softcap( + q: &Tensor, + k: &Tensor, + v: &Tensor, + alibi_slopes: &Tensor, + softmax_scale: f32, + softcap: Option, + causal: bool, +) -> Result { + let window_size_left = None; + let window_size_right = if causal { Some(0) } else { None }; + + let op = FlashAttn { + softmax_scale, + softcap, + alibi_slopes: Some(alibi_slopes.clone()), + window_size_left, + window_size_right, + }; + q.apply_op3(k, v, op) +} + +/// Flash-attention v2 layer. +/// +/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. +/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads +/// than q, the number of heads in k and v has to be divisible by the number of heads in q. +/// +/// # Arguments +/// +/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`. +/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`. +/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`. +/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`. +/// * `window_size_left` - Limit left attention to value tokens. +/// * `window_size_right` - Limit right attention to value tokens. +/// +/// # Causal mask +/// +/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result +/// of `Q @ K^T` +/// +/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`. +pub fn flash_attn_alibi_windowed( + q: &Tensor, + k: &Tensor, + v: &Tensor, + alibi_slopes: &Tensor, + softmax_scale: f32, + window_size_left: Option, + window_size_right: Option, +) -> Result { + let op = FlashAttn { + softmax_scale, + softcap: None, + alibi_slopes: Some(alibi_slopes.clone()), + window_size_left, + window_size_right, + }; + q.apply_op3(k, v, op) +} + +struct FlashAttnVarLen { + pub softmax_scale: f32, + pub softcap: Option, + pub max_seqlen_q: usize, + pub max_seqlen_k: usize, + pub seqlens_q: Tensor, + pub seqlens_k: Tensor, + pub alibi_slopes: Option, + pub window_size_left: Option, + pub window_size_right: Option, +} + +impl FlashAttnVarLen { + fn cuda_fwd_t< + T: mcandle_core::cuda_backend::CudaDType + mcandle_core::cuda_backend::cudarc::driver::DeviceRepr, + >( + &self, + q: &mcandle_core::CudaStorage, + q_l: &Layout, + k: &mcandle_core::CudaStorage, + k_l: &Layout, + v: &mcandle_core::CudaStorage, + v_l: &Layout, + is_bf16: bool, + ) -> Result<(mcandle_core::CudaStorage, Shape)> { + // https://github.com/Dao-AILab/flash-attention/blob/184b992dcb2a0890adaa19eb9b541c3e4f9d2a08/csrc/flash_attn/flash_api.cpp#L327 + let dev = q.device(); + let out_shape = q_l.shape().clone(); + let out_l = Layout::contiguous(&out_shape); + + let (seqlens_q, seqlens_q_layout) = self.seqlens_q.storage_and_layout(); + let seqlens_q = match &*seqlens_q { + mcandle_core::Storage::Cuda(c) => c.as_cuda_slice::()?, // Should be i32! + _ => mcandle_core::bail!("seqlens_q must be a cuda tensor"), + }; + let seqlens_q = match seqlens_q_layout.contiguous_offsets() { + Some((o1, o2)) => seqlens_q.slice(o1..o2), + None => mcandle_core::bail!("seqlens_q has to be contiguous"), + }; + + let (seqlens_k, seqlens_k_layout) = self.seqlens_k.storage_and_layout(); + let seqlens_k = match &*seqlens_k { + mcandle_core::Storage::Cuda(c) => c.as_cuda_slice::()?, // Should be i32! + _ => mcandle_core::bail!("seqlens_k must be a cuda tensor"), + }; + let seqlens_k = match seqlens_k_layout.contiguous_offsets() { + Some((o1, o2)) => seqlens_k.slice(o1..o2), + None => mcandle_core::bail!("seqlens_k has to be contiguous"), + }; + + let q = q.as_cuda_slice::()?; + let k = k.as_cuda_slice::()?; + let v = v.as_cuda_slice::()?; + let q = q.slice(q_l.start_offset()..); + let k = k.slice(k_l.start_offset()..); + let v = v.slice(v_l.start_offset()..); + + let q_stride = q_l.stride(); + let k_stride = k_l.stride(); + let v_stride = v_l.stride(); + let o_stride = out_l.stride(); + + let q_rank = q_stride.len(); + let k_rank = k_stride.len(); + let v_rank = v_stride.len(); + let o_rank = o_stride.len(); + + if q_rank != 3 || k_rank != 3 || v_rank != 3 { + mcandle_core::bail!( + "flash-attn-varlen expects input tensors of rank 3 (q: {q_rank}, k: {k_rank}, v: {v_rank}" + ) + } + if q_stride[q_rank - 1] != 1 { + mcandle_core::bail!("the last dim of q must be contiguous {q_stride:?}") + } + if k_stride[k_rank - 1] != 1 { + mcandle_core::bail!("the last dim of k must be contiguous {k_stride:?}") + } + if v_stride[v_rank - 1] != 1 { + mcandle_core::bail!("the last dim of v must be contiguous {v_stride:?}") + } + + let (_total_q, num_heads, head_size_og) = q_l.shape().dims3()?; + let (total_k, num_heads_k, _head_size_og) = k_l.shape().dims3()?; + let expected_kv = (total_k, num_heads_k, head_size_og); + if expected_kv != k_l.shape().dims3()? { + mcandle_core::bail!("shape mismatch q {:?} and k {:?}", q_l.shape(), k_l.shape()) + } + if expected_kv != v_l.shape().dims3()? { + mcandle_core::bail!("shape mismatch q {:?} and v {:?}", q_l.shape(), v_l.shape()) + } + if head_size_og > 256 { + mcandle_core::bail!("only supports head dimension at most 256 (got {head_size_og})") + } + if head_size_og % 8 != 0 { + // TODO: Handle head sizes that are not a multiple of 8 via some padding. + mcandle_core::bail!("only supports head sizes that are a multiple of 8 (got {head_size_og})") + } + if num_heads % num_heads_k != 0 { + mcandle_core::bail!("number of k/v heads {num_heads_k} must divide number of heads in query {num_heads}") + } + + let nseqlens_q = seqlens_q_layout.shape().dims1()?; + if nseqlens_q < 2 { + mcandle_core::bail!("seqlens_q should have a len >= 2 {nseqlens_q}") + } + let nseqlens_k = seqlens_k_layout.shape().dims1()?; + if nseqlens_k != nseqlens_q { + mcandle_core::bail!("seqlens_q and seqlens_k should have the same number of elements {nseqlens_q} <> {nseqlens_k}") + } + + let batch_size = nseqlens_q - 1; + + let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes { + if alibi_slopes.dtype() != DType::F32 { + mcandle_core::bail!( + "DType mismatch alibi_slopes {:?}, expected {:?}", + alibi_slopes.dtype(), + DType::F32 + ); + } + + let (alibi_slopes, alibi_slopes_layout) = alibi_slopes.storage_and_layout(); + + if num_heads != alibi_slopes_layout.shape().dims1()? { + mcandle_core::bail!( + "shape mismatch alibi_slopes {:?}, expected {:?}", + alibi_slopes_layout.shape(), + (num_heads) + ); + } + + let alibi_slopes = match &*alibi_slopes { + mcandle_core::Storage::Cuda(c) => c.as_cuda_slice::()?, + _ => mcandle_core::bail!("alibi_slopes must be a cuda tensor"), + }; + + let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..); + + *alibi_slopes.device_ptr() as *const core::ffi::c_void + } else { + std::ptr::null() + }; + + // if window_size_left > self.max_seqlen_k or None => -1 + let mut window_size_left = self + .window_size_left + .filter(|v| v <= &self.max_seqlen_k) + .map(|v| v as i32) + .unwrap_or(-1); + + // if window_size_right > self.max_seqlen_k or None => -1 + let mut window_size_right = self + .window_size_right + .filter(|v| v <= &self.max_seqlen_k) + .map(|v| v as i32) + .unwrap_or(-1); + + let head_size = round_multiple(head_size_og, 8); + let head_size_rounded = round_multiple(head_size, 32); + let seqlen_q_rounded = round_multiple(self.max_seqlen_q, 128); + let seqlen_k_rounded = round_multiple(self.max_seqlen_k, 128); + + let elem_count = out_shape.elem_count(); + let dst = unsafe { dev.alloc::(elem_count) }.w()?; + let softmax_lse = dev + .alloc_zeros::(batch_size * num_heads * self.max_seqlen_q) + .w()?; + + let is_bf16 = if is_bf16 { 1 } else { 0 }; + + // Causal is the special case where window_size_right == 0 and window_size_left < 0. + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. + let is_causal = if window_size_left < 0 && window_size_right == 0 { + 1 + } else { + 0 + }; + if window_size_left < 0 && window_size_right >= 0 { + window_size_left = self.max_seqlen_k as i32; + } + if window_size_left >= 0 && window_size_right < 0 { + window_size_right = self.max_seqlen_k as i32; + } + + unsafe { + let q_ptr = *q.device_ptr() as *const core::ffi::c_void; + let k_ptr = *k.device_ptr() as *const core::ffi::c_void; + let v_ptr = *v.device_ptr() as *const core::ffi::c_void; + let dst_ptr = *dst.device_ptr() as *const core::ffi::c_void; + let softmax_lse_ptr = *softmax_lse.device_ptr() as *const core::ffi::c_void; + let seqlens_q_ptr = *seqlens_q.device_ptr() as *const core::ffi::c_int; + let seqlens_k_ptr = *seqlens_k.device_ptr() as *const core::ffi::c_int; + ffi::run_mha( + q_ptr, + k_ptr, + v_ptr, + dst_ptr, + softmax_lse_ptr, + /* alibi_slopes_ptr */ alibi_slopes_ptr, + /* cu_seqlens_q_ptr */ seqlens_q_ptr, + /* cu_seqlens_k_ptr */ seqlens_k_ptr, + /* q_batch_stride */ 0, + /* k_batch_stride */ 0, + /* v_batch_stride */ 0, + /* o_batch_stride */ 0, + /* alibi_slopes_batch_stride */ 0, + /* q_row_stride */ q_stride[q_rank - 3] as u32, + /* k_row_stride */ k_stride[k_rank - 3] as u32, + /* v_row_stride */ v_stride[v_rank - 3] as u32, + /* o_row_stride */ o_stride[o_rank - 3] as u32, + /* q_head_stride */ q_stride[q_rank - 2] as u32, + /* k_head_stride */ k_stride[k_rank - 2] as u32, + /* v_head_stride */ v_stride[v_rank - 2] as u32, + /* o_head_stride */ o_stride[o_rank - 2] as u32, + /* b */ batch_size as u32, + /* h */ num_heads as u32, + /* h_k */ num_heads_k as u32, + /* d */ head_size as u32, + /* d_rounded */ head_size_rounded as u32, + /* softmax_scale*/ self.softmax_scale, + /* softcap */ self.softcap.unwrap_or(0.0), + /* seqlen_q */ self.max_seqlen_q as u32, + /* seqlen_k */ self.max_seqlen_k as u32, + /* seqlen_q_rounded */ seqlen_q_rounded as u32, + /* seqlen_k_rounded */ seqlen_k_rounded as u32, + /* is_bf16 */ is_bf16, + /* is_causal */ is_causal, + /* window_size_left */ window_size_left, + /* window_size_right */ window_size_right, + ) + } + + let dst = mcandle_core::CudaStorage::wrap_cuda_slice(dst, dev.clone()); + Ok((dst, out_shape)) + } +} + +impl mcandle_core::CustomOp3 for FlashAttnVarLen { + fn name(&self) -> &'static str { + "flash-attn-varlen" + } + + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + mcandle_core::bail!("no cpu support for flash-attn") + } + + fn cuda_fwd( + &self, + q: &mcandle_core::CudaStorage, + q_l: &Layout, + k: &mcandle_core::CudaStorage, + k_l: &Layout, + v: &mcandle_core::CudaStorage, + v_l: &Layout, + ) -> Result<(mcandle_core::CudaStorage, Shape)> { + match q.dtype() { + mcandle_core::DType::F16 => self.cuda_fwd_t::(q, q_l, k, k_l, v, v_l, false), + mcandle_core::DType::BF16 => self.cuda_fwd_t::(q, q_l, k, k_l, v, v_l, true), + dt => mcandle_core::bail!("flash-attn is only supported for f16/bf16 ({dt:?})"), + } + } +} + +#[allow(clippy::too_many_arguments)] +/// Flash-attention v2 layer with variable-length batching. +/// +/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. +/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads +/// than q, the number of heads in k and v has to be divisible by the number of heads in q. +/// +/// # Arguments +/// +/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`. +/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`. +/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`. +/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q. +/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v. +/// * `max_seqlen_q` - The maximum query sequence length for q in the batch. +/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch. +/// +/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`, +/// `seqlen_1 + seqlen_2`, etc. +/// +/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`. +pub fn flash_attn_varlen( + q: &Tensor, + k: &Tensor, + v: &Tensor, + seqlens_q: &Tensor, + seqlens_k: &Tensor, + max_seqlen_q: usize, + max_seqlen_k: usize, + softmax_scale: f32, + causal: bool, +) -> Result { + flash_attn_varlen_softcap( + q, + k, + v, + seqlens_q, + seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + None, + causal, + ) +} + +/// Equivalent to [`flash_attn_varlen`], but with softcap support. +pub fn flash_attn_varlen_softcap( + q: &Tensor, + k: &Tensor, + v: &Tensor, + seqlens_q: &Tensor, + seqlens_k: &Tensor, + max_seqlen_q: usize, + max_seqlen_k: usize, + softmax_scale: f32, + softcap: Option, + causal: bool, +) -> Result { + let window_size_left = None; + let window_size_right = if causal { Some(0) } else { None }; + + let op = FlashAttnVarLen { + softmax_scale, + softcap, + max_seqlen_q, + max_seqlen_k, + seqlens_q: seqlens_q.clone(), + seqlens_k: seqlens_k.clone(), + alibi_slopes: None, + window_size_left, + window_size_right, + }; + q.apply_op3(k, v, op) +} + +#[allow(clippy::too_many_arguments)] +/// Flash-attention v2 layer with variable-length batching. +/// +/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. +/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads +/// than q, the number of heads in k and v has to be divisible by the number of heads in q. +/// +/// # Arguments +/// +/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`. +/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`. +/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`. +/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q. +/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v. +/// * `max_seqlen_q` - The maximum query sequence length for q in the batch. +/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch. +/// * `window_size_left` - Limit left attention to value tokens. +/// * `window_size_right` - Limit right attention to value tokens. +/// +/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`, +/// `seqlen_1 + seqlen_2`, etc. +/// +/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`. +/// +/// # Causal mask +/// +/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result +/// of `Q @ K^T` +pub fn flash_attn_varlen_windowed( + q: &Tensor, + k: &Tensor, + v: &Tensor, + seqlens_q: &Tensor, + seqlens_k: &Tensor, + max_seqlen_q: usize, + max_seqlen_k: usize, + softmax_scale: f32, + window_size_left: Option, + window_size_right: Option, +) -> Result { + flash_attn_varlen_windowed_softcap( + q, + k, + v, + seqlens_q, + seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + None, + window_size_left, + window_size_right, + ) +} + +/// Equivalent to [`flash_attn_varlen_windowed`], but with softcap support. +pub fn flash_attn_varlen_windowed_softcap( + q: &Tensor, + k: &Tensor, + v: &Tensor, + seqlens_q: &Tensor, + seqlens_k: &Tensor, + max_seqlen_q: usize, + max_seqlen_k: usize, + softmax_scale: f32, + softcap: Option, + window_size_left: Option, + window_size_right: Option, +) -> Result { + let op = FlashAttnVarLen { + softmax_scale, + softcap, + max_seqlen_q, + max_seqlen_k, + seqlens_q: seqlens_q.clone(), + seqlens_k: seqlens_k.clone(), + alibi_slopes: None, + window_size_left, + window_size_right, + }; + q.apply_op3(k, v, op) +} + +#[allow(clippy::too_many_arguments)] +/// Flash-attention v2 layer with variable-length batching. +/// +/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. +/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads +/// than q, the number of heads in k and v has to be divisible by the number of heads in q. +/// +/// # Arguments +/// +/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`. +/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`. +/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`. +/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`. +/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q. +/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v. +/// * `max_seqlen_q` - The maximum query sequence length for q in the batch. +/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch. +/// +/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`, +/// `seqlen_1 + seqlen_2`, etc. +/// +/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`. +pub fn flash_attn_varlen_alibi( + q: &Tensor, + k: &Tensor, + v: &Tensor, + alibi_slopes: &Tensor, + seqlens_q: &Tensor, + seqlens_k: &Tensor, + max_seqlen_q: usize, + max_seqlen_k: usize, + softmax_scale: f32, + causal: bool, +) -> Result { + flash_attn_varlen_alibi_softcap( + q, + k, + v, + alibi_slopes, + seqlens_q, + seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + None, + causal, + ) +} + +/// Equivalent to [`flash_attn_varlen_alibi`], but with softcap support +pub fn flash_attn_varlen_alibi_softcap( + q: &Tensor, + k: &Tensor, + v: &Tensor, + alibi_slopes: &Tensor, + seqlens_q: &Tensor, + seqlens_k: &Tensor, + max_seqlen_q: usize, + max_seqlen_k: usize, + softmax_scale: f32, + softcap: Option, + causal: bool, +) -> Result { + let window_size_left = None; + let window_size_right = if causal { Some(0) } else { None }; + + let op = FlashAttnVarLen { + softmax_scale, + softcap, + max_seqlen_q, + max_seqlen_k, + seqlens_q: seqlens_q.clone(), + seqlens_k: seqlens_k.clone(), + alibi_slopes: Some(alibi_slopes.clone()), + window_size_left, + window_size_right, + }; + q.apply_op3(k, v, op) +} + +#[allow(clippy::too_many_arguments)] +/// Flash-attention v2 layer with variable-length batching. +/// +/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`. +/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads +/// than q, the number of heads in k and v has to be divisible by the number of heads in q. +/// +/// # Arguments +/// +/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`. +/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`. +/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`. +/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`. +/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q. +/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v. +/// * `max_seqlen_q` - The maximum query sequence length for q in the batch. +/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch. +/// * `window_size_left` - Limit left attention to value tokens. +/// * `window_size_right` - Limit right attention to value tokens. +/// +/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`, +/// `seqlen_1 + seqlen_2`, etc. +/// +/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`. +/// +/// # Causal mask +/// +/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result +/// of `Q @ K^T` +pub fn flash_attn_varlen_alibi_windowed( + q: &Tensor, + k: &Tensor, + v: &Tensor, + alibi_slopes: &Tensor, + seqlens_q: &Tensor, + seqlens_k: &Tensor, + max_seqlen_q: usize, + max_seqlen_k: usize, + softmax_scale: f32, + window_size_left: Option, + window_size_right: Option, +) -> Result { + flash_attn_varlen_alibi_windowed_softcap( + q, + k, + v, + alibi_slopes, + seqlens_q, + seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + None, + window_size_left, + window_size_right, + ) +} + +/// Equivalent to [`flash_attn_varlen_alibi_windowed`], but with softcap support. +pub fn flash_attn_varlen_alibi_windowed_softcap( + q: &Tensor, + k: &Tensor, + v: &Tensor, + alibi_slopes: &Tensor, + seqlens_q: &Tensor, + seqlens_k: &Tensor, + max_seqlen_q: usize, + max_seqlen_k: usize, + softmax_scale: f32, + softcap: Option, + window_size_left: Option, + window_size_right: Option, +) -> Result { + let op = FlashAttnVarLen { + softmax_scale, + softcap, + max_seqlen_q, + max_seqlen_k, + seqlens_q: seqlens_q.clone(), + seqlens_k: seqlens_k.clone(), + alibi_slopes: Some(alibi_slopes.clone()), + window_size_left, + window_size_right, + }; + q.apply_op3(k, v, op) +} diff --git a/mcandle/mcandle-flash-attn/tests/flash_attn_tests.rs b/mcandle/mcandle-flash-attn/tests/flash_attn_tests.rs new file mode 100644 index 000000000..168dfe55a --- /dev/null +++ b/mcandle/mcandle-flash-attn/tests/flash_attn_tests.rs @@ -0,0 +1,227 @@ +use anyhow::Result; +use mcandle_core::{DType, Device, IndexOp, Tensor, D}; + +fn to_vec3_round(t: Tensor, digits: i32) -> Result>>> { + let b = 10f32.powi(digits); + let t = t.to_vec3::()?; + let t = t + .iter() + .map(|t| { + t.iter() + .map(|t| t.iter().map(|t| f32::round(t * b) / b).collect()) + .collect() + }) + .collect(); + Ok(t) +} + +fn fa_acausal( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + softcap: Option, +) -> Result { + let in_dtype = q.dtype(); + let q = q.to_dtype(DType::F32)?; + let k = k.to_dtype(DType::F32)?; + let v = v.to_dtype(DType::F32)?; + let mut att = (q.matmul(&k.t()?)? * softmax_scale as f64)?; + if let Some(softcap) = softcap { + att = (att / softcap as f64)?; + att = att.tanh()?; + att = (att * softcap as f64)?; + } + let att = mcandle_nn::ops::softmax(&att, D::Minus1)?; + // Convert to contiguous as matmul doesn't support strided vs for now. + let output = att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?; + Ok(output) +} + +#[test] +fn flash_attn_acausal() -> Result<()> { + let device = Device::new_cuda(0)?; + let q = Tensor::arange(0u32, 48, &device)? + .to_dtype(DType::F16)? + .reshape((1, 3, 2, 8))?; + let k = (&q / 40.)?; + let v = (&q / 50.)?; + let q = (&q / 30.)?; + + let ys1 = fa_acausal(&q, &k, &v, 0.5, None)?; + let ys1 = ys1.i(0)?.to_dtype(DType::F32)?; + let ys2 = { + let q = q.transpose(1, 2)?; + let k = k.transpose(1, 2)?; + let v = v.transpose(1, 2)?; + mcandle_flash_attn::flash_attn(&q, &k, &v, 0.5, false)?.transpose(1, 2)? + }; + let ys2 = ys2.i(0)?.to_dtype(DType::F32)?; + let diff = ys1.sub(&ys2)?.abs()?.flatten_all()?.max(0)?; + + assert_eq!(ys1.dims(), &[3, 2, 8]); + assert_eq!( + to_vec3_round(ys1, 4)?, + &[ + [ + [0.0837, 0.1038, 0.1238, 0.1438, 0.1637, 0.1837, 0.2037, 0.2238], + [0.0922, 0.1122, 0.1322, 0.1522, 0.1721, 0.1921, 0.2122, 0.2322] + ], + [ + [0.4204, 0.4404, 0.4604, 0.4805, 0.5005, 0.5205, 0.5405, 0.5605], + [0.428, 0.448, 0.468, 0.488, 0.5083, 0.5283, 0.5483, 0.5684] + ], + [ + [0.7554, 0.7754, 0.7954, 0.8154, 0.8354, 0.8555, 0.8755, 0.8955], + [0.7622, 0.7822, 0.8022, 0.8223, 0.8423, 0.8623, 0.8823, 0.9023] + ] + ] + ); + + assert_eq!(ys2.dims(), &[3, 2, 8]); + assert_eq!( + to_vec3_round(ys2, 4)?, + &[ + [ + [0.0837, 0.1038, 0.1238, 0.1438, 0.1637, 0.1837, 0.2037, 0.2238], + [0.0922, 0.1122, 0.1322, 0.1522, 0.1721, 0.1921, 0.2122, 0.2322] + ], + [ + [0.4204, 0.4404, 0.4604, 0.4805, 0.5005, 0.5205, 0.5405, 0.5605], + [0.428, 0.448, 0.468, 0.488, 0.5083, 0.5283, 0.5483, 0.5684] + ], + [ + [0.7554, 0.7754, 0.7954, 0.8154, 0.8354, 0.8555, 0.8755, 0.8955], + [0.7622, 0.7822, 0.8022, 0.8223, 0.8423, 0.8623, 0.8823, 0.9023] + ] + ] + ); + assert!(diff.to_vec0::()?.abs() < 1e-5); + Ok(()) +} + +#[test] +fn flash_attn_varlen() -> Result<()> { + let device = Device::new_cuda(0)?; + let q = Tensor::arange(0u32, 48, &device)? + .to_dtype(DType::F16)? + .reshape((3, 2, 8))?; + let k = (&q / 40.)?; + let v = (&q / 50.)?; + let q = (&q / 30.)?; + + let seqlens_q = Tensor::new(&[0u32, 2u32], &device)?; + let seqlens_k = Tensor::new(&[0u32, 2u32], &device)?; + + let ys = { + let q = q.transpose(0, 1)?; + let k = k.transpose(0, 1)?; + let v = v.transpose(0, 1)?; + mcandle_flash_attn::flash_attn_varlen( + &q, &k, &v, &seqlens_q, &seqlens_k, 32, 32, 0.5, false, + )? + .transpose(0, 1)? + }; + let ys = ys.to_dtype(DType::F32)?; + + assert_eq!(ys.dims(), &[3, 2, 8]); + assert_eq!( + to_vec3_round(ys, 4)?, + &[ + [ + [0.0837, 0.1038, 0.1238, 0.1438, 0.1637, 0.1837, 0.2037, 0.2238], + [0.0922, 0.1122, 0.1322, 0.1522, 0.1721, 0.1921, 0.2122, 0.2322] + ], + [ + [0.4204, 0.4404, 0.4604, 0.4805, 0.5005, 0.5205, 0.5405, 0.5605], + [0.428, 0.448, 0.468, 0.488, 0.5083, 0.5283, 0.5483, 0.5684] + ], + [ + [0.7554, 0.7754, 0.7954, 0.8154, 0.8354, 0.8555, 0.8755, 0.8955], + [0.7622, 0.7822, 0.8022, 0.8223, 0.8423, 0.8623, 0.8823, 0.9023] + ] + ] + ); + Ok(()) +} + +#[test] +fn flash_attn_acausal_softcap() -> Result<()> { + let device = Device::new_cuda(0)?; + let q = Tensor::arange(0u32, 48, &device)? + .to_dtype(DType::F16)? + .reshape((1, 3, 2, 8))?; + let k = (&q / 40.)?; + let v = (&q / 50.)?; + let q = (&q / 30.)?; + + let ys1 = fa_acausal(&q, &k, &v, 0.5, Some(30.))?; + let ys1 = ys1.i(0)?.to_dtype(DType::F32)?; + let ys2 = { + let q = q.transpose(1, 2)?; + let k = k.transpose(1, 2)?; + let v = v.transpose(1, 2)?; + mcandle_flash_attn::flash_attn_softcap(&q, &k, &v, 0.5, Some(30.), false)?.transpose(1, 2)? + }; + let ys2 = ys2.i(0)?.to_dtype(DType::F32)?; + let diff = ys1.sub(&ys2)?.abs()?.flatten_all()?.max(0)?; + + assert_eq!(ys1.dims(), &[3, 2, 8]); + assert_eq!(ys2.dims(), &[3, 2, 8]); + assert!(diff.to_vec0::()?.abs() < 1e-5); + Ok(()) +} + +#[test] +fn flash_attn_varlen_softcap() -> Result<()> { + let device = Device::new_cuda(0)?; + let q = Tensor::arange(0u32, 48, &device)? + .to_dtype(DType::F16)? + .reshape((3, 2, 8))?; + let k = (&q / 40.)?; + let v = (&q / 50.)?; + let q = (&q / 30.)?; + + let seqlens_q = Tensor::new(&[0u32, 2u32], &device)?; + let seqlens_k = Tensor::new(&[0u32, 2u32], &device)?; + + let ys = { + let q = q.transpose(0, 1)?; + let k = k.transpose(0, 1)?; + let v = v.transpose(0, 1)?; + mcandle_flash_attn::flash_attn_varlen_softcap( + &q, + &k, + &v, + &seqlens_q, + &seqlens_k, + 32, + 32, + 0.5, + Some(30.), + false, + )? + .transpose(0, 1)? + }; + let ys = ys.to_dtype(DType::F32)?; + + assert_eq!(ys.dims(), &[3, 2, 8]); + assert_eq!( + to_vec3_round(ys, 4)?, + &[ + [ + [0.0837, 0.1038, 0.1238, 0.1438, 0.1637, 0.1837, 0.2037, 0.2238], + [0.0922, 0.1122, 0.1322, 0.1522, 0.1721, 0.1921, 0.2122, 0.2322] + ], + [ + [0.4204, 0.4404, 0.4604, 0.4805, 0.5005, 0.5205, 0.5405, 0.5605], + [0.428, 0.448, 0.468, 0.488, 0.5078, 0.5278, 0.5479, 0.5679] + ], + [ + [0.7549, 0.7749, 0.7949, 0.8149, 0.835, 0.855, 0.875, 0.895], + [0.7607, 0.7808, 0.8008, 0.8208, 0.8408, 0.8608, 0.8809, 0.9009] + ] + ] + ); + Ok(()) +} diff --git a/mcandle/mcandle-kernels/Cargo.toml b/mcandle/mcandle-kernels/Cargo.toml new file mode 100644 index 000000000..1ff29662d --- /dev/null +++ b/mcandle/mcandle-kernels/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "mcandle-kernels" +version = "0.3.2" +edition = "2021" + +description = "CUDA kernels for Candle" +repository = "https://github.com/huggingface/mcandle" +keywords = ["blas", "tensor", "machine-learning"] +categories = ["science"] +license = "MIT OR Apache-2.0" + +[dependencies] + +[build-dependencies] +bindgen_cuda = "0.1.1" diff --git a/mcandle/mcandle-kernels/README.md b/mcandle/mcandle-kernels/README.md new file mode 100644 index 000000000..7056b203a --- /dev/null +++ b/mcandle/mcandle-kernels/README.md @@ -0,0 +1,4 @@ +# mcandle-kernels + +This crate contains CUDA kernels used from mcandle. Some of these implementations +come from the [dfdx crate](https://github.com/coreylowman/dfdx). diff --git a/mcandle/mcandle-kernels/build.rs b/mcandle/mcandle-kernels/build.rs new file mode 100644 index 000000000..c28abd979 --- /dev/null +++ b/mcandle/mcandle-kernels/build.rs @@ -0,0 +1,11 @@ +fn main() { + println!("cargo:rerun-if-changed=build.rs"); + println!("cargo:rerun-if-changed=src/compatibility.cuh"); + println!("cargo:rerun-if-changed=src/cuda_utils.cuh"); + println!("cargo:rerun-if-changed=src/binary_op_macros.cuh"); + + let builder = bindgen_cuda::Builder::default(); + println!("cargo:info={builder:?}"); + let bindings = builder.build_ptx().unwrap(); + bindings.write("src/lib.rs").unwrap(); +} diff --git a/mcandle/mcandle-kernels/src/affine.cu b/mcandle/mcandle-kernels/src/affine.cu new file mode 100644 index 000000000..ef75dffd3 --- /dev/null +++ b/mcandle/mcandle-kernels/src/affine.cu @@ -0,0 +1,49 @@ +#include "cuda_utils.cuh" +#include + +#define AFFINE_OP(TYPENAME, FN_NAME, AFFINE) \ +extern "C" __global__ void FN_NAME( \ + const size_t numel, \ + const size_t num_dims, \ + const size_t *info, \ + const TYPENAME *inp, \ + TYPENAME *out, \ + const TYPENAME mul, \ + const TYPENAME add \ +) { \ + const size_t *dims = info; \ + const size_t *strides = info + num_dims; \ + if (info == nullptr || is_contiguous(num_dims, dims, strides)) { \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + TYPENAME x = inp ? inp[i] : out[i]; \ + out[i] = AFFINE; \ + } \ + } \ + else { \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \ + TYPENAME x = inp ? inp[strided_i] : out[i]; \ + out[i] = AFFINE; \ + } \ + } \ +} \ + +#if __CUDA_ARCH__ >= 800 +AFFINE_OP(__nv_bfloat16, affine_bf16, x * mul + add) + +#define F8E4M3_TO_FLOAT(x) __half2float(__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3)) + +AFFINE_OP(__nv_fp8_e4m3, affine_f8_e4m3, __nv_fp8_e4m3(F8E4M3_TO_FLOAT(x) * F8E4M3_TO_FLOAT(mul) + F8E4M3_TO_FLOAT(add))) +#endif + +#if __CUDA_ARCH__ >= 530 +AFFINE_OP(__half, affine_f16, x * mul + add) +#endif + +AFFINE_OP(float, affine_f32, x * mul + add) +AFFINE_OP(double, affine_f64, x * mul + add) +AFFINE_OP(uint8_t, affine_u8, x * mul + add) +AFFINE_OP(uint32_t, affine_u32, x * mul + add) +AFFINE_OP(int16_t, affine_i16, x * mul + add) +AFFINE_OP(int32_t, affine_i32, x * mul + add) +AFFINE_OP(int64_t, affine_i64, x * mul + add) diff --git a/mcandle/mcandle-kernels/src/binary.cu b/mcandle/mcandle-kernels/src/binary.cu new file mode 100644 index 000000000..7bda3e463 --- /dev/null +++ b/mcandle/mcandle-kernels/src/binary.cu @@ -0,0 +1,138 @@ +#include "binary_op_macros.cuh" +#include + +#if __CUDA_ARCH__ >= 800 +BINARY_OP(__nv_bfloat16, badd_bf16, x + y) +BINARY_OP(__nv_bfloat16, bdiv_bf16, x / y) +BINARY_OP(__nv_bfloat16, bmul_bf16, x * y) +BINARY_OP(__nv_bfloat16, bsub_bf16, x - y) +BINARY_OP(__nv_bfloat16, bmaximum_bf16, maxg(x, y)) +BINARY_OP(__nv_bfloat16, bminimum_bf16, ming(x, y)) +BINARY_OP_OUT(__nv_bfloat16, uint8_t, eq_bf16, x == y) +BINARY_OP_OUT(__nv_bfloat16, uint8_t, ne_bf16, x != y) +BINARY_OP_OUT(__nv_bfloat16, uint8_t, lt_bf16, x < y) +BINARY_OP_OUT(__nv_bfloat16, uint8_t, le_bf16, x <= y) +BINARY_OP_OUT(__nv_bfloat16, uint8_t, gt_bf16, x > y) +BINARY_OP_OUT(__nv_bfloat16, uint8_t, ge_bf16, x >= y) + +#define F8E4M3_TO_FLOAT(x) __half2float(__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3)) + +BINARY_OP(__nv_fp8_e4m3, badd_f8_e4m3, __nv_fp8_e4m3(F8E4M3_TO_FLOAT(x) + F8E4M3_TO_FLOAT(y))) +BINARY_OP(__nv_fp8_e4m3, bdiv_f8_e4m3, __nv_fp8_e4m3(F8E4M3_TO_FLOAT(x) / F8E4M3_TO_FLOAT(y))) +BINARY_OP(__nv_fp8_e4m3, bmul_f8_e4m3, __nv_fp8_e4m3(F8E4M3_TO_FLOAT(x) * F8E4M3_TO_FLOAT(y))) +BINARY_OP(__nv_fp8_e4m3, bsub_f8_e4m3, __nv_fp8_e4m3(F8E4M3_TO_FLOAT(x) - F8E4M3_TO_FLOAT(y))) +BINARY_OP(__nv_fp8_e4m3, bmaximum_f8_e4m3, maxg(x, y)) +BINARY_OP(__nv_fp8_e4m3, bminimum_f8_e4m3, ming(x, y)) +BINARY_OP_OUT(__nv_fp8_e4m3, uint8_t, eq_f8_e4m3, F8E4M3_TO_FLOAT(x) == F8E4M3_TO_FLOAT(y)) +BINARY_OP_OUT(__nv_fp8_e4m3, uint8_t, ne_f8_e4m3, F8E4M3_TO_FLOAT(x) != F8E4M3_TO_FLOAT(y)) +BINARY_OP_OUT(__nv_fp8_e4m3, uint8_t, lt_f8_e4m3, F8E4M3_TO_FLOAT(x) < F8E4M3_TO_FLOAT(y)) +BINARY_OP_OUT(__nv_fp8_e4m3, uint8_t, le_f8_e4m3, F8E4M3_TO_FLOAT(x) <= F8E4M3_TO_FLOAT(y)) +BINARY_OP_OUT(__nv_fp8_e4m3, uint8_t, gt_f8_e4m3, F8E4M3_TO_FLOAT(x) > F8E4M3_TO_FLOAT(y)) +BINARY_OP_OUT(__nv_fp8_e4m3, uint8_t, ge_f8_e4m3, F8E4M3_TO_FLOAT(x) >= F8E4M3_TO_FLOAT(y)) +#endif + +#if __CUDA_ARCH__ >= 530 +BINARY_OP(__half, badd_f16, x + y) +BINARY_OP(__half, bdiv_f16, x / y) +BINARY_OP(__half, bmul_f16, x * y) +BINARY_OP(__half, bsub_f16, x - y) +BINARY_OP(__half, bmaximum_f16, maxg(x, y)) +BINARY_OP(__half, bminimum_f16, ming(x, y)) +BINARY_OP_OUT(__half, uint8_t, eq_f16, x == y) +BINARY_OP_OUT(__half, uint8_t, ne_f16, x != y) +BINARY_OP_OUT(__half, uint8_t, lt_f16, x < y) +BINARY_OP_OUT(__half, uint8_t, le_f16, x <= y) +BINARY_OP_OUT(__half, uint8_t, gt_f16, x > y) +BINARY_OP_OUT(__half, uint8_t, ge_f16, x >= y) +#endif + +BINARY_OP(float, badd_f32, x + y) +BINARY_OP(double, badd_f64, x + y); +BINARY_OP(uint8_t, badd_u8, x + y); +BINARY_OP(uint32_t, badd_u32, x + y); +BINARY_OP(int16_t, badd_i16, x + y); +BINARY_OP(int32_t, badd_i32, x + y); +BINARY_OP(int64_t, badd_i64, x + y); +BINARY_OP(float, bdiv_f32, x / y) +BINARY_OP(double, bdiv_f64, x / y); +BINARY_OP(uint8_t, bdiv_u8, x / y); +BINARY_OP(uint32_t, bdiv_u32, x / y); +BINARY_OP(int16_t, bdiv_i16, x / y); +BINARY_OP(int32_t, bdiv_i32, x / y); +BINARY_OP(int64_t, bdiv_i64, x / y); +BINARY_OP(float, bmul_f32, x * y) +BINARY_OP(double, bmul_f64, x * y); +BINARY_OP(uint8_t, bmul_u8, x * y); +BINARY_OP(uint32_t, bmul_u32, x * y); +BINARY_OP(int16_t, bmul_i16, x * y); +BINARY_OP(int32_t, bmul_i32, x * y); +BINARY_OP(int64_t, bmul_i64, x * y); +BINARY_OP(float, bsub_f32, x - y) +BINARY_OP(double, bsub_f64, x - y); +BINARY_OP(uint8_t, bsub_u8, x - y); +BINARY_OP(uint32_t, bsub_u32, x - y); +BINARY_OP(int16_t, bsub_i16, x - y); +BINARY_OP(int32_t, bsub_i32, x - y); +BINARY_OP(int64_t, bsub_i64, x - y); +BINARY_OP(float, bminimum_f32, ming(x, y)); +BINARY_OP(double, bminimum_f64, ming(x, y)); +BINARY_OP(uint8_t, bminimum_u8, ming(x, y)); +BINARY_OP(uint32_t, bminimum_u32, ming(x, y)); +BINARY_OP(int16_t, bminimum_i16, ming(x, y)); +BINARY_OP(int32_t, bminimum_i32, ming(x, y)); +BINARY_OP(int64_t, bminimum_i64, ming(x, y)); +BINARY_OP(float, bmaximum_f32, maxg(x, y)); +BINARY_OP(double, bmaximum_f64, maxg(x, y)); +BINARY_OP(uint8_t, bmaximum_u8, maxg(x, y)); +BINARY_OP(uint32_t, bmaximum_u32, maxg(x, y)); +BINARY_OP(int16_t, bmaximum_i16, maxg(x, y)); +BINARY_OP(int32_t, bmaximum_i32, maxg(x, y)); +BINARY_OP(int64_t, bmaximum_i64, maxg(x, y)); + +BINARY_OP_OUT(float, uint8_t, eq_f32, x == y) +BINARY_OP_OUT(double, uint8_t, eq_f64, x == y) +BINARY_OP_OUT(uint8_t, uint8_t, eq_u8, x == y) +BINARY_OP_OUT(uint32_t, uint8_t, eq_u32, x == y) +BINARY_OP_OUT(int16_t, uint8_t, eq_i16, x == y) +BINARY_OP_OUT(int32_t, uint8_t, eq_i32, x == y) +BINARY_OP_OUT(int64_t, uint8_t, eq_i64, x == y) + +BINARY_OP_OUT(float, uint8_t, ne_f32, x != y) +BINARY_OP_OUT(double, uint8_t, ne_f64, x != y) +BINARY_OP_OUT(uint8_t, uint8_t, ne_u8, x != y) +BINARY_OP_OUT(uint32_t, uint8_t, ne_u32, x != y) +BINARY_OP_OUT(int16_t, uint8_t, ne_i16, x != y) +BINARY_OP_OUT(int32_t, uint8_t, ne_i32, x != y) +BINARY_OP_OUT(int64_t, uint8_t, ne_i64, x != y) + +BINARY_OP_OUT(float, uint8_t, lt_f32, x < y) +BINARY_OP_OUT(double, uint8_t, lt_f64, x < y) +BINARY_OP_OUT(uint8_t, uint8_t, lt_u8, x < y) +BINARY_OP_OUT(uint32_t, uint8_t, lt_u32, x < y) +BINARY_OP_OUT(int16_t, uint8_t, lt_i16, x < y) +BINARY_OP_OUT(int32_t, uint8_t, lt_i32, x < y) +BINARY_OP_OUT(int64_t, uint8_t, lt_i64, x < y) + +BINARY_OP_OUT(float, uint8_t, le_f32, x <= y) +BINARY_OP_OUT(double, uint8_t, le_f64, x <= y) +BINARY_OP_OUT(uint8_t, uint8_t, le_u8, x <= y) +BINARY_OP_OUT(uint32_t, uint8_t, le_u32, x <= y) +BINARY_OP_OUT(int16_t, uint8_t, le_i16, x <= y) +BINARY_OP_OUT(int32_t, uint8_t, le_i32, x <= y) +BINARY_OP_OUT(int64_t, uint8_t, le_i64, x <= y) + +BINARY_OP_OUT(float, uint8_t, gt_f32, x > y) +BINARY_OP_OUT(double, uint8_t, gt_f64, x > y) +BINARY_OP_OUT(uint8_t, uint8_t, gt_u8, x > y) +BINARY_OP_OUT(uint32_t, uint8_t, gt_u32, x > y) +BINARY_OP_OUT(int16_t, uint8_t, gt_i16, x > y) +BINARY_OP_OUT(int32_t, uint8_t, gt_i32, x > y) +BINARY_OP_OUT(int64_t, uint8_t, gt_i64, x > y) + +BINARY_OP_OUT(float, uint8_t, ge_f32, x >= y) +BINARY_OP_OUT(double, uint8_t, ge_f64, x >= y) +BINARY_OP_OUT(uint8_t, uint8_t, ge_u8, x >= y) +BINARY_OP_OUT(uint32_t, uint8_t, ge_u32, x >= y) +BINARY_OP_OUT(int16_t, uint8_t, ge_i16, x >= y) +BINARY_OP_OUT(int32_t, uint8_t, ge_i32, x >= y) +BINARY_OP_OUT(int64_t, uint8_t, ge_i64, x >= y) diff --git a/mcandle/mcandle-kernels/src/binary_op_macros.cuh b/mcandle/mcandle-kernels/src/binary_op_macros.cuh new file mode 100644 index 000000000..9cb00874f --- /dev/null +++ b/mcandle/mcandle-kernels/src/binary_op_macros.cuh @@ -0,0 +1,69 @@ +#include "cuda_utils.cuh" + +#define BINARY_OP_OUT(TYPENAME, OUT_TYPENAME, FN_NAME, FUNC) \ +extern "C" __global__ void FN_NAME( \ + const size_t numel, \ + const size_t num_dims, \ + const size_t *dims_and_strides, \ + const TYPENAME *lhs, \ + const TYPENAME *rhs, \ + OUT_TYPENAME *out \ +) { \ + const size_t *dims = dims_and_strides; \ + const size_t *lhs_strides = dims_and_strides + 1 * num_dims; \ + const size_t *rhs_strides = dims_and_strides + 2 * num_dims; \ + bool lhs_cont = dims_and_strides == nullptr || is_contiguous(num_dims, dims, lhs_strides); \ + bool rhs_cont = dims_and_strides == nullptr || is_contiguous(num_dims, dims, rhs_strides); \ + if (lhs_cont && rhs_cont) { \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + TYPENAME x = lhs[i]; \ + TYPENAME y = rhs[i]; \ + out[i] = FUNC; \ + } \ + } else if (lhs_cont) { \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + unsigned int tmp_i = i; \ + unsigned int rhs_i = 0; \ + for (int d = num_dims - 1; d >= 0; d--) { \ + unsigned int i_dim = tmp_i % dims[d]; \ + rhs_i += i_dim * rhs_strides[d]; \ + tmp_i /= dims[d]; \ + } \ + TYPENAME x = lhs[i]; \ + TYPENAME y = rhs[rhs_i]; \ + out[i] = FUNC; \ + } \ + } else if (rhs_cont) { \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + unsigned int tmp_i = i; \ + unsigned int lhs_i = 0; \ + for (int d = num_dims - 1; d >= 0; d--) { \ + unsigned int i_dim = tmp_i % dims[d]; \ + lhs_i += i_dim * lhs_strides[d]; \ + tmp_i /= dims[d]; \ + } \ + TYPENAME x = lhs[lhs_i]; \ + TYPENAME y = rhs[i]; \ + out[i] = FUNC; \ + } \ + } else { \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + unsigned int tmp_i = i; \ + unsigned int lhs_i = 0; \ + unsigned int rhs_i = 0; \ + for (int d = num_dims - 1; d >= 0; d--) { \ + unsigned int i_dim = tmp_i % dims[d]; \ + lhs_i += i_dim * lhs_strides[d]; \ + rhs_i += i_dim * rhs_strides[d]; \ + tmp_i /= dims[d]; \ + } \ + TYPENAME x = lhs[lhs_i]; \ + TYPENAME y = rhs[rhs_i]; \ + out[i] = FUNC; \ + } \ + } \ +} \ + + +#define BINARY_OP(TYPENAME, FN_NAME, FUNC) \ + BINARY_OP_OUT(TYPENAME, TYPENAME, FN_NAME, FUNC) diff --git a/mcandle/mcandle-kernels/src/cast.cu b/mcandle/mcandle-kernels/src/cast.cu new file mode 100644 index 000000000..207425b30 --- /dev/null +++ b/mcandle/mcandle-kernels/src/cast.cu @@ -0,0 +1,259 @@ +#include "cuda_utils.cuh" +#include + +template +__device__ void cast_( + const size_t numel, + const size_t num_dims, + const size_t *info, + const S *inp, + T *out +) { + const size_t *dims = info; + const size_t *strides = info + num_dims; + if (info == nullptr || is_contiguous(num_dims, dims, strides)) { + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + out[i] = inp[i]; + } + } + else { + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + unsigned strided_i = get_strided_index(i, num_dims, dims, strides); + out[i] = inp[strided_i]; + } + } +} + +#define F8E4M3_TO_FLOAT(x) __half2float(__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3)) + +template +__device__ void cast_fp8_( + const size_t numel, + const size_t num_dims, + const size_t *info, + const __nv_fp8_e4m3 *inp, + T *out +) { + const size_t *dims = info; + const size_t *strides = info + num_dims; + if (info == nullptr || is_contiguous(num_dims, dims, strides)) { + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + out[i] = F8E4M3_TO_FLOAT(inp[i]); + } + } + else { + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + unsigned strided_i = get_strided_index(i, num_dims, dims, strides); + out[i] = F8E4M3_TO_FLOAT(inp[strided_i]); + } + } +} +template +__device__ void cast_fp8_into_( + const size_t numel, + const size_t num_dims, + const size_t *info, + const S *inp, + __nv_fp8_e4m3 *out +) { + const size_t *dims = info; + const size_t *strides = info + num_dims; + if (info == nullptr || is_contiguous(num_dims, dims, strides)) { + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + out[i] = __nv_fp8_e4m3((float)inp[i]); + } + } + else { + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + unsigned strided_i = get_strided_index(i, num_dims, dims, strides); + out[i] = __nv_fp8_e4m3((float)inp[strided_i]); + } + } +} + +template +__device__ void cast_through( + const size_t numel, + const size_t num_dims, + const size_t *info, + const S *inp, + T *out +) { + const size_t *dims = info; + const size_t *strides = info + num_dims; + if (info == nullptr || is_contiguous(num_dims, dims, strides)) { + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + out[i] = static_cast(static_cast(inp[i])); + } + } + else { + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + unsigned strided_i = get_strided_index(i, num_dims, dims, strides); + out[i] = static_cast(static_cast(inp[strided_i])); + } + } +} + + +#define CAST_OP(SRC_TYPENAME, DST_TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const size_t numel, \ + const size_t num_dims, \ + const size_t *info, \ + const SRC_TYPENAME *inp, \ + DST_TYPENAME *out \ +) { \ + cast_(numel, num_dims, info, inp, out); \ +} \ + + +#define CAST_OP_FP8(SRC_TYPENAME, DST_TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const size_t numel, \ + const size_t num_dims, \ + const size_t *info, \ + const SRC_TYPENAME *inp, \ + DST_TYPENAME *out \ +) { \ + cast_fp8_(numel, num_dims, info, inp, out); \ +} \ + + +#define CAST_OP_FP8_INTO(SRC_TYPENAME, DST_TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const size_t numel, \ + const size_t num_dims, \ + const size_t *info, \ + const SRC_TYPENAME *inp, \ + DST_TYPENAME *out \ +) { \ + cast_fp8_into_(numel, num_dims, info, inp, out); \ +} \ + +#define CAST_THROUGH_OP(SRC_TYPENAME, DST_TYPENAME, INT_TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const size_t numel, \ + const size_t num_dims, \ + const size_t *info, \ + const SRC_TYPENAME *inp, \ + DST_TYPENAME *out \ +) { \ + cast_through(numel, num_dims, info, inp, out); \ +} \ + +#if __CUDA_ARCH__ >= 800 +CAST_OP(__nv_bfloat16, __nv_bfloat16, cast_bf16_bf16) +CAST_OP(__nv_fp8_e4m3, __nv_fp8_e4m3, cast_f8_e4m3_f8_e4m3) + +CAST_OP(__nv_bfloat16, uint32_t, cast_bf16_u32) +CAST_OP(__nv_bfloat16, float, cast_bf16_f32) +CAST_OP(__nv_bfloat16, double, cast_bf16_f64) +CAST_OP(uint8_t, __nv_bfloat16, cast_u8_bf16) +CAST_OP(uint32_t, __nv_bfloat16, cast_u32_bf16) +CAST_OP(float, __nv_bfloat16, cast_f32_bf16) +CAST_OP(double, __nv_bfloat16, cast_f64_bf16) +CAST_THROUGH_OP(__nv_bfloat16, uint8_t, float, cast_bf16_u8) +CAST_THROUGH_OP(__nv_bfloat16, __half, float, cast_bf16_f16) +CAST_THROUGH_OP(__half, __nv_bfloat16, float, cast_f16_bf16) +CAST_THROUGH_OP(int32_t, __nv_bfloat16, float, cast_i32_bf16) +CAST_THROUGH_OP(__nv_bfloat16, int32_t, float, cast_bf16_i32) + +CAST_OP_FP8(__nv_fp8_e4m3, float, cast_f8_e4m3_f32) +CAST_OP_FP8_INTO(float, __nv_fp8_e4m3, cast_f32_f8_e4m3) +CAST_OP_FP8(__nv_fp8_e4m3, uint8_t, cast_f8_e4m3_u8) +CAST_OP_FP8(__nv_fp8_e4m3, __half, cast_f8_e4m3_f16) +CAST_OP_FP8(__nv_fp8_e4m3, double, cast_f8_e4m3_f64) +CAST_OP_FP8_INTO(__half, __nv_fp8_e4m3, cast_f16_f8_e4m3) +CAST_OP_FP8_INTO(double, __nv_fp8_e4m3, cast_f64_f8_e4m3) +CAST_OP_FP8_INTO(uint8_t, __nv_fp8_e4m3, cast_u8_f8_e4m3) +CAST_OP_FP8_INTO(int32_t, __nv_fp8_e4m3, cast_i32_f8_e4m3) +CAST_OP_FP8(__nv_fp8_e4m3, int32_t, cast_f8_e4m3_i32) +CAST_OP_FP8(__nv_fp8_e4m3, __nv_bfloat16, cast_f8_e4m3_bf16) +CAST_OP_FP8_INTO(__nv_bfloat16, __nv_fp8_e4m3, cast_bf16_f8_e4m3) +#else +#include +#if CUDA_VERSION >= 11000 +CAST_OP(__nv_bfloat16, float, cast_bf16_f32) +CAST_OP(float, __nv_bfloat16, cast_f32_bf16) +CAST_THROUGH_OP(__nv_bfloat16, uint8_t, float, cast_bf16_u8) +CAST_THROUGH_OP(__nv_bfloat16, __half, float, cast_bf16_f16) +CAST_THROUGH_OP(__nv_bfloat16, double, float, cast_bf16_f64) +CAST_THROUGH_OP(__half, __nv_bfloat16, float, cast_f16_bf16) +CAST_THROUGH_OP(double, __nv_bfloat16, float, cast_f64_bf16) +CAST_THROUGH_OP(uint8_t, __nv_bfloat16, float, cast_u8_bf16) +CAST_THROUGH_OP(int32_t, __nv_bfloat16, float, cast_i32_bf16) +CAST_THROUGH_OP(__nv_bfloat16, int32_t, float, cast_bf16_i32) +CAST_THROUGH_OP(__nv_bfloat16, __nv_fp8_e4m3, float, cast_bf16_f8_e4m3) +#endif +#endif + +#if __CUDA_ARCH__ >= 530 +CAST_OP(__half, __half, cast_f16_f16) + +CAST_THROUGH_OP(__half, uint8_t, float, cast_f16_u8) +CAST_OP(__half, uint32_t, cast_f16_u32) +CAST_OP(__half, float, cast_f16_f32) +CAST_OP(__half, double, cast_f16_f64) +CAST_OP(uint8_t, __half, cast_u8_f16 ) +CAST_OP(uint32_t, __half, cast_u32_f16) +CAST_OP(float, __half, cast_f32_f16) +CAST_OP(double, __half, cast_f64_f16) +CAST_OP(int32_t, __half, cast_i32_f16 ) +CAST_THROUGH_OP(__half, int32_t, float, cast_f16_i32) +#endif + +CAST_OP(uint32_t, uint32_t, cast_u32_u32) +CAST_OP(uint32_t, uint8_t, cast_u32_u8 ) +CAST_OP(uint32_t, int64_t, cast_u32_i64 ) +CAST_OP(uint32_t, int32_t, cast_u32_i32 ) +CAST_OP(uint32_t, int16_t, cast_u32_i16 ) +CAST_OP(uint32_t, float, cast_u32_f32) +CAST_OP(uint32_t, double, cast_u32_f64) + +CAST_OP(uint8_t, uint32_t, cast_u8_u32) +CAST_OP(uint8_t, uint8_t, cast_u8_u8 ) +CAST_OP(uint8_t, int16_t, cast_u8_i16 ) +CAST_OP(uint8_t, int32_t, cast_u8_i32 ) +CAST_OP(uint8_t, int64_t, cast_u8_i64 ) +CAST_OP(uint8_t, float, cast_u8_f32) +CAST_OP(uint8_t, double, cast_u8_f64) + +CAST_OP(int64_t, uint32_t, cast_i64_u32) +CAST_OP(int64_t, uint8_t, cast_i64_u8 ) +CAST_OP(int64_t, int16_t, cast_i64_i16 ) +CAST_OP(int64_t, int32_t, cast_i64_i32 ) +CAST_OP(int64_t, int64_t, cast_i64_i64 ) +CAST_OP(int64_t, float, cast_i64_f32) +CAST_OP(int64_t, double, cast_i64_f64) + +CAST_OP(int32_t, uint32_t, cast_i32_u32) +CAST_OP(int32_t, uint8_t, cast_i32_u8 ) +CAST_OP(int32_t, int64_t, cast_i32_i64 ) +CAST_OP(int32_t, int32_t, cast_i32_i32 ) +CAST_OP(int32_t, int16_t, cast_i32_i16 ) +CAST_OP(int32_t, float, cast_i32_f32) +CAST_OP(int32_t, double, cast_i32_f64) + +CAST_OP(int16_t, uint32_t, cast_i16_u32) +CAST_OP(int16_t, uint8_t, cast_i16_u8 ) +CAST_OP(int16_t, int64_t, cast_i16_i64 ) +CAST_OP(int16_t, int32_t, cast_i16_i32 ) +CAST_OP(int16_t, int16_t, cast_i16_i16 ) +CAST_OP(int16_t, float, cast_i16_f32) +CAST_OP(int16_t, double, cast_i16_f64) + +CAST_OP(float, uint8_t, cast_f32_u8 ) +CAST_OP(float, uint32_t, cast_f32_u32) +CAST_OP(float, int16_t, cast_f32_i16 ) +CAST_OP(float, int32_t, cast_f32_i32 ) +CAST_OP(float, int64_t, cast_f32_i64 ) +CAST_OP(float, float, cast_f32_f32) +CAST_OP(float, double, cast_f32_f64) + +CAST_OP(double, uint8_t, cast_f64_u8 ) +CAST_OP(double, uint32_t, cast_f64_u32) +CAST_OP(double, int16_t, cast_f64_i16 ) +CAST_OP(double, int32_t, cast_f64_i32 ) +CAST_OP(double, int64_t, cast_f64_i64 ) +CAST_OP(double, float, cast_f64_f32) +CAST_OP(double, double, cast_f64_f64) diff --git a/mcandle/mcandle-kernels/src/compatibility.cuh b/mcandle/mcandle-kernels/src/compatibility.cuh new file mode 100644 index 000000000..1e4cf215c --- /dev/null +++ b/mcandle/mcandle-kernels/src/compatibility.cuh @@ -0,0 +1,164 @@ +#include "cuda_fp16.h" +#include "cuda_bf16.h" +#include "cuda_fp8.h" + +// Table showing which features are supported on which compute capability +// https://docs.nvidia.com/cuda/cuda-c-programming-guide/#features-and-technical-specifications + +// FIXME: the minimum compute capabilities are just guesses since the table is not specific enough + +#if (__CUDACC_VER_MAJOR__ < 12 || __CUDACC_VER_MINOR__ < 2) && __CUDA_ARCH__ < 800 +__device__ __forceinline__ __half __hmax_nan(__half a, __half b) { + return __hisnan(a) ? a : (__hisnan(b) ? b : __hmax(a, b)); +} +__device__ __forceinline__ __half __hmin_nan(__half a, __half b) { + return __hisnan(a) ? a : (__hisnan(b) ? b : __hmin(a, b)); +} +#endif + +#if __CUDA_ARCH__ < 600 +// Copied from https://docs.nvidia.com/cuda/cuda-c-programming-guide/#atomic-functions +__device__ double atomicAdd(double* address, double val) { + unsigned long long int* address_as_ull = (unsigned long long int*)address; + unsigned long long int old = *address_as_ull, assumed; + + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, + __double_as_longlong(val + + __longlong_as_double(assumed))); + + // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) + } while (assumed != old); + + return __longlong_as_double(old); +} +#endif + + +#if __CUDA_ARCH__ < 700 +// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#atomicadd +// The 16-bit __half floating-point version of atomicAdd() is only supported by devices of compute capability 7.x and higher. +// Solution adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh#L96-L119 +__device__ __half atomicAdd(__half *address, __half val) { + // unsigned int *address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); + // unsigned int old = *address_as_ui; + // unsigned int assumed; + // bool unaligned = (size_t) address & 2; + // do { + // assumed = old; + // unsigned int hsum; + // hsum = unaligned ? (old >> 16) : (old & 0xffff); + // hsum = __half_as_ushort(__ushort_as_half(hsum) + val); + // old = atomicCAS(address_as_ui, assumed, + // unaligned ? (old & 0xffff) | (hsum << 16) : (old & 0xffff0000) | hsum + // ); + + // } while (assumed != old); + // return __ushort_as_half(unaligned ? (old >> 16) : (old & 0xffff)); +} +#endif + + +__device__ __forceinline__ __half atomicMaxf(__half* address, __half val) { +#if __CUDA_ARCH__ < 700 + // On older GPUs we do not have access to atomicCAS for shorts, so we have to do some trickery. + // Solution adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh#L96-L119 + unsigned int *address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + bool unaligned = (size_t) address & 2; + do { + assumed = old; + unsigned int hmax; + hmax = unaligned ? (old >> 16) : (old & 0xffff); + hmax = __half_as_ushort(__hmax_nan(val, __ushort_as_half(hmax))); + old = atomicCAS(address_as_ui, assumed, + unaligned ? (old & 0xffff) | (hmax << 16) : (old & 0xffff0000) | hmax + ); + + } while (assumed != old); + return __ushort_as_half(unaligned ? (old >> 16) : (old & 0xffff)); +#else + // Based on https://docs.nvidia.com/cuda/cuda-c-programming-guide/#atomic-functions + unsigned short int* casted_address = (unsigned short int*)address; + unsigned short int old = *casted_address; + unsigned short int assumed; + do { + assumed = old; + old = atomicCAS(casted_address, assumed, __half_as_ushort(__hmax_nan(val, __ushort_as_half(assumed)))); + // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) + } while (assumed != old); + return __ushort_as_half(old); +#endif +} + +// atomicMax is not implemented for floats, +// solution copied https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda +__device__ __forceinline__ float atomicMaxf(float * addr, float value) { + if (signbit(value)) { + return __uint_as_float(atomicMin((unsigned int *)addr, __float_as_uint(value))); + } else { + return __int_as_float(atomicMax((int *)addr, __float_as_int(value))); + } +} + +__device__ __forceinline__ double atomicMaxf(double * addr, double value) { + if (signbit(value)) { + return __longlong_as_double(atomicMin((unsigned long long int *)addr, __double_as_longlong(value))); + } else { + return __longlong_as_double(atomicMax((long long int *)addr, __double_as_longlong(value))); + } +} + + +__device__ __forceinline__ __half atomicMinf(__half* address, __half val) { +#if __CUDA_ARCH__ < 700 + // On older GPUs we do not have access to atomicCAS for shorts, so we have to do some trickery. + // Solution adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh#L96-L119 + unsigned int *address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + bool unaligned = (size_t) address & 2; + do { + assumed = old; + unsigned int hmin; + hmin = unaligned ? (old >> 16) : (old & 0xffff); + hmin = __half_as_ushort(__hmin_nan(val, __ushort_as_half(hmin))); + old = atomicCAS(address_as_ui, assumed, + unaligned ? (old & 0xffff) | (hmin << 16) : (old & 0xffff0000) | hmin + ); + + } while (assumed != old); + return __ushort_as_half(unaligned ? (old >> 16) : (old & 0xffff)); +#else + // Based on https://docs.nvidia.com/cuda/cuda-c-programming-guide/#atomic-functions + unsigned short int* casted_address = (unsigned short int*)address; + unsigned short int old = *casted_address; + unsigned short int assumed; + do { + assumed = old; + old = atomicCAS(casted_address, assumed, __half_as_ushort(__hmin_nan(val, __ushort_as_half(assumed)))); + // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) + } while (assumed != old); + return __ushort_as_half(old); +#endif +} + +// atomicMin is not implemented for floats, +// solution copied https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda +__device__ __forceinline__ float atomicMinf(float * addr, float value) { + if (signbit(value)) { + return __uint_as_float(atomicMax((unsigned int *)addr, __float_as_uint(value))); + } else { + return __int_as_float(atomicMin((int *)addr, __float_as_int(value))); + } +} + +__device__ __forceinline__ double atomicMinf(double * addr, double value) { + if (signbit(value)) { + return __longlong_as_double(atomicMax((unsigned long long int *)addr, __double_as_longlong(value))); + } else { + return __longlong_as_double(atomicMin((long long int *)addr, __double_as_longlong(value))); + } +} diff --git a/mcandle/mcandle-kernels/src/conv.cu b/mcandle/mcandle-kernels/src/conv.cu new file mode 100644 index 000000000..6ca6fd7c2 --- /dev/null +++ b/mcandle/mcandle-kernels/src/conv.cu @@ -0,0 +1,780 @@ +#include "cuda_utils.cuh" +#include + +// Naive implementation of conv1d. +template +__device__ void conv1d( + const size_t src_numel, + const size_t l_out, + const size_t stride, + const size_t padding, + const size_t dilation, + const size_t *info, + const T *src, + const T *kernel, + T *dst +) { + // src: (b_size, c_in, l_in) + // k: (c_out, c_in, k_size) + const size_t *src_dims = info; + const size_t *src_s = info + 3; + const size_t *k_dims = info + 6; + const size_t *k_s = info + 9; + const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x; + const size_t k_size = k_dims[2]; + const size_t c_out = k_dims[0]; + const size_t c_in = src_dims[1]; + const size_t l_in = src_dims[2]; + if (dst_i >= src_dims[0] * c_out * l_out) { + return; + } + + // TODO + const size_t b_idx = dst_i / (l_out * c_out); + const size_t dst_c_idx = (dst_i / l_out) % c_out; + const size_t dst_l = dst_i % l_out; + + const size_t src_idx0 = b_idx * src_s[0]; + A d = 0; + for (size_t offset = 0; offset < k_size; ++offset) { + size_t src_l = (stride * dst_l + offset) * dilation; + if (src_l < padding || src_l >= padding + l_in) { + continue; + } + src_l -= padding; + for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) { + const size_t src_idx = src_idx0 + src_c_idx * src_s[1] + src_l * src_s[2]; + const size_t k_idx = dst_c_idx * k_s[0] + src_c_idx * k_s[1] + offset * k_s[2]; + d += static_cast(src[src_idx]) * static_cast(kernel[k_idx]); + } + } + dst[dst_i] = static_cast(d); +} + +template +__device__ void im2col1d( + const size_t dst_numel, + const size_t l_out, + const size_t l_k, + const size_t stride, + const size_t padding, + const size_t dilation, + const size_t *info, + const T *src, + T *dst +) { + const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x; + // dst: (b_size, l_out, c_in, l_k) + // src: (b_size, c_in, l_in) + if (dst_i >= dst_numel) { + return; + } + const size_t *src_dims = info; + const size_t *src_s = info + 3; + const size_t c_in = src_dims[1]; + const size_t l_in = src_dims[2]; + + const size_t dst_s2 = l_k; + const size_t dst_s1 = c_in * dst_s2; + const size_t dst_s0 = l_out * dst_s1; + + size_t tmp_dst_i = dst_i; + const size_t b_idx = tmp_dst_i / dst_s0; + tmp_dst_i -= b_idx * dst_s0; + const size_t l_idx = tmp_dst_i / dst_s1; + tmp_dst_i -= l_idx * dst_s1; + const size_t c_idx = tmp_dst_i / dst_s2; + tmp_dst_i -= c_idx * dst_s2; + const size_t l_k_idx = tmp_dst_i; + size_t src_l_idx = l_idx * stride + l_k_idx * dilation; + if (src_l_idx < padding || src_l_idx >= l_in + padding) { + dst[dst_i] = static_cast(0); + } + else { + src_l_idx -= padding; + const size_t src_i = b_idx * src_s[0] + c_idx * src_s[1] + src_l_idx * src_s[2]; + dst[dst_i] = src[src_i]; + } +} + +template +__device__ void col2im1d( + const size_t dst_el, + const size_t l_out, + const size_t l_in, + const size_t c_out, + const size_t k_size, + const size_t stride, + const T *src, + T *dst +) { + const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x; + // src: (b_size, l_in, c_out, l_k) + // dst: (b_size, c_out, l_out) + if (dst_i >= dst_el) { + return; + } + + const size_t dst_s0 = c_out * l_out; + const size_t dst_s1 = l_out; + const size_t src_s0 = c_out * k_size * l_in; + const size_t src_s1 = c_out * k_size; + const size_t src_s2 = k_size; + + size_t tmp_dst_i = dst_i; + const size_t b_idx = tmp_dst_i / dst_s0; + tmp_dst_i -= b_idx * dst_s0; + const size_t c_idx = tmp_dst_i / dst_s1; + tmp_dst_i -= c_idx * dst_s1; + const int l_out_idx = tmp_dst_i; + + dst[dst_i] = static_cast(0); + + int l_in_idx = l_out_idx / stride; + int k0 = l_out_idx - l_in_idx * stride; + // l_out_idx = l_in_idx * stride + k0 + for (; k0 < k_size && l_in_idx >= 0; k0 += stride, --l_in_idx) { + if (l_in_idx < l_in) { + const size_t src_i = b_idx * src_s0 + l_in_idx * src_s1 + c_idx * src_s2 + k0; + dst[dst_i] += src[src_i]; + } + } +} + +template +__device__ void im2col( + const size_t dst_numel, + const size_t h_out, + const size_t w_out, + const size_t h_k, + const size_t w_k, + const size_t stride, + const size_t padding, + const size_t dilation, + const size_t *info, + const T *src, + T *dst +) { + const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x; + // dst: (b_size, h_out, w_out, c_in, h_k, w_k) + // src: (b_size, c_in, h_in, w_in) + if (dst_i >= dst_numel) { + return; + } + const size_t *src_dims = info; + const size_t *src_s = info + 4; + const size_t c_in = src_dims[1]; + const size_t h_in = src_dims[2]; + const size_t w_in = src_dims[3]; + + const size_t dst_s4 = w_k; + const size_t dst_s3 = h_k * dst_s4; + const size_t dst_s2 = c_in * dst_s3; + const size_t dst_s1 = w_out * dst_s2; + const size_t dst_s0 = h_out * dst_s1; + + size_t tmp_dst_i = dst_i; + const size_t b_idx = tmp_dst_i / dst_s0; + tmp_dst_i -= b_idx * dst_s0; + const size_t h_idx = tmp_dst_i / dst_s1; + tmp_dst_i -= h_idx * dst_s1; + const size_t w_idx = tmp_dst_i / dst_s2; + tmp_dst_i -= w_idx * dst_s2; + const size_t c_idx = tmp_dst_i / dst_s3; + tmp_dst_i -= c_idx * dst_s3; + const size_t h_k_idx = tmp_dst_i / dst_s4; + tmp_dst_i -= h_k_idx * dst_s4; + const size_t w_k_idx = tmp_dst_i; + size_t src_h_idx = h_idx * stride + h_k_idx * dilation; + size_t src_w_idx = w_idx * stride + w_k_idx * dilation; + if (src_h_idx < padding || src_h_idx >= h_in + padding) { + dst[dst_i] = static_cast(0); + } + else if (src_w_idx < padding || src_w_idx >= w_in + padding) { + dst[dst_i] = static_cast(0); + } + else { + src_h_idx -= padding; + src_w_idx -= padding; + const size_t src_i = + b_idx * src_s[0] + + c_idx * src_s[1] + + src_h_idx * src_s[2] + + src_w_idx * src_s[3]; + dst[dst_i] = src[src_i]; + } +} + +// Naive implementation of conv2d. +template +__device__ void conv2d( + const size_t src_numel, + const size_t w_out, + const size_t h_out, + const size_t stride, + const size_t padding, + const size_t dilation, + const size_t *info, + const T *src, + const T *kernel, + T *dst +) { + const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x; + // src: (b_size, c_in, h_in, w_in) + // k: (c_out, c_in, h_k, w_k) + const size_t *src_dims = info; + const size_t *src_s = info + 4; + const size_t *k_dims = info + 8; + const size_t *k_s = info + 12; + const size_t h_k = k_dims[2]; + const size_t w_k = k_dims[3]; + const size_t c_out = k_dims[0]; + const size_t c_in = src_dims[1]; + const size_t h_in = src_dims[2]; + const size_t w_in = src_dims[3]; + if (dst_i >= src_dims[0] * c_out * w_out * h_out) { + return; + } + + // TODO + const size_t b_idx = dst_i / (w_out * h_out * c_out); + const size_t dst_c_idx = (dst_i / (w_out * h_out)) % c_out; + // NCHW layout. + const size_t dst_h = (dst_i / w_out) % h_out; + const size_t dst_w = dst_i % w_out; + + const size_t src_idx0 = b_idx * src_s[0]; + A d = 0; + for (size_t w_offset = 0; w_offset < w_k; ++w_offset) { + size_t src_w = stride * dst_w + w_offset * dilation; + if (src_w < padding || src_w >= w_in + padding) { + continue; + } + src_w -= padding; + for (size_t h_offset = 0; h_offset < h_k; ++h_offset) { + size_t src_h = stride * dst_h + h_offset * dilation; + if (src_h < padding || src_h >= h_in + padding) { + continue; + } + src_h -= padding; + for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) { + const size_t src_idx = src_idx0 + src_c_idx * src_s[1] + src_h * src_s[2] + src_w * src_s[3]; + const size_t k_idx = dst_c_idx * k_s[0] + src_c_idx * k_s[1] + h_offset * k_s[2] + w_offset * k_s[3]; + d += static_cast(src[src_idx]) * static_cast(kernel[k_idx]); + } + } + } + dst[dst_i] = static_cast(d); +} + +// Naive implementation of conv_transpose1d. +template +__device__ void conv_transpose1d( + const size_t src_numel, + const size_t l_out, + const size_t stride, + const size_t padding, + const size_t out_padding, + const size_t dilation, + const size_t *info, + const T *src, + const T *kernel, + T *dst +) { + const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x; + // src: (b_size, c_in, l_in) + // k: (c_in, c_out, l_k) + const size_t *src_dims = info; + const size_t *src_s = info + 3; + const size_t *k_dims = info + 6; + const size_t *k_s = info + 9; + const size_t l_k = k_dims[2]; + const size_t c_out = k_dims[1]; + const size_t c_in = src_dims[1]; + const size_t l_in = src_dims[2]; + if (dst_i >= src_dims[0] * c_out * l_out) { + return; + } + + // TODO + const size_t b_idx = dst_i / (l_out * c_out); + const size_t dst_c_idx = (dst_i / l_out) % c_out; + // NCL layout. + const size_t out_x = dst_i % l_out; + + const size_t src_idx0 = b_idx * src_s[0]; + A d = 0; + for (int k_x = 0; k_x < (int)l_k; ++k_x) { + // let out_x = inp_x * p.stride + k_x * p.dilation - p.padding; + int inp_x_stride = (int)(out_x + padding) - k_x * dilation; + if (inp_x_stride < 0 || inp_x_stride % stride) { + continue; + } + int inp_x = inp_x_stride / stride; + if (inp_x >= l_in) continue; + for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) { + const size_t src_idx = src_idx0 + src_c_idx * src_s[1] + inp_x * src_s[2]; + const size_t k_idx = src_c_idx * k_s[0] + dst_c_idx * k_s[1] + k_x * k_s[2]; + d += static_cast(src[src_idx]) * static_cast(kernel[k_idx]); + } + } + dst[dst_i] = static_cast(d); +} + +// Naive implementation of conv_transpose2d. +template +__device__ void conv_transpose2d( + const size_t src_numel, + const size_t w_out, + const size_t h_out, + const size_t stride, + const size_t padding, + const size_t out_padding, + const size_t dilation, + const size_t *info, + const T *src, + const T *kernel, + T *dst +) { + const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x; + // src: (b_size, c_in, h_in, w_in) + // k: (c_in, c_out, h_k, w_k) + const size_t *src_dims = info; + const size_t *src_s = info + 4; + const size_t *k_dims = info + 8; + const size_t *k_s = info + 12; + const size_t h_k = k_dims[2]; + const size_t w_k = k_dims[3]; + const size_t c_out = k_dims[1]; + const size_t c_in = src_dims[1]; + const size_t h_in = src_dims[2]; + const size_t w_in = src_dims[3]; + if (dst_i >= src_dims[0] * c_out * w_out * h_out) { + return; + } + + // TODO + const size_t b_idx = dst_i / (w_out * h_out * c_out); + const size_t dst_c_idx = (dst_i / (w_out * h_out)) % c_out; + // NCHW layout. + const size_t out_y = (dst_i / w_out) % h_out; + const size_t out_x = dst_i % w_out; + + const size_t src_idx0 = b_idx * src_s[0]; + A d = 0; + for (int k_x = 0; k_x < (int)w_k; ++k_x) { + // let out_x = inp_x * p.stride + k_x * p.dilation - p.padding; + int inp_x_stride = (int)(out_x + padding) - k_x * dilation; + if (inp_x_stride < 0 || inp_x_stride % stride) { + continue; + } + int inp_x = inp_x_stride / stride; + if (inp_x >= w_in) continue; + for (int k_y = 0; k_y < (int)h_k; ++k_y) { + int inp_y_stride = (int)(out_y + padding) - k_y * dilation; + if (inp_y_stride < 0 || inp_y_stride % stride) { + continue; + } + int inp_y = inp_y_stride / stride; + if (inp_y >= h_in) continue; + for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) { + const size_t src_idx = src_idx0 + src_c_idx * src_s[1] + inp_y * src_s[2] + inp_x * src_s[3]; + const size_t k_idx = src_c_idx * k_s[0] + dst_c_idx * k_s[1] + k_y * k_s[2] + k_x * k_s[3]; + d += static_cast(src[src_idx]) * static_cast(kernel[k_idx]); + } + } + } + dst[dst_i] = static_cast(d); +} + +template +__device__ void avg_pool2d( + const size_t src_numel, + const size_t w_k, + const size_t h_k, + const size_t w_stride, + const size_t h_stride, + const size_t *info, + const T *src, + T *dst +) { + const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x; + // src: (b_size, c_in, w_in, h_in) + const size_t *src_dims = info; + const size_t *src_s = info + 4; + + const size_t c = src_dims[1]; + const size_t w_in = src_dims[2]; + const size_t h_in = src_dims[3]; + + const size_t w_out = (w_in - w_k) / w_stride + 1; + const size_t h_out = (h_in - h_k) / h_stride + 1; + if (dst_i >= src_dims[0] * c * w_out * h_out) { + return; + } + + // TODO: Improve this. + const size_t b_idx = dst_i / (w_out * h_out * c); + const size_t c_idx = (dst_i / (w_out * h_out)) % c; + const size_t dst_w = (dst_i / h_out) % w_out; + const size_t dst_h = dst_i % h_out; + + const size_t src_idx0 = b_idx * src_s[0]; + const float scale = 1.0 / (w_k * h_k); + A d = 0; + for (size_t w_offset = 0; w_offset < w_k; ++w_offset) { + size_t src_w = w_stride * dst_w + w_offset; + if (src_w >= w_in) { + continue; + } + for (size_t h_offset = 0; h_offset < h_k; ++h_offset) { + size_t src_h = h_stride * dst_h + h_offset; + if (src_h >= h_in) { + continue; + } + const size_t src_idx = src_idx0 + c_idx * src_s[1] + src_w * src_s[2] + src_h * src_s[3]; + d += static_cast(src[src_idx]); + } + } + dst[dst_i] = static_cast(d * scale); +} + +template +__device__ void max_pool2d( + const size_t src_numel, + const size_t w_k, + const size_t h_k, + const size_t w_stride, + const size_t h_stride, + const size_t *info, + const T *src, + T *dst +) { + const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x; + // src: (b_size, c_in, w_in, h_in) + const size_t *src_dims = info; + const size_t *src_s = info + 4; + + const size_t c = src_dims[1]; + const size_t w_in = src_dims[2]; + const size_t h_in = src_dims[3]; + + const size_t w_out = (w_in - w_k) / w_stride + 1; + const size_t h_out = (h_in - h_k) / h_stride + 1; + if (dst_i >= src_dims[0] * c * w_out * h_out) { + return; + } + + // TODO: Improve this. + const size_t b_idx = dst_i / (w_out * h_out * c); + const size_t c_idx = (dst_i / (w_out * h_out)) % c; + const size_t dst_w = (dst_i / h_out) % w_out; + const size_t dst_h = dst_i % h_out; + + const size_t src_idx0 = b_idx * src_s[0]; + T d = 0; + bool set = false; + for (size_t w_offset = 0; w_offset < w_k; ++w_offset) { + size_t src_w = w_stride * dst_w + w_offset; + if (src_w >= w_in) { + continue; + } + for (size_t h_offset = 0; h_offset < h_k; ++h_offset) { + size_t src_h = h_stride * dst_h + h_offset; + if (src_h >= h_in) { + continue; + } + const size_t src_idx = src_idx0 + c_idx * src_s[1] + src_w * src_s[2] + src_h * src_s[3]; + if (set) { + d = maxg(d, src[src_idx]); + } + else { + d = src[src_idx]; + set = true; + } + } + } + dst[dst_i] = d; +} + +template +__device__ void upsample_nearest2d( + const size_t w_out, + const size_t h_out, + const double w_scale, + const double h_scale, + const size_t *info, + const T *src, + T *dst +) { + const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x; + // src: (b_size, c_in, w_in, h_in) + const size_t *src_dims = info; + const size_t *src_s = info + 4; + + const size_t c = src_dims[1]; + const size_t w_in = src_dims[2]; + const size_t h_in = src_dims[3]; + + if (dst_i >= src_dims[0] * c * w_out * h_out) { + return; + } + + // TODO: Improve this. + const size_t b_idx = dst_i / (w_out * h_out * c); + const size_t c_idx = (dst_i / (w_out * h_out)) % c; + const size_t dst_w = (dst_i / h_out) % w_out; + const size_t dst_h = dst_i % h_out; + + size_t src_w = static_cast(dst_w * w_scale); + size_t src_h = static_cast(dst_h * h_scale); + if (src_w >= w_in) { + src_w = w_in - 1; + } + if (src_h >= h_in) { + src_h = h_in - 1; + } + + const size_t src_i = b_idx * src_s[0] + c_idx * src_s[1] + src_w * src_s[2] + src_h * src_s[3]; + dst[dst_i] = src[src_i]; +} + + +#define CONV1D_OP(TYPENAME, TYPEACC, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const size_t src_numel, \ + const size_t num_dims, \ + const size_t stride, \ + const size_t padding, \ + const size_t dilation, \ + const size_t *info, \ + const TYPENAME *src, \ + const TYPENAME *kernel, \ + TYPENAME *dst \ +) { \ + conv1d(src_numel, num_dims, stride, padding, dilation, info, src, kernel, dst); \ +} \ + +#define CONV2D_OP(TYPENAME, TYPEACC, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const size_t src_numel, \ + const size_t w_out, \ + const size_t h_out, \ + const size_t stride, \ + const size_t padding, \ + const size_t dilation, \ + const size_t *info, \ + const TYPENAME *src, \ + const TYPENAME *kernel, \ + TYPENAME *dst \ +) { \ + conv2d(src_numel, w_out, h_out, stride, padding, dilation, info, src, kernel, dst); \ +} \ + +#define IM2COL1D_OP(TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const size_t dst_numel, \ + const size_t l_out, \ + const size_t l_k, \ + const size_t stride, \ + const size_t padding, \ + const size_t dilation, \ + const size_t *info, \ + const TYPENAME *src, \ + TYPENAME *dst \ +) { \ + im2col1d(dst_numel, l_out, l_k, stride, padding, dilation, info, src, dst); \ +} \ + +#define COL2IM1D_OP(TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const size_t dst_el, \ + const size_t l_out, \ + const size_t l_in, \ + const size_t c_out, \ + const size_t k_size, \ + const size_t stride, \ + const TYPENAME *src, \ + TYPENAME *dst \ +) { \ + col2im1d(dst_el, l_out, l_in, c_out, k_size, stride, src, dst); \ +} \ + +#define IM2COL_OP(TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const size_t dst_numel, \ + const size_t h_out, \ + const size_t w_out, \ + const size_t h_k, \ + const size_t w_k, \ + const size_t stride, \ + const size_t padding, \ + const size_t dilation, \ + const size_t *info, \ + const TYPENAME *src, \ + TYPENAME *dst \ +) { \ + im2col(dst_numel, h_out, w_out, h_k, w_k, stride, padding, dilation, info, src, dst); \ +} \ + +#define CONVT1D_OP(TYPENAME, TYPEACC, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const size_t src_numel, \ + const size_t l_out, \ + const size_t stride, \ + const size_t padding, \ + const size_t out_padding, \ + const size_t dilation, \ + const size_t *info, \ + const TYPENAME *src, \ + const TYPENAME *kernel, \ + TYPENAME *dst \ +) { \ + conv_transpose1d(src_numel, l_out, stride, padding, out_padding, dilation, info, src, kernel, dst); \ +} \ + +#define CONVT2D_OP(TYPENAME, TYPEACC, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const size_t src_numel, \ + const size_t w_out, \ + const size_t h_out, \ + const size_t stride, \ + const size_t padding, \ + const size_t out_padding, \ + const size_t dilation, \ + const size_t *info, \ + const TYPENAME *src, \ + const TYPENAME *kernel, \ + TYPENAME *dst \ +) { \ + conv_transpose2d(src_numel, w_out, h_out, stride, padding, out_padding, dilation, info, src, kernel, dst); \ +} \ + +#define AVG_POOL2D_OP(TYPENAME, TYPEACC, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const size_t src_numel, \ + const size_t w_k, \ + const size_t h_k, \ + const size_t w_stride, \ + const size_t h_stride, \ + const size_t *info, \ + const TYPENAME *src, \ + TYPENAME *dst \ +) { \ + avg_pool2d(src_numel, w_k, h_k, w_stride, h_stride, info, src, dst); \ +} \ + +#define MAX_POOL2D_OP(TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const size_t src_numel, \ + const size_t w_k, \ + const size_t h_k, \ + const size_t w_stride, \ + const size_t h_stride, \ + const size_t *info, \ + const TYPENAME *src, \ + TYPENAME *dst \ +) { \ + max_pool2d(src_numel, w_k, h_k, w_stride, h_stride, info, src, dst); \ +} \ + +#define UPSAMPLE_NEAREST2D_OP(TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const size_t w_out, \ + const size_t h_out, \ + const double w_scale, \ + const double h_scale, \ + const size_t *info, \ + const TYPENAME *src, \ + TYPENAME *dst \ +) { \ + upsample_nearest2d(w_out, h_out, w_scale, h_scale, info, src, dst); \ +} \ + +#if __CUDA_ARCH__ >= 800 +CONV1D_OP(__nv_bfloat16, float, conv1d_bf16) +CONV2D_OP(__nv_bfloat16, float, conv2d_bf16) +CONVT1D_OP(__nv_bfloat16, float, conv_transpose1d_bf16) +CONVT2D_OP(__nv_bfloat16, float, conv_transpose2d_bf16) +AVG_POOL2D_OP(__nv_bfloat16, float, avg_pool2d_bf16) +MAX_POOL2D_OP(__nv_bfloat16, max_pool2d_bf16) +UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16) +IM2COL_OP(__nv_bfloat16, im2col_bf16) +IM2COL1D_OP(__nv_bfloat16, im2col1d_bf16) +COL2IM1D_OP(__nv_bfloat16, col2im1d_bf16) + +// NOTE: No conv ops for f8 +// CONV1D_OP(__nv_bfloat16, float, conv1d_f8_e5m) +// CONV2D_OP(__nv_fp8_e4m3, float, conv2d_f8_e5m) +// CONVT1D_OP(__nv_fp8_e4m3, float, conv_transpose1d_f8_e5m) +// CONVT2D_OP(__nv_fp8_e4m3, float, conv_transpose2d_f8_e5m) +// AVG_POOL2D_OP(__nv_fp8_e4m3, float, avg_pool2d_f8_e5m) +// MAX_POOL2D_OP(__nv_fp8_e4m3, max_pool2d_f8_e5m) +// UPSAMPLE_NEAREST2D_OP(__nv_fp8_e4m3, upsample_nearest2d_f8_e5m) +// IM2COL_OP(__nv_fp8_e4m3, im2col_f8_e5m) +// IM2COL1D_OP(__nv_fp8_e4m3, im2col1d_f8_e5m) +// COL2IM1D_OP(__nv_fp8_e4m3, col2im1d_f8_e5m) +#endif + +#if __CUDA_ARCH__ >= 530 +CONV1D_OP(__half, float, conv1d_f16) +CONV2D_OP(__half, float, conv2d_f16) +CONVT1D_OP(__half, float, conv_transpose1d_f16) +CONVT2D_OP(__half, float, conv_transpose2d_f16) +AVG_POOL2D_OP(__half, float, avg_pool2d_f16) +MAX_POOL2D_OP(__half, max_pool2d_f16) +UPSAMPLE_NEAREST2D_OP(__half, upsample_nearest2d_f16) +IM2COL_OP(__half, im2col_f16) +IM2COL1D_OP(__half, im2col1d_f16) +COL2IM1D_OP(__half, col2im1d_f16) +#endif + +CONV1D_OP(float, float, conv1d_f32) +CONV1D_OP(double, double, conv1d_f64) +CONV1D_OP(uint8_t, uint8_t, conv1d_u8) +CONV1D_OP(uint32_t, uint32_t, conv1d_u32) + +CONV2D_OP(float, float, conv2d_f32) +CONV2D_OP(double, double, conv2d_f64) +CONV2D_OP(uint8_t, uint8_t, conv2d_u8) +CONV2D_OP(uint32_t, uint32_t, conv2d_u32) + +CONVT1D_OP(float, float, conv_transpose1d_f32) +CONVT1D_OP(double, double, conv_transpose1d_f64) +CONVT1D_OP(uint8_t, uint8_t, conv_transpose1d_u8) +CONVT1D_OP(uint32_t, uint32_t, conv_transpose1d_u32) + +CONVT2D_OP(float, float, conv_transpose2d_f32) +CONVT2D_OP(double, double, conv_transpose2d_f64) +CONVT2D_OP(uint8_t, uint8_t, conv_transpose2d_u8) +CONVT2D_OP(uint32_t, uint32_t, conv_transpose2d_u32) + +AVG_POOL2D_OP(float, float, avg_pool2d_f32) +AVG_POOL2D_OP(double, double, avg_pool2d_f64) +AVG_POOL2D_OP(uint8_t, uint8_t, avg_pool2d_u8) +AVG_POOL2D_OP(uint32_t, uint32_t, avg_pool2d_u32) + +MAX_POOL2D_OP(float, max_pool2d_f32) +MAX_POOL2D_OP(double, max_pool2d_f64) +MAX_POOL2D_OP(uint8_t, max_pool2d_u8) +MAX_POOL2D_OP(uint32_t, max_pool2d_u32) + +UPSAMPLE_NEAREST2D_OP(float, upsample_nearest2d_f32) +UPSAMPLE_NEAREST2D_OP(double, upsample_nearest2d_f64) +UPSAMPLE_NEAREST2D_OP(uint8_t, upsample_nearest2d_u8) +UPSAMPLE_NEAREST2D_OP(uint32_t, upsample_nearest2d_u32) + +IM2COL_OP(float, im2col_f32) +IM2COL_OP(double, im2col_f64) +IM2COL_OP(uint8_t, im2col_u8) +IM2COL_OP(uint32_t, im2col_u32) + +IM2COL1D_OP(float, im2col1d_f32) +IM2COL1D_OP(double, im2col1d_f64) +IM2COL1D_OP(uint8_t, im2col1d_u8) +IM2COL1D_OP(uint32_t, im2col1d_u32) + +COL2IM1D_OP(float, col2im1d_f32) +COL2IM1D_OP(double, col2im1d_f64) +COL2IM1D_OP(uint8_t, col2im1d_u8) +COL2IM1D_OP(uint32_t, col2im1d_u32) diff --git a/mcandle/mcandle-kernels/src/cuda_utils.cuh b/mcandle/mcandle-kernels/src/cuda_utils.cuh new file mode 100644 index 000000000..f3af71b18 --- /dev/null +++ b/mcandle/mcandle-kernels/src/cuda_utils.cuh @@ -0,0 +1,257 @@ +#include "compatibility.cuh" +#include +#include + +// TODO: This is often used to check that the data is contiguous so that +// kernels can be easily mapped. However this only returns true for row +// major, if all the inputs are column major, we could apply the fast path +// too (but we wouldn't if some of them are row major and some column major). +__device__ bool is_contiguous( + const size_t num_dims, + const size_t *dims, + const size_t *strides +) { + size_t acc = 1; + for (unsigned int d = 0; d < num_dims; d++) { + unsigned int dim_idx = num_dims - 1 - d; + if (dims[dim_idx] > 1 && acc != strides[dim_idx]) { + return false; + } + acc *= dims[dim_idx]; + } + return true; +} + +__device__ unsigned int get_strided_index( + unsigned int idx, + const size_t num_dims, + const size_t *dims, + const size_t *strides +) { + unsigned int strided_i = 0; + for (unsigned int d = 0; d < num_dims; d++) { + unsigned int dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; +} + +__device__ unsigned int restrided( + const unsigned int strided_i, + const size_t num_dims, + const size_t *dims, + const size_t *strides, + const size_t *new_strides +) { + unsigned int idx = 0; + for (int d = 0; d < num_dims; d++) { + idx += (strides[d] == 0 ? 0 : (strided_i / strides[d]) % dims[d]) * new_strides[d]; + } + return idx; +} + +// Sourced from https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2 +// Input must be less than or equal to 2 ^ 16 +// used in reductions +__device__ __forceinline__ unsigned int next_power_of_two(unsigned int v) { + v--; + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v++; + return v; +} + +// Efficiently computes the sum of each chunk in "data" of size chunk_len, and +// stores the sums in out[i / chunk_len] +template +__device__ void chunk_sum( + const size_t chunk_len, + const T data, + T* out +) { + __shared__ T buf[1024]; + + // assumes that threads where i >= numel have already exited + unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int block_i = threadIdx.x; + + // Fall back to atomicAdd if chunk_len is small to reduce overhead + if (chunk_len <= 2) { + atomicAdd(out + i / chunk_len, data); + return; + } + buf[block_i] = data; + + unsigned int chunk_i = i % chunk_len; + unsigned int chunk_start = max((int)(block_i - chunk_i), 0); + unsigned int chunk_end = min((unsigned int)(block_i + chunk_len - chunk_i), blockDim.x); + + chunk_i = block_i - chunk_start; + + size_t max_chunk_len = min(chunk_end - chunk_start, blockDim.x); + size_t incr = next_power_of_two(max_chunk_len) >> 1; + + __syncthreads(); + + // Uses sequential addressing as discussed in + // https://developer.download.nvidia.com/assets/cuda/files/reduction.pdf + for (; incr > 0; incr >>= 1) { + unsigned int block_i_2 = block_i + incr; + + if (block_i_2 < chunk_end && chunk_i < incr) { + // This is sound because __syncthreads and the conditions above + // ensure that no data races occur + buf[block_i] += buf[block_i_2]; + } + + __syncthreads(); + } + + if (block_i == chunk_start) { + atomicAdd(out + i / chunk_len, buf[block_i]); + } +} + +__device__ __forceinline__ int GetBlockNum(void) { + return (gridDim.x * gridDim.y * gridDim.z); +} + +__device__ __forceinline__ int GetBlockIdx(void) { + return (blockIdx.z * (gridDim.x * gridDim.y) + blockIdx.y * gridDim.x + + blockIdx.x); +} + +__device__ __forceinline__ int GetThreadNumEachBlock(void) { + return (blockDim.x * blockDim.y * blockDim.z); +} + +__device__ __forceinline__ int GetThreadNum(void) { + return GetBlockNum() * GetThreadNumEachBlock(); +} + +__device__ __forceinline__ int GetThreadIdxInBlock(void) { + return threadIdx.z * (blockDim.x * blockDim.y) + + threadIdx.y * blockDim.x + threadIdx.x; +} + +__device__ __forceinline__ int GetThreadIdx(void) { + int blockIdx = GetBlockIdx(); + int threadNumEachBlock = GetThreadNumEachBlock(); + + return blockIdx * threadNumEachBlock + GetThreadIdxInBlock(); +} + +__device__ __forceinline__ bool isnang(float a) { return isnan(a); } +__device__ __forceinline__ bool isnang(double a) { return isnan(a); } +__device__ __forceinline__ float recipg(float a) { return 1.0 / a; } +__device__ __forceinline__ double recipg(double a) { return 1.0 / a; } +__device__ __forceinline__ float cosg(float a) { return cosf(a); } +__device__ __forceinline__ double cosg(double a) { return cos(a); } +__device__ __forceinline__ float sing(float a) { return sinf(a); } +__device__ __forceinline__ double sing(double a) { return sin(a); } +__device__ __forceinline__ float sqrtg(float a) { return sqrtf(a); } +__device__ __forceinline__ double sqrtg(double a) { return sqrt(a); } +__device__ __forceinline__ float powg(float a, float b) { return powf(a, b); } +__device__ __forceinline__ double powg(double a, double b) { return pow(a, b); } +__device__ __forceinline__ float tanhg(float a) { return tanhf(a); } +__device__ __forceinline__ double tanhg(double a) { return tanh(a); } +__device__ __forceinline__ float erfg(float a) { return erff(a); } +__device__ __forceinline__ double erfg(double a) { return erf(a); } +__device__ __forceinline__ float ceilg(float a) { return ceilf(a); } +__device__ __forceinline__ double ceilg(double a) { return ceil(a); } +__device__ __forceinline__ float floorg(float a) { return floorf(a); } +__device__ __forceinline__ double floorg(double a) { return floor(a); } +__device__ __forceinline__ float roundg(float a) { return roundf(a); } +__device__ __forceinline__ double roundg(double a) { return round(a); } +__device__ __forceinline__ float normcdfg(float a) { return normcdff(a); } +__device__ __forceinline__ double normcdfg(double a) { return normcdf(a); } +__device__ __forceinline__ float maxg(float a, float b) { return fmaxf(a, b); } +__device__ __forceinline__ double maxg(double a, double b) { return fmax(a, b); } +__device__ __forceinline__ float ming(float a, float b) { return fminf(a, b); } +__device__ __forceinline__ double ming(double a, double b) { return fmin(a, b); } +__device__ __forceinline__ float logg(float a) { return logf(a); } +__device__ __forceinline__ double logg(double a) { return log(a); } +__device__ __forceinline__ float expg(float a) { return expf(a); } +__device__ __forceinline__ double expg(double a) { return exp(a); } +__device__ __forceinline__ float absg(float a) { return fabsf(a); } +__device__ __forceinline__ double absg(double a) { return fabs(a); } +__device__ __forceinline__ float copysigng(float a, float b) { return copysignf(a, b); } +__device__ __forceinline__ double copysigng(double a, double b) { return copysign(a, b); } + +__device__ __forceinline__ int16_t ming(int16_t a, int16_t b) { return min(a, b); } +__device__ __forceinline__ int16_t maxg(int16_t a, int16_t b) { return max(a, b); } +__device__ __forceinline__ int32_t ming(int32_t a, int32_t b) { return min(a, b); } +__device__ __forceinline__ int32_t maxg(int32_t a, int32_t b) { return max(a, b); } +__device__ __forceinline__ int64_t ming(int64_t a, int64_t b) { return min(a, b); } +__device__ __forceinline__ int64_t maxg(int64_t a, int64_t b) { return max(a, b); } +__device__ __forceinline__ uint32_t ming(uint32_t a, uint32_t b) { return min(a, b); } +__device__ __forceinline__ uint32_t maxg(uint32_t a, uint32_t b) { return max(a, b); } +__device__ __forceinline__ uint8_t ming(uint8_t a, uint8_t b) { return min(a, b); } +__device__ __forceinline__ uint8_t maxg(uint8_t a, uint8_t b) { return max(a, b); } +#if __CUDA_ARCH__ >= 530 +__device__ __forceinline__ __half powg(__half a, __half b) { return __float2half(powf(__half2float(a), __half2float(b))); } +__device__ __forceinline__ bool isnang(__half a) { return __hisnan(a); } +__device__ __forceinline__ __half sqrtg(__half a) { return hsqrt(a); } +__device__ __forceinline__ __half cosg(__half a) { return hcos(a); } +__device__ __forceinline__ __half sing(__half a) { return hsin(a); } +__device__ __forceinline__ __half recipg(__half a) { __half one = 1.0; return one / a; } +__device__ __forceinline__ __half maxg(__half a, __half b) { return __hmax_nan(a, b); } +__device__ __forceinline__ __half tanhg(__half a) { return __float2half(tanhf(__half2float(a))); } +__device__ __forceinline__ __half erfg(__half a) { return __float2half(erff(__half2float(a))); } +__device__ __forceinline__ __half ceilg(__half a) { return __float2half(ceilf(__half2float(a))); } +__device__ __forceinline__ __half floorg(__half a) { return __float2half(floorf(__half2float(a))); } +__device__ __forceinline__ __half roundg(__half a) { return __float2half(roundf(__half2float(a))); } +__device__ __forceinline__ __half normcdfg(__half a) { return __float2half(normcdff(__half2float(a))); } +__device__ __forceinline__ __half ming(__half a, __half b) { return __hmin_nan(a, b); } +__device__ __forceinline__ __half logg(__half a) { return hlog(a); } +__device__ __forceinline__ __half expg(__half a) { return hexp(a); } +__device__ __forceinline__ __half absg(__half a) { return __habs(a); } +__device__ __forceinline__ __half copysigng(__half a, __half b) { return __float2half(copysignf(__half2float(a), __half2float(b))); } +#endif + +#if __CUDA_ARCH__ >= 800 +__device__ __forceinline__ __nv_bfloat16 powg(__nv_bfloat16 a, __nv_bfloat16 b) { return __float2bfloat16(powf(__bfloat162float(a), __bfloat162float(b))); } +__device__ __forceinline__ bool isnang(__nv_bfloat16 a) { return __hisnan(a); } +__device__ __forceinline__ __nv_bfloat16 sqrtg(__nv_bfloat16 a) { return hsqrt(a); } +__device__ __forceinline__ __nv_bfloat16 cosg(__nv_bfloat16 a) { return hcos(a); } +__device__ __forceinline__ __nv_bfloat16 sing(__nv_bfloat16 a) { return hsin(a); } +__device__ __forceinline__ __nv_bfloat16 recipg(__nv_bfloat16 a) { __nv_bfloat16 one = 1.0; return one / a; } +__device__ __forceinline__ __nv_bfloat16 maxg(__nv_bfloat16 a, __nv_bfloat16 b) { return __hmax_nan(a, b); } +__device__ __forceinline__ __nv_bfloat16 tanhg(__nv_bfloat16 a) { return __float2bfloat16(tanhf(__bfloat162float(a))); } +__device__ __forceinline__ __nv_bfloat16 erfg(__nv_bfloat16 a) { return __float2bfloat16(erff(__bfloat162float(a))); } +__device__ __forceinline__ __nv_bfloat16 ceilg(__nv_bfloat16 a) { return __float2bfloat16(ceilf(__bfloat162float(a))); } +__device__ __forceinline__ __nv_bfloat16 floorg(__nv_bfloat16 a) { return __float2bfloat16(floorf(__bfloat162float(a))); } +__device__ __forceinline__ __nv_bfloat16 roundg(__nv_bfloat16 a) { return __float2bfloat16(roundf(__bfloat162float(a))); } +__device__ __forceinline__ __nv_bfloat16 normcdfg(__nv_bfloat16 a) { return __float2bfloat16(normcdff(__bfloat162float(a))); } +__device__ __forceinline__ __nv_bfloat16 ming(__nv_bfloat16 a, __nv_bfloat16 b) { return __hmin_nan(a, b); } +__device__ __forceinline__ __nv_bfloat16 logg(__nv_bfloat16 a) { return hlog(a); } +__device__ __forceinline__ __nv_bfloat16 expg(__nv_bfloat16 a) { return hexp(a); } +__device__ __forceinline__ __nv_bfloat16 absg(__nv_bfloat16 a) { return __habs(a); } +__device__ __forceinline__ __nv_bfloat16 copysigng(__nv_bfloat16 a, __nv_bfloat16 b) { return __float2bfloat16(copysignf(__bfloat162float(a), __bfloat162float(b))); } + +#define F8E4M3_TO_FLOAT(x) __half2float(__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3)) + +__device__ __forceinline__ __nv_fp8_e4m3 powg(__nv_fp8_e4m3 a, __nv_fp8_e4m3 b) { return __nv_fp8_e4m3(powf(F8E4M3_TO_FLOAT(a), F8E4M3_TO_FLOAT(b))); } +__device__ __forceinline__ bool isnang(__nv_fp8_e4m3 a) { return isnan(F8E4M3_TO_FLOAT(a)); } +__device__ __forceinline__ __nv_fp8_e4m3 sqrtg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(sqrtf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 cosg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(cosf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 sing(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(sinf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 recipg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(1. / F8E4M3_TO_FLOAT(a)); } +__device__ __forceinline__ __nv_fp8_e4m3 maxg(__nv_fp8_e4m3 a, __nv_fp8_e4m3 b) { return __nv_fp8_e4m3(fmaxf(F8E4M3_TO_FLOAT(a), F8E4M3_TO_FLOAT(b))); } +__device__ __forceinline__ __nv_fp8_e4m3 tanhg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(tanhf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 erfg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(erff(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 ceilg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(ceilf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 floorg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(floorf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 roundg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(roundf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 normcdfg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(normcdff(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 ming(__nv_fp8_e4m3 a, __nv_fp8_e4m3 b) { return __nv_fp8_e4m3(fminf(F8E4M3_TO_FLOAT(a), F8E4M3_TO_FLOAT(b))); } +__device__ __forceinline__ __nv_fp8_e4m3 logg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(logf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 expg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(expf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 absg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(fabsf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 copysigng(__nv_fp8_e4m3 a, __nv_fp8_e4m3 b) { return __nv_fp8_e4m3(copysignf(F8E4M3_TO_FLOAT(a), F8E4M3_TO_FLOAT(b))); } + + +#endif diff --git a/mcandle/mcandle-kernels/src/fill.cu b/mcandle/mcandle-kernels/src/fill.cu new file mode 100644 index 000000000..eeea8d4cd --- /dev/null +++ b/mcandle/mcandle-kernels/src/fill.cu @@ -0,0 +1,57 @@ +#include +#include "cuda_fp16.h" + +template +__device__ void fill_with(T *buf, T value, const size_t numel) { + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + buf[i] = value; + } +} +extern "C" __global__ void fill_u8(uint8_t *buf, uint8_t value, const size_t numel) { fill_with(buf, value, numel); } +extern "C" __global__ void fill_u32(uint32_t *buf, uint32_t value, const size_t numel) { fill_with(buf, value, numel); } +extern "C" __global__ void fill_i16(int16_t *buf, int16_t value, const size_t numel) { fill_with(buf, value, numel); } +extern "C" __global__ void fill_i32(int32_t *buf, int32_t value, const size_t numel) { fill_with(buf, value, numel); } +extern "C" __global__ void fill_i64(int64_t *buf, int64_t value, const size_t numel) { fill_with(buf, value, numel); } +extern "C" __global__ void fill_f32(float *buf, float value, const size_t numel) { fill_with(buf, value, numel); } +extern "C" __global__ void fill_f64(double *buf, double value, const size_t numel) { fill_with(buf, value, numel); } + +template +__device__ void copy2d(const T *src, T *dst, uint32_t d1, uint32_t d2, uint32_t src_s, uint32_t dst_s) { + uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= d1 * d2) { + return; + } + uint32_t idx1 = idx / d2; + uint32_t idx2 = idx - d2 * idx1; + dst[idx1 * dst_s + idx2] = src[idx1 * src_s + idx2]; +} + +#define COPY2D_OP(TYPENAME, FNNAME) \ +extern "C" __global__ \ +void FNNAME(const TYPENAME *src, TYPENAME *dst, uint32_t d1, uint32_t d2, uint32_t src_s, uint32_t dst_s) { \ + copy2d(src, dst, d1, d2, src_s, dst_s); \ +} \ + +COPY2D_OP(float, copy2d_f32) +COPY2D_OP(double, copy2d_f64) +COPY2D_OP(uint8_t, copy2d_u8) +COPY2D_OP(uint32_t, copy2d_u32) +COPY2D_OP(int16_t, copy2d_i16) +COPY2D_OP(int32_t, copy2d_i32) +COPY2D_OP(int64_t, copy2d_i64) + +#if __CUDA_ARCH__ >= 530 +extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); } +COPY2D_OP(__half, copy2d_f16) +#endif + +#if __CUDA_ARCH__ >= 800 +#include +#include + +extern "C" __global__ void fill_bf16(__nv_bfloat16 *buf, __nv_bfloat16 value, const size_t numel) { fill_with(buf, value, numel); } +COPY2D_OP(__nv_bfloat16, copy2d_bf16) + +extern "C" __global__ void fill_f8_e4m3(__nv_fp8_e4m3 *buf, __nv_fp8_e4m3 value, const size_t numel) { fill_with(buf, value, numel); } +COPY2D_OP(__nv_fp8_e4m3, copy2d_f8_e4m3) +#endif diff --git a/mcandle/mcandle-kernels/src/fused_rms_norm.cu b/mcandle/mcandle-kernels/src/fused_rms_norm.cu new file mode 100644 index 000000000..f012e002a --- /dev/null +++ b/mcandle/mcandle-kernels/src/fused_rms_norm.cu @@ -0,0 +1,82 @@ +#include "cuda_fp16.h" +#include + +#define WARP_SIZE 32 + +#ifndef USE_ROCM + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask) +#else + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) +#endif + +template +__inline__ __device__ T warpReduceSum(T val) { +#pragma unroll + for (int mask = WARP_SIZE/2; mask > 0; mask >>= 1) + val += VLLM_SHFL_XOR_SYNC(val, mask); + return val; +} + +__inline__ __device__ constexpr int _calculateLaneMask(int warp_size) { + return warp_size - 1; +} + +__inline__ __device__ constexpr int _calculateWidShift(int warp_size) { + return 5 + (warp_size >> 6); +} + +/* Calculate the sum of all elements in a block */ +template +__inline__ __device__ T blockReduceSum(T val) { + static __shared__ T shared[WARP_SIZE]; + constexpr auto LANE_MASK = _calculateLaneMask(WARP_SIZE); + constexpr auto WID_SHIFT = _calculateWidShift(WARP_SIZE); + int lane = threadIdx.x & LANE_MASK; + int wid = threadIdx.x >> WID_SHIFT; + + val = warpReduceSum(val); + + if (lane == 0) + shared[wid] = val; + + __syncthreads(); + + // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent + // blockDim.x is not divided by 32 + val = (threadIdx.x < (blockDim.x / (WARP_SIZE * 1.0f))) ? shared[lane] : (T)(0.0f); + val = warpReduceSum(val); + return val; +} + +#define RMS_NORM_OP(FN_NAME, TYPENAME)\ +extern "C" __global__ void FN_NAME(\ + TYPENAME* __restrict__ out,\ + const TYPENAME* __restrict__ input,\ + const TYPENAME* __restrict__ weight,\ + const float epsilon,\ + const int num_tokens,\ + const int hidden_size) {\ + __shared__ float s_variance;\ + float variance = 0.0f;\ + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {\ + const float x = (float) input[blockIdx.x * hidden_size + idx];\ + variance += x * x;\ + }\ + variance = blockReduceSum(variance);\ + if (threadIdx.x == 0) {\ + s_variance = rsqrtf(variance / hidden_size + epsilon);\ + }\ + __syncthreads();\ + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {\ + float x = (float) input[blockIdx.x * hidden_size + idx];\ + out[blockIdx.x * hidden_size + idx] = ((TYPENAME) (x * s_variance)) * weight[idx];\ + }\ +}\ + +RMS_NORM_OP(rms_norm_f32, float) +RMS_NORM_OP(rms_norm_f16, __half) + +#if __CUDA_ARCH__ >= 800 +#include +RMS_NORM_OP(rms_norm_bf16, __nv_bfloat16) +#endif \ No newline at end of file diff --git a/mcandle/mcandle-kernels/src/fused_rope.cu b/mcandle/mcandle-kernels/src/fused_rope.cu new file mode 100644 index 000000000..9f7873cca --- /dev/null +++ b/mcandle/mcandle-kernels/src/fused_rope.cu @@ -0,0 +1,231 @@ +#include "cuda_fp16.h" + +#ifndef USE_ROCM + #define LDG(arg) __ldg(arg) +#else + #define LDG(arg) *arg +#endif + +template +inline __device__ void apply_token_rotary_embedding( + scalar_t* __restrict__ arr, + const scalar_t* __restrict__ cos_ptr, + const scalar_t* __restrict__ sin_ptr, + int rot_offset, + int embed_dim) +{ + int x_index, y_index; + scalar_t cos, sin; + if (IS_NEOX) { + // GPT-NeoX style rotary embedding. + x_index = rot_offset; + y_index = embed_dim + rot_offset; + cos = LDG(cos_ptr + x_index); + sin = LDG(sin_ptr + x_index); + } else { + // GPT-J style rotary embedding. + x_index = 2 * rot_offset; + y_index = 2 * rot_offset + 1; + cos = LDG(cos_ptr + x_index / 2); + sin = LDG(sin_ptr + x_index / 2); + } + + const scalar_t x = arr[x_index]; + const scalar_t y = arr[y_index]; + arr[x_index] = x * cos - y * sin; + arr[y_index] = y * cos + x * sin; +} + +template +inline __device__ void apply_rotary_embedding( + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] + const scalar_t* cache_ptr, + const int head_size, + const int num_heads, + const int num_kv_heads, + const int rot_dim, + const int token_idx, + const int64_t query_stride, + const int64_t key_stride) +{ + const int embed_dim = rot_dim / 2; + const scalar_t* cos_ptr = cache_ptr; + const scalar_t* sin_ptr = cache_ptr + embed_dim; + + const int nq = num_heads * embed_dim; + for (int i = threadIdx.x; i < nq; i += blockDim.x) { + const int head_idx = i / embed_dim; + const int64_t token_head = token_idx * query_stride + head_idx * head_size; + const int rot_offset = i % embed_dim; + apply_token_rotary_embedding(query + token_head, cos_ptr, + sin_ptr, rot_offset, embed_dim); + } + + const int nk = num_kv_heads * embed_dim; + for (int i = threadIdx.x; i < nk; i += blockDim.x) { + const int head_idx = i / embed_dim; + const int64_t token_head = token_idx * key_stride + head_idx * head_size; + const int rot_offset = i % embed_dim; + apply_token_rotary_embedding(key + token_head, cos_ptr, + sin_ptr, rot_offset, embed_dim); + } +} + +extern "C" __global__ void rotary_embedding_kernel_f32( + const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] + float* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] + float* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] + const float* __restrict__ cos_sin_cache, // [max_position, rot_dim] + const int rot_dim, + const int64_t query_stride, + const int64_t key_stride, + const int num_heads, + const int num_kv_heads, + const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + const float* cache_ptr = cos_sin_cache + pos * rot_dim; + + apply_rotary_embedding(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); +} + +extern "C" __global__ void rotary_embedding_kernel_f16( + const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] + __half* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] + __half* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] + const __half* __restrict__ cos_sin_cache, // [max_position, rot_dim] + const int rot_dim, + const int64_t query_stride, + const int64_t key_stride, + const int num_heads, + const int num_kv_heads, + const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + const __half* cache_ptr = cos_sin_cache + pos * rot_dim; + + apply_rotary_embedding<__half, false>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); +} + +extern "C" __global__ void rotary_embedding_kernel_f64( + const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] + double* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] + double* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] + const double* __restrict__ cos_sin_cache, // [max_position, rot_dim] + const int rot_dim, + const int64_t query_stride, + const int64_t key_stride, + const int num_heads, + const int num_kv_heads, + const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + const double* cache_ptr = cos_sin_cache + pos * rot_dim; + + apply_rotary_embedding(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); +} + + + + +extern "C" __global__ void rotary_embedding_kernel_neox_f32( + const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] + float* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] + float* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] + const float* __restrict__ cos_sin_cache, // [max_position, rot_dim] + const int rot_dim, + const int64_t query_stride, + const int64_t key_stride, + const int num_heads, + const int num_kv_heads, + const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + const float* cache_ptr = cos_sin_cache + pos * rot_dim; + + apply_rotary_embedding(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); +} + +extern "C" __global__ void rotary_embedding_kernel_neox_f16( + const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] + __half* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] + __half* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] + const __half* __restrict__ cos_sin_cache, // [max_position, rot_dim] + const int rot_dim, + const int64_t query_stride, + const int64_t key_stride, + const int num_heads, + const int num_kv_heads, + const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + const __half* cache_ptr = cos_sin_cache + pos * rot_dim; + + apply_rotary_embedding<__half, true>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); +} + +extern "C" __global__ void rotary_embedding_kernel_neox_f64( + const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] + double* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] + double* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] + const double* __restrict__ cos_sin_cache, // [max_position, rot_dim] + const int rot_dim, + const int64_t query_stride, + const int64_t key_stride, + const int num_heads, + const int num_kv_heads, + const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + const double* cache_ptr = cos_sin_cache + pos * rot_dim; + + apply_rotary_embedding(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); +} + +#if __CUDA_ARCH__ >= 800 +#include +extern "C" __global__ void rotary_embedding_kernel_bf16( + const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] + __nv_bfloat16* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] + __nv_bfloat16* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] + const __nv_bfloat16* __restrict__ cos_sin_cache, // [max_position, rot_dim] + const int rot_dim, + const int64_t query_stride, + const int64_t key_stride, + const int num_heads, + const int num_kv_heads, + const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + const __nv_bfloat16* cache_ptr = cos_sin_cache + pos * rot_dim; + + apply_rotary_embedding<__nv_bfloat16, false>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); +} + +extern "C" __global__ void rotary_embedding_kernel_neox_bf16( + const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] + __nv_bfloat16* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] + __nv_bfloat16* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] + const __nv_bfloat16* __restrict__ cos_sin_cache, // [max_position, rot_dim] + const int rot_dim, + const int64_t query_stride, + const int64_t key_stride, + const int num_heads, + const int num_kv_heads, + const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + const __nv_bfloat16* cache_ptr = cos_sin_cache + pos * rot_dim; + + apply_rotary_embedding<__nv_bfloat16, true>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); +} +#endif \ No newline at end of file diff --git a/mcandle/mcandle-kernels/src/indexing.cu b/mcandle/mcandle-kernels/src/indexing.cu new file mode 100644 index 000000000..52846a04b --- /dev/null +++ b/mcandle/mcandle-kernels/src/indexing.cu @@ -0,0 +1,442 @@ +// WARNING: THIS IS ONLY VALID ASSUMING THAT inp IS CONTIGUOUS! +// TODO: proper error reporting when ids are larger than v_size. +#include "cuda_utils.cuh" +#include + +template +__device__ void index_select( + const size_t numel, + const size_t num_dims, + const size_t *info, + const I *ids, + const T *inp, + T *out, + const size_t left_size, + const size_t src_dim_size, + const size_t ids_dim_size, + const size_t right_size +) { + const size_t *dims = info; + const size_t *strides = info + num_dims; + bool b = is_contiguous(num_dims, dims, strides); + for (unsigned int dst_i = blockIdx.x * blockDim.x + threadIdx.x; dst_i < numel; dst_i += blockDim.x * gridDim.x) { + unsigned int left_i = dst_i / (ids_dim_size * right_size); + unsigned int id_i = dst_i / right_size % ids_dim_size; + unsigned int right_i = dst_i % right_size; + unsigned int src_i = left_i * (src_dim_size * right_size) + ids[id_i] * right_size + right_i; + unsigned strided_i = b ? src_i : get_strided_index(src_i, num_dims, dims, strides); + out[dst_i] = inp[strided_i]; + } +} + +#define IS_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const size_t numel, \ + const size_t num_dims, \ + const size_t *info, \ + const INDEX_TYPENAME *ids, \ + const TYPENAME *inp, \ + TYPENAME *out, \ + const size_t left_size, \ + const size_t src_dim_size, \ + const size_t ids_dim_size, \ + const size_t right_size \ +) { index_select(numel, num_dims, info, ids, inp, out, left_size, src_dim_size, ids_dim_size, right_size); } \ + +template +__device__ void gather( + const size_t numel, + const I *ids, + const T *inp, + T *out, + const size_t left_size, + const size_t src_dim_size, + const size_t ids_dim_size, + const size_t right_size +) { + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + size_t post = i % right_size; + size_t idx = ids[i]; + size_t pre = i / (right_size * ids_dim_size); + size_t src_i = (pre * src_dim_size + idx) * right_size + post; + out[i] = inp[src_i]; + } +} + +#define GATHER_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const size_t numel, \ + const INDEX_TYPENAME *ids, \ + const TYPENAME *inp, \ + TYPENAME *out, \ + const size_t left_size, \ + const size_t src_dim_size, \ + const size_t ids_dim_size, \ + const size_t right_size \ +) { gather(numel, ids, inp, out, left_size, src_dim_size, ids_dim_size, right_size); } \ + +template +__device__ void index_add( + const I *ids, + const size_t ids_dim_size, + const T *inp, + T *out, + const size_t left_size, + const size_t src_dim_size, + const size_t dst_dim_size, + const size_t right_size +) { + const size_t numel = left_size * right_size; + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + const size_t pre = i / right_size; + const size_t post = i % right_size; + for (unsigned int j = 0; j < ids_dim_size; ++j) { + const size_t idx = ids[j]; + const size_t src_i = (pre * ids_dim_size + j) * right_size + post; + const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post; + out[dst_i] += inp[src_i]; + } + } +} + +#if __CUDA_ARCH__ >= 800 +#define F8E4M3_TO_FLOAT(x) __half2float(__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3)) + +template +__device__ void scatter_add_f8( + const I *ids, + const __nv_fp8_e4m3 *inp, + __nv_fp8_e4m3 *out, + const size_t left_size, + const size_t src_dim_size, + const size_t dst_dim_size, + const size_t right_size +) { + const size_t numel = left_size * right_size; + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + const size_t pre = i / right_size; + const size_t post = i % right_size; + for (unsigned int j = 0; j < src_dim_size; ++j) { + const size_t src_i = (pre * src_dim_size + j) * right_size + post; + const size_t idx = ids[src_i]; + const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post; + out[dst_i] = __nv_fp8_e4m3(F8E4M3_TO_FLOAT(out[dst_i]) + F8E4M3_TO_FLOAT(inp[src_i])); + } + } +} + +template +__device__ void index_add_f8( + const I *ids, + const size_t ids_dim_size, + const __nv_fp8_e4m3 *inp, + __nv_fp8_e4m3 *out, + const size_t left_size, + const size_t src_dim_size, + const size_t dst_dim_size, + const size_t right_size +) { + const size_t numel = left_size * right_size; + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + const size_t pre = i / right_size; + const size_t post = i % right_size; + for (unsigned int j = 0; j < ids_dim_size; ++j) { + const size_t idx = ids[j]; + const size_t src_i = (pre * ids_dim_size + j) * right_size + post; + const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post; + out[dst_i] = __nv_fp8_e4m3(F8E4M3_TO_FLOAT(out[dst_i]) + F8E4M3_TO_FLOAT(inp[src_i])); + } + } +} +#endif + +#define IA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const INDEX_TYPENAME *ids, \ + const size_t ids_dim_size, \ + const TYPENAME *inp, \ + TYPENAME *out, \ + const size_t left_size, \ + const size_t src_dim_size, \ + const size_t dst_dim_size, \ + const size_t right_size \ +) { index_add(ids, ids_dim_size, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \ + +#define IA_OP_F8(TYPENAME, INDEX_TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const INDEX_TYPENAME *ids, \ + const size_t ids_dim_size, \ + const TYPENAME *inp, \ + TYPENAME *out, \ + const size_t left_size, \ + const size_t src_dim_size, \ + const size_t dst_dim_size, \ + const size_t right_size \ +) { index_add_f8(ids, ids_dim_size, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \ + +template +__device__ void scatter_add( + const I *ids, + const T *inp, + T *out, + const size_t left_size, + const size_t src_dim_size, + const size_t dst_dim_size, + const size_t right_size +) { + const size_t numel = left_size * right_size; + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + const size_t pre = i / right_size; + const size_t post = i % right_size; + for (unsigned int j = 0; j < src_dim_size; ++j) { + const size_t src_i = (pre * src_dim_size + j) * right_size + post; + const size_t idx = ids[src_i]; + const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post; + out[dst_i] += inp[src_i]; + } + } +} + +#define SA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const INDEX_TYPENAME *ids, \ + const TYPENAME *inp, \ + TYPENAME *out, \ + const size_t left_size, \ + const size_t src_dim_size, \ + const size_t dst_dim_size, \ + const size_t right_size \ +) { scatter_add(ids, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \ + +#define SA_OP_F8(TYPENAME, INDEX_TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const INDEX_TYPENAME *ids, \ + const TYPENAME *inp, \ + TYPENAME *out, \ + const size_t left_size, \ + const size_t src_dim_size, \ + const size_t dst_dim_size, \ + const size_t right_size \ +) { scatter_add_f8(ids, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \ + + +#if __CUDA_ARCH__ >= 800 +IS_OP(__nv_bfloat16, int16_t, is_i16_bf16) +IS_OP(__nv_bfloat16, int32_t, is_i32_bf16) +IS_OP(__nv_bfloat16, int64_t, is_i64_bf16) +IS_OP(__nv_bfloat16, uint32_t, is_u32_bf16) +IS_OP(__nv_bfloat16, uint8_t, is_u8_bf16) +GATHER_OP(__nv_bfloat16, int16_t, gather_i16_bf16) +GATHER_OP(__nv_bfloat16, int32_t, gather_i32_bf16) +GATHER_OP(__nv_bfloat16, int64_t, gather_i64_bf16) +GATHER_OP(__nv_bfloat16, uint32_t, gather_u32_bf16) +GATHER_OP(__nv_bfloat16, uint8_t, gather_u8_bf16) +IA_OP(__nv_bfloat16, int16_t, ia_i16_bf16) +IA_OP(__nv_bfloat16, int32_t, ia_i32_bf16) +IA_OP(__nv_bfloat16, int64_t, ia_i64_bf16) +IA_OP(__nv_bfloat16, uint32_t, ia_u32_bf16) +IA_OP(__nv_bfloat16, uint8_t, ia_u8_bf16) +SA_OP(__nv_bfloat16, int16_t, sa_i16_bf16) +SA_OP(__nv_bfloat16, int32_t, sa_i32_bf16) +SA_OP(__nv_bfloat16, int64_t, sa_i64_bf16) +SA_OP(__nv_bfloat16, uint32_t, sa_u32_bf16) +SA_OP(__nv_bfloat16, uint8_t, sa_u8_bf16) + +IS_OP(__nv_fp8_e4m3, int16_t, is_i16_f8_e4m3) +IS_OP(__nv_fp8_e4m3, int32_t, is_i32_f8_e4m3) +IS_OP(__nv_fp8_e4m3, int64_t, is_i64_f8_e4m3) +IS_OP(__nv_fp8_e4m3, uint32_t, is_u32_f8_e4m3) +IS_OP(__nv_fp8_e4m3, uint8_t, is_u8_f8_e4m3) +GATHER_OP(__nv_fp8_e4m3, int16_t, gather_i16_f8_e4m3) +GATHER_OP(__nv_fp8_e4m3, int32_t, gather_i32_f8_e4m3) +GATHER_OP(__nv_fp8_e4m3, int64_t, gather_i64_f8_e4m3) +GATHER_OP(__nv_fp8_e4m3, uint32_t, gather_u32_f8_e4m3) +GATHER_OP(__nv_fp8_e4m3, uint8_t, gather_u8_f8_e4m3) +IA_OP_F8(__nv_fp8_e4m3, int16_t, ia_i16_f8_e4m3) +IA_OP_F8(__nv_fp8_e4m3, int32_t, ia_i32_f8_e4m3) +IA_OP_F8(__nv_fp8_e4m3, int64_t, ia_i64_f8_e4m3) +IA_OP_F8(__nv_fp8_e4m3, uint32_t, ia_u32_f8_e4m3) +IA_OP_F8(__nv_fp8_e4m3, uint8_t, ia_u8_f8_e4m3) +SA_OP_F8(__nv_fp8_e4m3, int16_t, sa_i16_f8_e4m3) +SA_OP_F8(__nv_fp8_e4m3, int32_t, sa_i32_f8_e4m3) +SA_OP_F8(__nv_fp8_e4m3, int64_t, sa_i64_f8_e4m3) +SA_OP_F8(__nv_fp8_e4m3, uint32_t, sa_u32_f8_e4m3) +SA_OP_F8(__nv_fp8_e4m3, uint8_t, sa_u8_f8_e4m3) +#endif + +#if __CUDA_ARCH__ >= 530 +IS_OP(__half, int16_t, is_i16_f16) +IS_OP(__half, int32_t, is_i32_f16) +IS_OP(__half, int64_t, is_i64_f16) +IS_OP(__half, uint32_t, is_u32_f16) +IS_OP(__half, uint8_t, is_u8_f16) +GATHER_OP(__half, int16_t, gather_i16_f16) +GATHER_OP(__half, int32_t, gather_i32_f16) +GATHER_OP(__half, int64_t, gather_i64_f16) +GATHER_OP(__half, uint32_t, gather_u32_f16) +GATHER_OP(__half, uint8_t, gather_u8_f16) +IA_OP(__half, int16_t, ia_i16_f16) +IA_OP(__half, int32_t, ia_i32_f16) +IA_OP(__half, int64_t, ia_i64_f16) +IA_OP(__half, uint32_t, ia_u32_f16) +IA_OP(__half, uint8_t, ia_u8_f16) +SA_OP(__half, int16_t, sa_i16_f16) +SA_OP(__half, int32_t, sa_i32_f16) +SA_OP(__half, int64_t, sa_i64_f16) +SA_OP(__half, uint32_t, sa_u32_f16) +SA_OP(__half, uint8_t, sa_u8_f16) +#endif + +IS_OP(float, int16_t, is_i16_f32) +IS_OP(double, int16_t, is_i16_f64) +IS_OP(uint8_t, int16_t, is_i16_u8) +IS_OP(uint32_t, int16_t, is_i16_u32) +IS_OP(int16_t, int16_t, is_i16_i16) +IS_OP(int32_t, int16_t, is_i16_i32) +IS_OP(int64_t, int16_t, is_i16_i64) + +IS_OP(float, int32_t, is_i32_f32) +IS_OP(double, int32_t, is_i32_f64) +IS_OP(uint8_t, int32_t, is_i32_u8) +IS_OP(uint32_t, int32_t, is_i32_u32) +IS_OP(int16_t, int32_t, is_i32_i16) +IS_OP(int32_t, int32_t, is_i32_i32) +IS_OP(int64_t, int32_t, is_i32_i64) + +IS_OP(float, int64_t, is_i64_f32) +IS_OP(double, int64_t, is_i64_f64) +IS_OP(uint8_t, int64_t, is_i64_u8) +IS_OP(uint32_t, int64_t, is_i64_u32) +IS_OP(int64_t, int64_t, is_i64_i64) +IS_OP(int32_t, int64_t, is_i64_i32) +IS_OP(int16_t, int64_t, is_i64_i16) + +IS_OP(float, uint32_t, is_u32_f32) +IS_OP(double, uint32_t, is_u32_f64) +IS_OP(uint8_t, uint32_t, is_u32_u8) +IS_OP(int16_t, uint32_t, is_u32_i16) +IS_OP(int32_t, uint32_t, is_u32_i32) +IS_OP(int64_t, uint32_t, is_u32_i64) +IS_OP(uint32_t, uint32_t, is_u32_u32) + +IS_OP(float, uint8_t, is_u8_f32) +IS_OP(double, uint8_t, is_u8_f64) +IS_OP(uint8_t, uint8_t, is_u8_u8) +IS_OP(uint32_t, uint8_t, is_u8_u32) +IS_OP(int16_t, uint8_t, is_u8_i16) +IS_OP(int32_t, uint8_t, is_u8_i32) +IS_OP(int64_t, uint8_t, is_u8_i64) + +GATHER_OP(float, int16_t, gather_i16_f32) +GATHER_OP(double, int16_t, gather_i16_f64) +GATHER_OP(uint8_t, int16_t, gather_i16_u8) +GATHER_OP(uint32_t, int16_t, gather_i16_u32) +GATHER_OP(int16_t, int16_t, gather_i16_i16) +GATHER_OP(int32_t, int16_t, gather_i16_i32) +GATHER_OP(int64_t, int16_t, gather_i16_i64) + +GATHER_OP(float, int32_t, gather_i32_f32) +GATHER_OP(double, int32_t, gather_i32_f64) +GATHER_OP(uint8_t, int32_t, gather_i32_u8) +GATHER_OP(uint32_t, int32_t, gather_i32_u32) +GATHER_OP(int16_t, int32_t, gather_i32_i16) +GATHER_OP(int32_t, int32_t, gather_i32_i32) +GATHER_OP(int64_t, int32_t, gather_i32_i64) + +GATHER_OP(float, int64_t, gather_i64_f32) +GATHER_OP(double, int64_t, gather_i64_f64) +GATHER_OP(uint8_t, int64_t, gather_i64_u8) +GATHER_OP(uint32_t, int64_t, gather_i64_u32) +GATHER_OP(int64_t, int64_t, gather_i64_i64) +GATHER_OP(int32_t, int64_t, gather_i64_i32) +GATHER_OP(int16_t, int64_t, gather_i64_i16) + +GATHER_OP(float, uint32_t, gather_u32_f32) +GATHER_OP(double, uint32_t, gather_u32_f64) +GATHER_OP(uint8_t, uint32_t, gather_u32_u8) +GATHER_OP(int16_t, uint32_t, gather_u32_i16) +GATHER_OP(int32_t, uint32_t, gather_u32_i32) +GATHER_OP(int64_t, uint32_t, gather_u32_i64) +GATHER_OP(uint32_t, uint32_t, gather_u32_u32) + +GATHER_OP(float, uint8_t, gather_u8_f32) +GATHER_OP(double, uint8_t, gather_u8_f64) +GATHER_OP(uint8_t, uint8_t, gather_u8_u8) +GATHER_OP(uint32_t, uint8_t, gather_u8_u32) +GATHER_OP(int16_t, uint8_t, gather_u8_i16) +GATHER_OP(int32_t, uint8_t, gather_u8_i32) +GATHER_OP(int64_t, uint8_t, gather_u8_i64) + +IA_OP(float, int16_t, ia_i16_f32) +IA_OP(double, int16_t, ia_i16_f64) +IA_OP(uint8_t, int16_t, ia_i16_u8) +IA_OP(int16_t, int16_t, ia_i16_i16) +IA_OP(uint16_t, int16_t, ia_i16_u16) + +IA_OP(float, int32_t, ia_i32_f32) +IA_OP(double, int32_t, ia_i32_f64) +IA_OP(uint8_t, int32_t, ia_i32_u8) +IA_OP(int32_t, int32_t, ia_i32_i32) +IA_OP(uint32_t, int32_t, ia_i32_u32) + +IA_OP(float, int64_t, ia_i64_f32) +IA_OP(double, int64_t, ia_i64_f64) +IA_OP(uint8_t, int64_t, ia_i64_u8) +IA_OP(int64_t, int64_t, ia_i64_i64) +IA_OP(uint32_t, int64_t, ia_i64_u32) +IA_OP(int32_t, int64_t, ia_i64_i32) +IA_OP(int16_t, int64_t, ia_i64_i16) + +IA_OP(float, uint32_t, ia_u32_f32) +IA_OP(double, uint32_t, ia_u32_f64) +IA_OP(uint8_t, uint32_t, ia_u32_u8) +IA_OP(int16_t, uint32_t, ia_u32_i16) +IA_OP(int32_t, uint32_t, ia_u32_i32) +IA_OP(int64_t, uint32_t, ia_u32_i64) +IA_OP(uint32_t, uint32_t, ia_u32_u32) + +IA_OP(float, uint8_t, ia_u8_f32) +IA_OP(double, uint8_t, ia_u8_f64) +IA_OP(uint8_t, uint8_t, ia_u8_u8) +IA_OP(uint32_t, uint8_t, ia_u8_u32) +IA_OP(int16_t, uint8_t, ia_u8_i16) +IA_OP(int32_t, uint8_t, ia_u8_i32) +IA_OP(int64_t, uint8_t, ia_u8_i64) + +SA_OP(float, int16_t, sa_i16_f32) +SA_OP(double, int16_t, sa_i16_f64) +SA_OP(uint8_t, int16_t, sa_i16_u8) +SA_OP(int16_t, int16_t, sa_i16_i16) +SA_OP(int32_t, int16_t, sa_i16_i32) +SA_OP(uint32_t, int16_t, sa_i16_u32) + +SA_OP(float, int32_t, sa_i32_f32) +SA_OP(double, int32_t, sa_i32_f64) +SA_OP(uint8_t, int32_t, sa_i32_u8) +SA_OP(int16_t, int32_t, sa_i32_i16) +SA_OP(int32_t, int32_t, sa_i32_i32) +SA_OP(uint32_t, int32_t, sa_i32_u32) + +SA_OP(float, int64_t, sa_i64_f32) +SA_OP(double, int64_t, sa_i64_f64) +SA_OP(uint8_t, int64_t, sa_i64_u8) +SA_OP(int16_t, int64_t, sa_i64_i16) +SA_OP(int32_t, int64_t, sa_i64_i32) +SA_OP(int64_t, int64_t, sa_i64_i64) +SA_OP(uint32_t, int64_t, sa_i64_u32) + +SA_OP(float, uint32_t, sa_u32_f32) +SA_OP(double, uint32_t, sa_u32_f64) +SA_OP(uint8_t, uint32_t, sa_u32_u8) +SA_OP(int16_t, uint32_t, sa_u32_i16) +SA_OP(int32_t, uint32_t, sa_u32_i32) +SA_OP(int64_t, uint32_t, sa_u32_i64) +SA_OP(uint32_t, uint32_t, sa_u32_u32) + +SA_OP(float, uint8_t, sa_u8_f32) +SA_OP(double, uint8_t, sa_u8_f64) +SA_OP(uint8_t, uint8_t, sa_u8_u8) +SA_OP(uint32_t, uint8_t, sa_u8_u32) +SA_OP(int16_t, uint8_t, sa_u8_i16) +SA_OP(int32_t, uint8_t, sa_u8_i32) +SA_OP(int64_t, uint8_t, sa_u8_i64) diff --git a/mcandle/mcandle-kernels/src/kvconcat.cu b/mcandle/mcandle-kernels/src/kvconcat.cu new file mode 100644 index 000000000..2bbd6c53a --- /dev/null +++ b/mcandle/mcandle-kernels/src/kvconcat.cu @@ -0,0 +1,54 @@ +#include "cuda_utils.cuh" +#include + +template +__device__ __forceinline__ void kvconcat_dim0_kernel(T *ltensor, T* rtensor, T *out, + const size_t chunk_l, const size_t chunk_r, const size_t lstride, const size_t rstride) { + size_t idx = GetThreadIdx(); + if (idx < chunk_l * lstride) { + out[idx] = ltensor[idx]; + } else { + out[idx] = rtensor[idx - chunk_l * lstride]; + } +} +template +__device__ __forceinline__ void kvconcat_dim2_kernel(T *ltensor, T* rtensor, T *out, + const size_t chunk_l, const size_t chunk_r, const size_t lstride, const size_t rstride) { + int thread_id = GetThreadIdx(); + int out_stride = lstride + rstride; + int idx = thread_id / out_stride; + int j = thread_id % out_stride; + T* pLeft = ltensor + idx * lstride; + T* pRight = rtensor + idx * rstride; + T* pOut = out + idx * out_stride; + if (idx < chunk_l) { + if (j < lstride) + pOut[j] = pLeft[j]; + else + pOut[j] = pRight[j - lstride]; + } +} + +#define KVCONCAT_OP(TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME(TYPENAME *ltensor, TYPENAME* rtensor, TYPENAME *out, const size_t concat_dim,\ + const size_t chunk_l, const size_t chunk_r, const size_t lstride, const size_t rstride) {\ + if (concat_dim == 2)\ + kvconcat_dim2_kernel(ltensor, rtensor, out, chunk_l, chunk_r, lstride, rstride);\ + else if (concat_dim == 0) {\ + if (blockIdx.x == 0 && threadIdx.x ==0) \ + kvconcat_dim0_kernel(ltensor, rtensor, out, chunk_l, chunk_r, lstride, rstride);\ + }\ +}\ + +KVCONCAT_OP(uint8_t, kvconcat_u8) +KVCONCAT_OP(double, kvconcat_f64) +KVCONCAT_OP(float, kvconcat_f32) + +#if __CUDA_ARCH__ >= 530 +KVCONCAT_OP(__half, kvconcat_f16) +#endif + +#if __CUDA_ARCH__ >= 800 +KVCONCAT_OP(__nv_bfloat16, kvconcat_bf16) +KVCONCAT_OP(__nv_fp8_e4m3, kvconcat_f8_e4m3) +#endif \ No newline at end of file diff --git a/mcandle/mcandle-kernels/src/lib.rs b/mcandle/mcandle-kernels/src/lib.rs new file mode 100644 index 000000000..0bb490ca1 --- /dev/null +++ b/mcandle/mcandle-kernels/src/lib.rs @@ -0,0 +1,14 @@ +pub const AFFINE: &str = include_str!(concat!(env!("OUT_DIR"), "/affine.ptx")); +pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx")); +pub const CAST: &str = include_str!(concat!(env!("OUT_DIR"), "/cast.ptx")); +pub const CONV: &str = include_str!(concat!(env!("OUT_DIR"), "/conv.ptx")); +pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx")); +pub const FUSED_RMS_NORM: &str = include_str!(concat!(env!("OUT_DIR"), "/fused_rms_norm.ptx")); +pub const FUSED_ROPE: &str = include_str!(concat!(env!("OUT_DIR"), "/fused_rope.ptx")); +pub const INDEXING: &str = include_str!(concat!(env!("OUT_DIR"), "/indexing.ptx")); +pub const KVCONCAT: &str = include_str!(concat!(env!("OUT_DIR"), "/kvconcat.ptx")); +pub const QUANTIZED: &str = include_str!(concat!(env!("OUT_DIR"), "/quantized.ptx")); +pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx")); +pub const SORT: &str = include_str!(concat!(env!("OUT_DIR"), "/sort.ptx")); +pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx")); +pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx")); diff --git a/mcandle/mcandle-kernels/src/quantized.cu b/mcandle/mcandle-kernels/src/quantized.cu new file mode 100644 index 000000000..05f878f3d --- /dev/null +++ b/mcandle/mcandle-kernels/src/quantized.cu @@ -0,0 +1,4320 @@ +// Kernels adapted from llama.cpp ggml-cuda.cu +// https://github.com/ggerganov/llama.cpp/blob/master/ggml-cuda.cu +#include "cuda_fp16.h" +#include "cuda_bf16.h" +#include + +#define GGML_UNUSED(x) (void)(x) +#define GGML_CUDA_ASSUME(x) + +#ifdef GGML_QKK_64 +#define QK_K 64 +#define K_SCALE_SIZE 4 +#else +#define QK_K 256 +#define K_SCALE_SIZE 12 +#endif + +#undef GGML_CUDA_F16 +#define GGML_CUDA_DMMV_X 32 +#define CUDA_QUANTIZE_BLOCK_SIZE 256 +#define CUDA_DEQUANTIZE_BLOCK_SIZE 256 +#define K_QUANTS_PER_ITERATION 2 + +typedef uint16_t ggml_fp16_t; +typedef float dfloat; // dequantize float +typedef float2 dfloat2; +typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v); + +static __device__ __forceinline__ float warp_reduce_sum(float x) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x += __shfl_xor_sync(0xffffffff, x, mask, 32); + } + return x; +} + +static __device__ __forceinline__ float warp_reduce_max(float x) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); + } + return x; +} + +static __device__ __forceinline__ int get_int_from_int8(const int8_t * x8, const int & i32) { + const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment + + int x32 = 0; + x32 |= x16[0] << 0; + x32 |= x16[1] << 16; + + return x32; +} + +static __device__ __forceinline__ int get_int_from_uint8(const uint8_t * x8, const int & i32) { + const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment + + int x32 = 0; + x32 |= x16[0] << 0; + x32 |= x16[1] << 16; + + return x32; +} + +static __device__ __forceinline__ int get_int_from_int8_aligned(const int8_t * x8, const int & i32) { + return *((const int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment +} + +static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t * x8, const int & i32) { + return *((const int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment +} + + +#define WARP_SIZE 32 +#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed) + +#define CC_PASCAL 600 +#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products +#define CC_VOLTA 700 +#define CC_OFFSET_AMD 1000000 +#define CC_RDNA1 (CC_OFFSET_AMD + 1010) +#define CC_RDNA2 (CC_OFFSET_AMD + 1030) +#define CC_RDNA3 (CC_OFFSET_AMD + 1100) + +#define MMQ_X_Q4_0_RDNA2 64 +#define MMQ_Y_Q4_0_RDNA2 128 +#define NWARPS_Q4_0_RDNA2 8 +#define MMQ_X_Q4_0_RDNA1 64 +#define MMQ_Y_Q4_0_RDNA1 64 +#define NWARPS_Q4_0_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) +#define MMQ_X_Q4_0_AMPERE 4 +#define MMQ_Y_Q4_0_AMPERE 32 +#define NWARPS_Q4_0_AMPERE 4 +#else +#define MMQ_X_Q4_0_AMPERE 64 +#define MMQ_Y_Q4_0_AMPERE 128 +#define NWARPS_Q4_0_AMPERE 4 +#endif +#define MMQ_X_Q4_0_PASCAL 64 +#define MMQ_Y_Q4_0_PASCAL 64 +#define NWARPS_Q4_0_PASCAL 8 + +#define MMQ_X_Q4_1_RDNA2 64 +#define MMQ_Y_Q4_1_RDNA2 128 +#define NWARPS_Q4_1_RDNA2 8 +#define MMQ_X_Q4_1_RDNA1 64 +#define MMQ_Y_Q4_1_RDNA1 64 +#define NWARPS_Q4_1_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) +#define MMQ_X_Q4_1_AMPERE 4 +#define MMQ_Y_Q4_1_AMPERE 32 +#define NWARPS_Q4_1_AMPERE 4 +#else +#define MMQ_X_Q4_1_AMPERE 64 +#define MMQ_Y_Q4_1_AMPERE 128 +#define NWARPS_Q4_1_AMPERE 4 +#endif +#define MMQ_X_Q4_1_PASCAL 64 +#define MMQ_Y_Q4_1_PASCAL 64 +#define NWARPS_Q4_1_PASCAL 8 + +#define MMQ_X_Q5_0_RDNA2 64 +#define MMQ_Y_Q5_0_RDNA2 128 +#define NWARPS_Q5_0_RDNA2 8 +#define MMQ_X_Q5_0_RDNA1 64 +#define MMQ_Y_Q5_0_RDNA1 64 +#define NWARPS_Q5_0_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) +#define MMQ_X_Q5_0_AMPERE 4 +#define MMQ_Y_Q5_0_AMPERE 32 +#define NWARPS_Q5_0_AMPERE 4 +#else +#define MMQ_X_Q5_0_AMPERE 128 +#define MMQ_Y_Q5_0_AMPERE 64 +#define NWARPS_Q5_0_AMPERE 4 +#endif +#define MMQ_X_Q5_0_PASCAL 64 +#define MMQ_Y_Q5_0_PASCAL 64 +#define NWARPS_Q5_0_PASCAL 8 + +#define MMQ_X_Q5_1_RDNA2 64 +#define MMQ_Y_Q5_1_RDNA2 128 +#define NWARPS_Q5_1_RDNA2 8 +#define MMQ_X_Q5_1_RDNA1 64 +#define MMQ_Y_Q5_1_RDNA1 64 +#define NWARPS_Q5_1_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) +#define MMQ_X_Q5_1_AMPERE 4 +#define MMQ_Y_Q5_1_AMPERE 32 +#define NWARPS_Q5_1_AMPERE 4 +#else +#define MMQ_X_Q5_1_AMPERE 128 +#define MMQ_Y_Q5_1_AMPERE 64 +#define NWARPS_Q5_1_AMPERE 4 +#endif +#define MMQ_X_Q5_1_PASCAL 64 +#define MMQ_Y_Q5_1_PASCAL 64 +#define NWARPS_Q5_1_PASCAL 8 + +#define MMQ_X_Q8_0_RDNA2 64 +#define MMQ_Y_Q8_0_RDNA2 128 +#define NWARPS_Q8_0_RDNA2 8 +#define MMQ_X_Q8_0_RDNA1 64 +#define MMQ_Y_Q8_0_RDNA1 64 +#define NWARPS_Q8_0_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) +#define MMQ_X_Q8_0_AMPERE 4 +#define MMQ_Y_Q8_0_AMPERE 32 +#define NWARPS_Q8_0_AMPERE 4 +#else +#define MMQ_X_Q8_0_AMPERE 128 +#define MMQ_Y_Q8_0_AMPERE 64 +#define NWARPS_Q8_0_AMPERE 4 +#endif +#define MMQ_X_Q8_0_PASCAL 64 +#define MMQ_Y_Q8_0_PASCAL 64 +#define NWARPS_Q8_0_PASCAL 8 + +#define MMQ_X_Q2_K_RDNA2 64 +#define MMQ_Y_Q2_K_RDNA2 128 +#define NWARPS_Q2_K_RDNA2 8 +#define MMQ_X_Q2_K_RDNA1 128 +#define MMQ_Y_Q2_K_RDNA1 32 +#define NWARPS_Q2_K_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) +#define MMQ_X_Q2_K_AMPERE 4 +#define MMQ_Y_Q2_K_AMPERE 32 +#define NWARPS_Q2_K_AMPERE 4 +#else +#define MMQ_X_Q2_K_AMPERE 64 +#define MMQ_Y_Q2_K_AMPERE 128 +#define NWARPS_Q2_K_AMPERE 4 +#endif +#define MMQ_X_Q2_K_PASCAL 64 +#define MMQ_Y_Q2_K_PASCAL 64 +#define NWARPS_Q2_K_PASCAL 8 + +#define MMQ_X_Q3_K_RDNA2 128 +#define MMQ_Y_Q3_K_RDNA2 64 +#define NWARPS_Q3_K_RDNA2 8 +#define MMQ_X_Q3_K_RDNA1 32 +#define MMQ_Y_Q3_K_RDNA1 128 +#define NWARPS_Q3_K_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) +#define MMQ_X_Q3_K_AMPERE 4 +#define MMQ_Y_Q3_K_AMPERE 32 +#define NWARPS_Q3_K_AMPERE 4 +#else +#define MMQ_X_Q3_K_AMPERE 128 +#define MMQ_Y_Q3_K_AMPERE 128 +#define NWARPS_Q3_K_AMPERE 4 +#endif +#define MMQ_X_Q3_K_PASCAL 64 +#define MMQ_Y_Q3_K_PASCAL 64 +#define NWARPS_Q3_K_PASCAL 8 + +#define MMQ_X_Q4_K_RDNA2 64 +#define MMQ_Y_Q4_K_RDNA2 128 +#define NWARPS_Q4_K_RDNA2 8 +#define MMQ_X_Q4_K_RDNA1 32 +#define MMQ_Y_Q4_K_RDNA1 64 +#define NWARPS_Q4_K_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) +#define MMQ_X_Q4_K_AMPERE 4 +#define MMQ_Y_Q4_K_AMPERE 32 +#define NWARPS_Q4_K_AMPERE 4 +#else +#define MMQ_X_Q4_K_AMPERE 64 +#define MMQ_Y_Q4_K_AMPERE 128 +#define NWARPS_Q4_K_AMPERE 4 +#endif +#define MMQ_X_Q4_K_PASCAL 64 +#define MMQ_Y_Q4_K_PASCAL 64 +#define NWARPS_Q4_K_PASCAL 8 + +#define MMQ_X_Q5_K_RDNA2 64 +#define MMQ_Y_Q5_K_RDNA2 128 +#define NWARPS_Q5_K_RDNA2 8 +#define MMQ_X_Q5_K_RDNA1 32 +#define MMQ_Y_Q5_K_RDNA1 64 +#define NWARPS_Q5_K_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) +#define MMQ_X_Q5_K_AMPERE 4 +#define MMQ_Y_Q5_K_AMPERE 32 +#define NWARPS_Q5_K_AMPERE 4 +#else +#define MMQ_X_Q5_K_AMPERE 64 +#define MMQ_Y_Q5_K_AMPERE 128 +#define NWARPS_Q5_K_AMPERE 4 +#endif +#define MMQ_X_Q5_K_PASCAL 64 +#define MMQ_Y_Q5_K_PASCAL 64 +#define NWARPS_Q5_K_PASCAL 8 + +#define MMQ_X_Q6_K_RDNA2 64 +#define MMQ_Y_Q6_K_RDNA2 128 +#define NWARPS_Q6_K_RDNA2 8 +#define MMQ_X_Q6_K_RDNA1 32 +#define MMQ_Y_Q6_K_RDNA1 64 +#define NWARPS_Q6_K_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) +#define MMQ_X_Q6_K_AMPERE 4 +#define MMQ_Y_Q6_K_AMPERE 32 +#define NWARPS_Q6_K_AMPERE 4 +#else +#define MMQ_X_Q6_K_AMPERE 64 +#define MMQ_Y_Q6_K_AMPERE 64 +#define NWARPS_Q6_K_AMPERE 4 +#endif +#define MMQ_X_Q6_K_PASCAL 64 +#define MMQ_Y_Q6_K_PASCAL 64 +#define NWARPS_Q6_K_PASCAL 8 + + +// QK = number of values after dequantization +// QR = QK / number of values before dequantization +// QI = number of 32 bit integers before dequantization + +#define QK4_0 32 +#define QR4_0 2 +#define QI4_0 (QK4_0 / (4 * QR4_0)) +typedef struct { + half d; // delta + uint8_t qs[QK4_0 / 2]; // nibbles / quants +} block_q4_0; +static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding"); + +#define QK4_1 32 +#define QR4_1 2 +#define QI4_1 (QK4_1 / (4 * QR4_1)) +typedef struct { + half2 dm; // dm.x = delta, dm.y = min + uint8_t qs[QK4_1 / 2]; // nibbles / quants +} block_q4_1; +static_assert(sizeof(block_q4_1) == sizeof(ggml_fp16_t) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding"); + +#define QK5_0 32 +#define QR5_0 2 +#define QI5_0 (QK5_0 / (4 * QR5_0)) +typedef struct { + half d; // delta + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_0 / 2]; // nibbles / quants +} block_q5_0; +static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding"); + +#define QK5_1 32 +#define QR5_1 2 +#define QI5_1 (QK5_1 / (4 * QR5_1)) +typedef struct { + half2 dm; // dm.x = delta, dm.y = min + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_1 / 2]; // nibbles / quants +} block_q5_1; +static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding"); + +#define QK8_0 32 +#define QR8_0 1 +#define QI8_0 (QK8_0 / (4 * QR8_0)) +typedef struct { + half d; // delta + int8_t qs[QK8_0]; // quants +} block_q8_0; +static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding"); + +#define QK8_1 32 +#define QR8_1 1 +#define QI8_1 (QK8_1 / (4 * QR8_1)) +typedef struct { + half2 ds; // ds.x = delta, ds.y = sum + int8_t qs[QK8_0]; // quants +} block_q8_1; +static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_fp16_t) + QK8_0, "wrong q8_1 block size/padding"); + +typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs); +typedef void (*allocate_tiles_cuda_t)(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc); +typedef void (*load_tiles_cuda_t)( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row); +typedef float (*vec_dot_q_mul_mat_cuda_t)( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ms, const int & i, const int & j, const int & k); + +#define QR2_K 4 +#define QI2_K (QK_K / (4*QR2_K)) +typedef struct { + uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits + uint8_t qs[QK_K/4]; // quants + half2 dm; // super-block scale for quantized scales/mins +} block_q2_K; +static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding"); + +#define QR3_K 4 +#define QI3_K (QK_K / (4*QR3_K)) +typedef struct { + uint8_t hmask[QK_K/8]; // quants - high bit + uint8_t qs[QK_K/4]; // quants - low 2 bits +#ifdef GGML_QKK_64 + uint8_t scales[2]; // scales, quantized with 8 bits +#else + uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits +#endif + half d; // super-block scale +} block_q3_K; +//static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE, "wrong q3_K block size/padding"); + +#define QR4_K 2 +#define QI4_K (QK_K / (4*QR4_K)) +#ifdef GGML_QKK_64 +typedef struct { + half dm[2]; // super-block scales/mins + uint8_t scales[2]; // 4-bit block scales/mins + uint8_t qs[QK_K/2]; // 4--bit quants +} block_q4_K; +static_assert(sizeof(block_q4_K) == sizeof(half2) + QK_K/2 + 2, "wrong q4_K block size/padding"); +#else +typedef struct { + half2 dm; // super-block scale for quantized scales/mins + uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits + uint8_t qs[QK_K/2]; // 4--bit quants +} block_q4_K; +static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding"); +#endif + +#define QR5_K 2 +#define QI5_K (QK_K / (4*QR5_K)) +#ifdef GGML_QKK_64 +typedef struct { + half d; // super-block scale + int8_t scales[QK_K/16]; // block scales + uint8_t qh[QK_K/8]; // quants, high bit + uint8_t qs[QK_K/2]; // quants, low 4 bits +} block_q5_K; +static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding"); +#else +typedef struct { + half2 dm; // super-block scale for quantized scales/mins + uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits + uint8_t qh[QK_K/8]; // quants, high bit + uint8_t qs[QK_K/2]; // quants, low 4 bits +} block_q5_K; +static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding"); +#endif + +#define QR6_K 2 +#define QI6_K (QK_K / (4*QR6_K)) +typedef struct { + uint8_t ql[QK_K/2]; // quants, lower 4 bits + uint8_t qh[QK_K/4]; // quants, upper 2 bits + int8_t scales[QK_K/16]; // scales + half d; // delta +} block_q6_K; +static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding"); + +// In llama.cpp this is only used for intermediate quantization and dot products +typedef struct { + float d; // delta + int8_t qs[QK_K]; // quants + int16_t bsums[QK_K/16]; // sum of quants in groups of 16 +} block_q8_K; +static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding"); + + +template +static __device__ __forceinline__ void mul_mat_q( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + + const block_q_t * x = (const block_q_t *) vx; + const block_q8_1 * y = (const block_q8_1 *) vy; + + const int blocks_per_row_x = ncols_x / qk; + const int blocks_per_col_y = nrows_y / QK8_1; + const int blocks_per_warp = WARP_SIZE / qi; + + const int & ncols_dst = ncols_y; + + const int row_dst_0 = blockIdx.x*mmq_y; + const int & row_x_0 = row_dst_0; + + const int col_dst_0 = blockIdx.y*mmq_x; + const int & col_y_0 = col_dst_0; + + int * tile_x_ql = nullptr; + half2 * tile_x_dm = nullptr; + int * tile_x_qh = nullptr; + int * tile_x_sc = nullptr; + + allocate_tiles(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc); + + __shared__ int tile_y_qs[mmq_x * WARP_SIZE]; + __shared__ half2 tile_y_ds[mmq_x * WARP_SIZE/QI8_1]; + + float sum[mmq_y/WARP_SIZE][mmq_x/nwarps] = {{0.0f}}; + + for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) { + + load_tiles(x + row_x_0*blocks_per_row_x + ib0, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, + threadIdx.y, nrows_x-row_x_0-1, threadIdx.x, blocks_per_row_x); + +#pragma unroll + for (int ir = 0; ir < qr; ++ir) { + const int kqs = ir*WARP_SIZE + threadIdx.x; + const int kbxd = kqs / QI8_1; + +#pragma unroll + for (int i = 0; i < mmq_x; i += nwarps) { + const int col_y_eff = min(col_y_0 + threadIdx.y + i, ncols_y-1); // to prevent out-of-bounds memory accesses + + const block_q8_1 * by0 = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kbxd]; + + const int index_y = (threadIdx.y + i) * WARP_SIZE + kqs % WARP_SIZE; + tile_y_qs[index_y] = get_int_from_int8_aligned(by0->qs, threadIdx.x % QI8_1); + } + +#pragma unroll + for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) { + const int ids = (ids0 + threadIdx.y * QI8_1 + threadIdx.x / (WARP_SIZE/QI8_1)) % mmq_x; + const int kby = threadIdx.x % (WARP_SIZE/QI8_1); + const int col_y_eff = min(col_y_0 + ids, ncols_y-1); + + // if the sum is not needed it's faster to transform the scale to f32 ahead of time + const half2 * dsi_src = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + ir*(WARP_SIZE/QI8_1) + kby].ds; + half2 * dsi_dst = &tile_y_ds[ids * (WARP_SIZE/QI8_1) + kby]; + if (need_sum) { + *dsi_dst = *dsi_src; + } else { + float * dfi_dst = (float *) dsi_dst; + *dfi_dst = __low2half(*dsi_src); + } + } + + __syncthreads(); + +// #pragma unroll // unrolling this loop causes too much register pressure + for (int k = ir*WARP_SIZE/qr; k < (ir+1)*WARP_SIZE/qr; k += vdr) { +#pragma unroll + for (int j = 0; j < mmq_x; j += nwarps) { +#pragma unroll + for (int i = 0; i < mmq_y; i += WARP_SIZE) { + sum[i/WARP_SIZE][j/nwarps] += vec_dot( + tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds, + threadIdx.x + i, threadIdx.y + j, k); + } + } + } + + __syncthreads(); + } + } + +#pragma unroll + for (int j = 0; j < mmq_x; j += nwarps) { + const int col_dst = col_dst_0 + j + threadIdx.y; + + if (col_dst >= ncols_dst) { + return; + } + +#pragma unroll + for (int i = 0; i < mmq_y; i += WARP_SIZE) { + const int row_dst = row_dst_0 + threadIdx.x + i; + + if (row_dst >= nrows_dst) { + continue; + } + + dst[col_dst*nrows_dst + row_dst] = sum[i/WARP_SIZE][j/nwarps]; + } + } +} + +template static __device__ __forceinline__ void load_tiles_q4_0( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { + (void)x_qh; (void)x_sc; + + const int kbx = k / QI4_0; + const int kqsx = k % QI4_0; + + const block_q4_0 * bx0 = (const block_q4_0 *) vx; + + float * x_dmf = (float *) x_dm; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbx; + + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx); + // x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbx] = bxi->d; + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI4_0; + const int kbxd = k % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) { + int i = i0 + i_offset * QI4_0 + k / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbxd; + + x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = bxi->d; + } +} + +template static __device__ __forceinline__ void load_tiles_q4_1( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { + GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + + GGML_CUDA_ASSUME(i_offset >= 0); + GGML_CUDA_ASSUME(i_offset < nwarps); + GGML_CUDA_ASSUME(k >= 0); + GGML_CUDA_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI4_1; + const int kqsx = k % QI4_1; + + const block_q4_1 * bx0 = (const block_q4_1 *) vx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbx; + + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI4_1; + const int kbxd = k % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) { + int i = i0 + i_offset * QI4_1 + k / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbxd; + + x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm; + } +} + + +template static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + (void)x_qh; (void)x_sc; + + __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + mmq_y]; + __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI4_0) + mmq_y/QI4_0]; + + *x_ql = tile_x_qs; + *x_dm = (half2 *) tile_x_d; +} + +template static __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + + __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_1) + mmq_y/QI4_1]; + + *x_ql = tile_x_qs; + *x_dm = tile_x_dm; +} + +static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ + const block_q4_0 * x = (const block_q4_0 *) vx; + + const dfloat d = x[ib].d; + + const int vui = x[ib].qs[iqs]; + + v.x = vui & 0xF; + v.y = vui >> 4; + +#ifdef GGML_CUDA_F16 + v = __hsub2(v, {8.0f, 8.0f}); + v = __hmul2(v, {d, d}); +#else + v.x = (v.x - 8.0f) * d; + v.y = (v.y - 8.0f) * d; +#endif // GGML_CUDA_F16 +} + +static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ + const block_q4_1 * x = (const block_q4_1 *) vx; + + const dfloat d = __low2half(x[ib].dm); + const dfloat m = __high2half(x[ib].dm); + + const int vui = x[ib].qs[iqs]; + + v.x = vui & 0xF; + v.y = vui >> 4; + +#ifdef GGML_CUDA_F16 + v = __hmul2(v, {d, d}); + v = __hadd2(v, {m, m}); +#else + v.x = (v.x * d) + m; + v.y = (v.y * d) + m; +#endif // GGML_CUDA_F16 +} + +static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ + const block_q5_0 * x = (const block_q5_0 *) vx; + + const dfloat d = x[ib].d; + + uint32_t qh; + memcpy(&qh, x[ib].qh, sizeof(qh)); + + const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10; + const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10; + + v.x = ((x[ib].qs[iqs] & 0xf) | xh_0); + v.y = ((x[ib].qs[iqs] >> 4) | xh_1); + +#ifdef GGML_CUDA_F16 + v = __hsub2(v, {16.0f, 16.0f}); + v = __hmul2(v, {d, d}); +#else + v.x = (v.x - 16.0f) * d; + v.y = (v.y - 16.0f) * d; +#endif // GGML_CUDA_F16 +} + +static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ + const block_q5_1 * x = (const block_q5_1 *) vx; + + const dfloat d = __low2half(x[ib].dm); + const dfloat m = __high2half(x[ib].dm); + + uint32_t qh; + memcpy(&qh, x[ib].qh, sizeof(qh)); + + const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10; + const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10; + + v.x = ((x[ib].qs[iqs] & 0xf) | xh_0); + v.y = ((x[ib].qs[iqs] >> 4) | xh_1); + +#ifdef GGML_CUDA_F16 + v = __hmul2(v, {d, d}); + v = __hadd2(v, {m, m}); +#else + v.x = (v.x * d) + m; + v.y = (v.y * d) + m; +#endif // GGML_CUDA_F16 +} + +static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ + const block_q8_0 * x = (const block_q8_0 *) vx; + + const dfloat d = x[ib].d; + + v.x = x[ib].qs[iqs + 0]; + v.y = x[ib].qs[iqs + 1]; + +#ifdef GGML_CUDA_F16 + v = __hmul2(v, {d, d}); +#else + v.x *= d; + v.y *= d; +#endif // GGML_CUDA_F16 +} + + +template +static __device__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) { + const int i = 2*(blockDim.x*blockIdx.x + threadIdx.x); + + if (i >= k) { + return; + } + + const int ib = i/qk; // block index + const int iqs = (i%qk)/qr; // quant index + const int iybs = i - i%qk; // y block start index + const int y_offset = qr == 1 ? 1 : qk/2; + + // dequantize + dfloat2 v; + dequantize_kernel(vx, ib, iqs, v); + + y[iybs + iqs + 0] = v.x; + y[iybs + iqs + y_offset] = v.y; +} + +template +static __device__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) { + + const int64_t i = blockIdx.x; + + // assume 32 threads + const int tid = threadIdx.x; + const int il = tid/8; + const int ir = tid%8; + const int64_t ib = 8*i + ir; + if (ib >= nb32) { + return; + } + + dst_t * y = yy + 256*i + 32*ir + 4*il; + + const block_q4_0 * x = (const block_q4_0 *)vx + ib; + const float d = __half2float(x->d); + const float dm = -8*d; + + const uint8_t * q = x->qs + 4*il; + + for (int l = 0; l < 4; ++l) { + y[l+ 0] = d * (q[l] & 0xF) + dm; + y[l+16] = d * (q[l] >> 4) + dm; + } +} + +template +static __device__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) { + + const int64_t i = blockIdx.x; + + // assume 32 threads + const int tid = threadIdx.x; + const int il = tid/8; + const int ir = tid%8; + const int64_t ib = 8*i + ir; + if (ib >= nb32) { + return; + } + + dst_t * y = yy + 256*i + 32*ir + 4*il; + + const block_q4_1 * x = (const block_q4_1 *)vx + ib; + const float2 d = __half22float2(x->dm); + + const uint8_t * q = x->qs + 4*il; + + for (int l = 0; l < 4; ++l) { + y[l+ 0] = d.x * (q[l] & 0xF) + d.y; + y[l+16] = d.x * (q[l] >> 4) + d.y; + } +} + +//================================== k-quants + +template +static __device__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int i = blockIdx.x; + const block_q2_K * x = (const block_q2_K *) vx; + + const int tid = threadIdx.x; +#if QK_K == 256 + const int n = tid/32; + const int l = tid - 32*n; + const int is = 8*n + l/16; + + const uint8_t q = x[i].qs[32*n + l]; + dst_t * y = yy + i*QK_K + 128*n; + + float dall = __low2half(x[i].dm); + float dmin = __high2half(x[i].dm); + y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4); + y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4); + y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4); + y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4); +#else + const int is = tid/16; // 0 or 1 + const int il = tid%16; // 0...15 + const uint8_t q = x[i].qs[il] >> (2*is); + dst_t * y = yy + i*QK_K + 16*is + il; + float dall = __low2half(x[i].dm); + float dmin = __high2half(x[i].dm); + y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4); + y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4); +#endif + +} + +template +static __device__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int i = blockIdx.x; + const block_q3_K * x = (const block_q3_K *) vx; + +#if QK_K == 256 + const int r = threadIdx.x/4; + const int tid = r/2; + const int is0 = r%2; + const int l0 = 16*is0 + 4*(threadIdx.x%4); + const int n = tid / 4; + const int j = tid - 4*n; + + uint8_t m = 1 << (4*n + j); + int is = 8*n + 2*j + is0; + int shift = 2*j; + + int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) : + is < 8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) : + is < 12 ? (x[i].scales[is-8] >> 4) | (((x[i].scales[is+0] >> 4) & 3) << 4) : + (x[i].scales[is-8] >> 4) | (((x[i].scales[is-4] >> 6) & 3) << 4); + float d_all = x[i].d; + float dl = d_all * (us - 32); + + dst_t * y = yy + i*QK_K + 128*n + 32*j; + const uint8_t * q = x[i].qs + 32*n; + const uint8_t * hm = x[i].hmask; + + for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4)); +#else + const int tid = threadIdx.x; + const int is = tid/16; // 0 or 1 + const int il = tid%16; // 0...15 + const int im = il/8; // 0...1 + const int in = il%8; // 0...7 + + dst_t * y = yy + i*QK_K + 16*is + il; + + const uint8_t q = x[i].qs[il] >> (2*is); + const uint8_t h = x[i].hmask[in] >> (2*is + im); + const float d = (float)x[i].d; + + if (is == 0) { + y[ 0] = d * ((x[i].scales[0] & 0xF) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4)); + y[32] = d * ((x[i].scales[1] & 0xF) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4)); + } else { + y[ 0] = d * ((x[i].scales[0] >> 4) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4)); + y[32] = d * ((x[i].scales[1] >> 4) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4)); + } +#endif + +} + +#if QK_K == 256 +static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) { + if (j < 4) { + d = q[j] & 63; m = q[j + 4] & 63; + } else { + d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4); + m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); + } +} +#endif + +template +static __device__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { + const block_q4_K * x = (const block_q4_K *) vx; + + const int i = blockIdx.x; + +#if QK_K == 256 + // assume 32 threads + const int tid = threadIdx.x; + const int il = tid/8; + const int ir = tid%8; + const int is = 2*il; + const int n = 4; + + dst_t * y = yy + i*QK_K + 64*il + n*ir; + + const float dall = __low2half(x[i].dm); + const float dmin = __high2half(x[i].dm); + + const uint8_t * q = x[i].qs + 32*il + n*ir; + + uint8_t sc, m; + get_scale_min_k4(is + 0, x[i].scales, sc, m); + const float d1 = dall * sc; const float m1 = dmin * m; + get_scale_min_k4(is + 1, x[i].scales, sc, m); + const float d2 = dall * sc; const float m2 = dmin * m; + for (int l = 0; l < n; ++l) { + y[l + 0] = d1 * (q[l] & 0xF) - m1; + y[l +32] = d2 * (q[l] >> 4) - m2; + } +#else + const int tid = threadIdx.x; + const uint8_t * q = x[i].qs; + dst_t * y = yy + i*QK_K; + const float d = (float)x[i].dm[0]; + const float m = (float)x[i].dm[1]; + y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4); + y[tid+32] = d * (x[i].scales[1] & 0xF) * (q[tid] >> 4) - m * (x[i].scales[1] >> 4); +#endif +} + +template +static __device__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { + const block_q5_K * x = (const block_q5_K *) vx; + + const int i = blockIdx.x; + +#if QK_K == 256 + // assume 64 threads - this is very slightly better than the one below + const int tid = threadIdx.x; + const int il = tid/16; // il is in 0...3 + const int ir = tid%16; // ir is in 0...15 + const int is = 2*il; // is is in 0...6 + + dst_t * y = yy + i*QK_K + 64*il + 2*ir; + + const float dall = __low2half(x[i].dm); + const float dmin = __high2half(x[i].dm); + + const uint8_t * ql = x[i].qs + 32*il + 2*ir; + const uint8_t * qh = x[i].qh + 2*ir; + + uint8_t sc, m; + get_scale_min_k4(is + 0, x[i].scales, sc, m); + const float d1 = dall * sc; const float m1 = dmin * m; + get_scale_min_k4(is + 1, x[i].scales, sc, m); + const float d2 = dall * sc; const float m2 = dmin * m; + + uint8_t hm = 1 << (2*il); + y[ 0] = d1 * ((ql[ 0] & 0xF) + (qh[ 0] & hm ? 16 : 0)) - m1; + y[ 1] = d1 * ((ql[ 1] & 0xF) + (qh[ 1] & hm ? 16 : 0)) - m1; + hm <<= 1; + y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2; + y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2; +#else + const int tid = threadIdx.x; + const uint8_t q = x[i].qs[tid]; + const int im = tid/8; // 0...3 + const int in = tid%8; // 0...7 + const int is = tid/16; // 0 or 1 + const uint8_t h = x[i].qh[in] >> im; + const float d = x[i].d; + dst_t * y = yy + i*QK_K + tid; + y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16)); + y[32] = d * x[i].scales[is+2] * ((q >> 4) - ((h >> 4) & 1 ? 0 : 16)); +#endif +} + +template +static __device__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { + const block_q6_K * x = (const block_q6_K *) vx; + + const int64_t i = blockIdx.x; +#if QK_K == 256 + + // assume 64 threads - this is very slightly better than the one below + const int64_t tid = threadIdx.x; + const int64_t ip = tid/32; // ip is 0 or 1 + const int64_t il = tid - 32*ip; // 0...32 + const int64_t is = 8*ip + il/16; + + dst_t * y = yy + i*QK_K + 128*ip + il; + + const float d = x[i].d; + + const uint8_t * ql = x[i].ql + 64*ip + il; + const uint8_t qh = x[i].qh[32*ip + il]; + const int8_t * sc = x[i].scales + is; + + y[ 0] = d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32); + y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32); + y[64] = d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32); + y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32); +#else + + // assume 32 threads + const int64_t tid = threadIdx.x; + const int64_t ip = tid/16; // 0 or 1 + const int64_t il = tid - 16*ip; // 0...15 + + dst_t * y = yy + i*QK_K + 16*ip + il; + + const float d = x[i].d; + + const uint8_t ql = x[i].ql[16*ip + il]; + const uint8_t qh = x[i].qh[il] >> (2*ip); + const int8_t * sc = x[i].scales; + + y[ 0] = d * sc[ip+0] * ((int8_t)((ql & 0xF) | (((qh >> 0) & 3) << 4)) - 32); + y[32] = d * sc[ip+2] * ((int8_t)((ql >> 4) | (((qh >> 4) & 3) << 4)) - 32); +#endif +} + +template +static __device__ void dequantize_block_q8_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) { + const int i = blockIdx.x; + + // assume 32 threads + const int tid = threadIdx.x; + const int il = tid/8; + const int ir = tid%8; + const int ib = 8*i + ir; + if (ib >= nb32) { + return; + } + + dst_t * y = yy + 256*i + 32*ir + 8*il; + + const block_q8_0 * x = (const block_q8_0 *)vx + ib; + const float d = __half2float(x->d); + + const int8_t * q = x->qs + 8*il; + + for (int l = 0; l < 8; ++l) { + y[l] = d * q[l]; + } +} + +template +static __device__ void dequantize_block_q8_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { + const block_q8_K * x = (const block_q8_K *) vx; + + const int i = blockIdx.x; + +#if QK_K == 256 + // assume 32 threads + const int tid = threadIdx.x; + const int il = tid/8; + const int ir = tid%8; + const int n = 8; + + dst_t * y = yy + i*QK_K + 64*il + n*ir; + + const int8_t * q = x[i].qs + 64*il + n*ir; + + for (int l = 0; l < n; ++l) { + y[l] = q[l] * x[i].d; + } +#else + const int tid = threadIdx.x; + const uint8_t * q = x[i].qs; + float * y = yy + i*QK_K; + y[tid] = x[i].d * x[i].scales[0]; +#endif +} + +template +static __device__ void dequantize_block_q5_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) { + return dequantize_block(vx, yy, nb32); +} + +template +static __device__ void dequantize_block_q5_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) { + return dequantize_block(vx, yy, nb32); +} + +#define DEQUANTIZE_K(QNAME) \ +extern "C" __global__ void dequantize_block_##QNAME##_f32(const void * __restrict__ vx, float * __restrict__ y) { \ + dequantize_block_##QNAME(vx, y); \ +} \ +extern "C" __global__ void dequantize_block_##QNAME##_f16(const void * __restrict__ vx, half * __restrict__ y) { \ + dequantize_block_##QNAME(vx, y); \ +} \ + +#define DEQUANTIZE(QNAME) \ +extern "C" __global__ void dequantize_block_##QNAME##_f32(const void * __restrict__ vx, float * __restrict__ y, const int k) { \ + dequantize_block_##QNAME(vx, y, k); \ +} \ +extern "C" __global__ void dequantize_block_##QNAME##_f16(const void * __restrict__ vx, half * __restrict__ y, const int k) { \ + dequantize_block_##QNAME(vx, y, k); \ +} \ + +DEQUANTIZE_K(q2_K) +DEQUANTIZE_K(q3_K) +DEQUANTIZE_K(q4_K) +DEQUANTIZE_K(q5_K) +DEQUANTIZE_K(q6_K) +DEQUANTIZE_K(q8_K) +DEQUANTIZE(q4_0) +DEQUANTIZE(q4_1) +DEQUANTIZE(q5_0) +DEQUANTIZE(q5_1) +DEQUANTIZE(q8_0) + +template +static __device__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) { + // qk = quantized weights per x block + // qr = number of quantized weights per data value in x block + const int row = blockIdx.x*blockDim.y + threadIdx.y; + + if (row >= nrows) { + return; + } + + const int tid = threadIdx.x; + + const int iter_stride = 2*GGML_CUDA_DMMV_X; + const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter + const int y_offset = qr == 1 ? 1 : qk/2; + +// partial sum for each thread +#ifdef GGML_CUDA_F16 + half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics +#else + float tmp = 0.0f; +#endif // GGML_CUDA_F16 + + for (int i = 0; i < ncols; i += iter_stride) { + const int col = i + vals_per_iter*tid; + const int ib = (row*ncols + col)/qk; // x block index + const int iqs = (col%qk)/qr; // x quant index + const int iybs = col - col%qk; // y block start index + +// processing >2 values per i iter is faster for fast GPUs +#pragma unroll + for (int j = 0; j < vals_per_iter; j += 2) { + // process 2 vals per j iter + + // dequantize + // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val + dfloat2 v; + dequantize_kernel(vx, ib, iqs + j/qr, v); + + // matrix multiplication + // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2 +#ifdef GGML_CUDA_F16 + tmp += __hmul2(v, { + y[iybs + iqs + j/qr + 0], + y[iybs + iqs + j/qr + y_offset] + }); +#else + tmp += v.x * y[iybs + iqs + j/qr + 0]; + tmp += v.y * y[iybs + iqs + j/qr + y_offset]; +#endif // GGML_CUDA_F16 + } + } + + // sum up partial sums and write back result +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (tid == 0) { +#ifdef GGML_CUDA_F16 + dst[row] = tmp.x + tmp.y; +#else + dst[row] = tmp; +#endif // GGML_CUDA_F16 + } +} + +extern "C" __global__ void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows) { + dequantize_mul_mat_vec(vx, y, dst, ncols, nrows); +} + +extern "C" __global__ void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows) { + dequantize_mul_mat_vec(vx, y, dst, ncols, nrows); +} + +extern "C" __global__ void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows) { + dequantize_mul_mat_vec(vx, y, dst, ncols, nrows); +} + +extern "C" __global__ void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows) { + dequantize_mul_mat_vec(vx, y, dst, ncols, nrows); +} +extern "C" __global__ void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows) { + dequantize_mul_mat_vec(vx, y, dst, ncols, nrows); +} + +extern "C" __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { + + static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); + + const int row = blockIdx.x*blockDim.y + threadIdx.y; + if (row > nrows) return; + + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + const block_q2_K * x = (const block_q2_K *)vx + ib0; + + float tmp = 0; // partial sum for thread in warp + +#if QK_K == 256 + const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...15 + const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 + + const int step = 16/K_QUANTS_PER_ITERATION; + + const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const int in = tid - step*im; // 0...15 or 0...7 + + const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 or 0...14 in steps of 2 + const int q_offset = 32*im + l0; + const int s_offset = 8*im; + const int y_offset = 128*im + l0; + + uint32_t aux[4]; + const uint8_t * d = (const uint8_t *)aux; + const uint8_t * m = (const uint8_t *)(aux + 2); + + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + + const float * y = yy + i * QK_K + y_offset; + const uint8_t * q = x[i].qs + q_offset; + + const float dall = __low2half(x[i].dm); + const float dmin = __high2half(x[i].dm); + + const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset); + aux[0] = a[0] & 0x0f0f0f0f; + aux[1] = a[1] & 0x0f0f0f0f; + aux[2] = (a[0] >> 4) & 0x0f0f0f0f; + aux[3] = (a[1] >> 4) & 0x0f0f0f0f; + + float sum1 = 0, sum2 = 0; + for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { + sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3) + + y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3) + + y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3) + + y[l+96] * d[6] * ((q[l+ 0] >> 6) & 3) + + y[l+16] * d[1] * ((q[l+16] >> 0) & 3) + + y[l+48] * d[3] * ((q[l+16] >> 2) & 3) + + y[l+80] * d[5] * ((q[l+16] >> 4) & 3) + +y[l+112] * d[7] * ((q[l+16] >> 6) & 3); + sum2 += y[l+ 0] * m[0] + y[l+32] * m[2] + y[l+64] * m[4] + y[ l+96] * m[6] + + y[l+16] * m[1] + y[l+48] * m[3] + y[l+80] * m[5] + y[l+112] * m[7]; + + } + tmp += dall * sum1 - dmin * sum2; + + } +#else + const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7 + const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3 + const int offset = tid * K_QUANTS_PER_ITERATION; + + uint32_t uaux[2]; + const uint8_t * d = (const uint8_t *)uaux; + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + + const float * y = yy + i * QK_K + offset; + const uint8_t * q = x[i].qs + offset; + const uint32_t * s = (const uint32_t *)x[i].scales; + + uaux[0] = s[0] & 0x0f0f0f0f; + uaux[1] = (s[0] >> 4) & 0x0f0f0f0f; + + const float2 dall = __half22float2(x[i].dm); + + float sum1 = 0, sum2 = 0; + for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { + const uint8_t ql = q[l]; + sum1 += y[l+ 0] * d[0] * ((ql >> 0) & 3) + + y[l+16] * d[1] * ((ql >> 2) & 3) + + y[l+32] * d[2] * ((ql >> 4) & 3) + + y[l+48] * d[3] * ((ql >> 6) & 3); + sum2 += y[l+0] * d[4] + y[l+16] * d[5] + y[l+32] * d[6] + y[l+48] * d[7]; + } + tmp += dall.x * sum1 - dall.y * sum2; + } +#endif + + // sum up partial sums and write back result +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (threadIdx.x == 0) { + dst[row] = tmp; + } +} + +extern "C" __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { + + const int row = blockIdx.x*blockDim.y + threadIdx.y; + if (row > nrows) return; + + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + const block_q3_K * x = (const block_q3_K *)vx + ib0; + + float tmp = 0; // partial sum for thread in warp + +#if QK_K == 256 + + const uint16_t kmask1 = 0x0303; + const uint16_t kmask2 = 0x0f0f; + + const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 + + const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop + const int step = 16/K_QUANTS_PER_ITERATION; + const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const int in = tid - step*im; // 0....15 or 0...7 + + const uint8_t m = 1 << (4*im); + + const int l0 = n*in; // 0...15 or 0...14 in steps of 2 + const int q_offset = 32*im + l0; + const int y_offset = 128*im + l0; + + uint16_t utmp[4]; + const int8_t * s = (const int8_t *)utmp; + + const uint16_t s_shift = 4*im; + + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + + const float * y = yy + i * QK_K + y_offset; + const uint8_t * q = x[i].qs + q_offset; + const uint8_t * h = x[i].hmask + l0; + + const uint16_t * a = (const uint16_t *)x[i].scales; + utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4); + utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4); + utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4); + utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4); + + const float d = x[i].d; + + float sum = 0; + for (int l = 0; l < n; ++l) { + sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4)) + + y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4)) + + y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4)) + + y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4)); + sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4)) + + y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4)) + + y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4)) + + y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4)); + } + tmp += d * sum; + + } +#else + + const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7 + const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3 + const int offset = tid * K_QUANTS_PER_ITERATION; // 0...15 or 0...14 + const int in = offset/8; // 0 or 1 + const int im = offset%8; // 0...7 + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + + const float * y = yy + i * QK_K + offset; + const uint8_t * q = x[i].qs + offset; + const uint8_t * s = x[i].scales; + + const float dall = (float)x[i].d; + + float sum = 0; + for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { + const uint8_t hl = x[i].hmask[im+l] >> in; + const uint8_t ql = q[l]; + sum += y[l+ 0] * dall * ((s[0] & 0xF) - 8) * ((int8_t)((ql >> 0) & 3) - ((hl >> 0) & 1 ? 0 : 4)) + + y[l+16] * dall * ((s[0] >> 4) - 8) * ((int8_t)((ql >> 2) & 3) - ((hl >> 2) & 1 ? 0 : 4)) + + y[l+32] * dall * ((s[1] & 0xF) - 8) * ((int8_t)((ql >> 4) & 3) - ((hl >> 4) & 1 ? 0 : 4)) + + y[l+48] * dall * ((s[1] >> 4) - 8) * ((int8_t)((ql >> 6) & 3) - ((hl >> 6) & 1 ? 0 : 4)); + } + tmp += sum; + } +#endif + + // sum up partial sums and write back result +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (threadIdx.x == 0) { + dst[row] = tmp; + } +} + +extern "C" __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { + + const int row = blockIdx.x*blockDim.y + threadIdx.y; + if (row > nrows) return; + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + const block_q4_K * x = (const block_q4_K *)vx + ib0; + +#if QK_K == 256 + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + + const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 + + const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4 + + const int il = tid/step; // 0...3 + const int ir = tid - step*il; // 0...7 or 0...3 + const int n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4 + + const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 + const int in = il%2; + + const int l0 = n*(2*ir + in); + const int q_offset = 32*im + l0; + const int y_offset = 64*im + l0; + + uint16_t aux[4]; + const uint8_t * sc = (const uint8_t *)aux; + +#if K_QUANTS_PER_ITERATION == 2 + uint32_t q32[4]; + const uint8_t * q4 = (const uint8_t *)q32; +#else + uint16_t q16[4]; + const uint8_t * q4 = (const uint8_t *)q16; +#endif + + float tmp = 0; // partial sum for thread in warp + + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + + const float * y1 = yy + i*QK_K + y_offset; + const float * y2 = y1 + 128; + + const float dall = __low2half(x[i].dm); + const float dmin = __high2half(x[i].dm); + + const uint16_t * a = (const uint16_t *)x[i].scales; + aux[0] = a[im+0] & kmask1; + aux[1] = a[im+2] & kmask1; + aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2); + aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2); + +#if K_QUANTS_PER_ITERATION == 2 + const uint32_t * q1 = (const uint32_t *)(x[i].qs + q_offset); + const uint32_t * q2 = q1 + 16; + + q32[0] = q1[0] & 0x0f0f0f0f; + q32[1] = q1[0] & 0xf0f0f0f0; + q32[2] = q2[0] & 0x0f0f0f0f; + q32[3] = q2[0] & 0xf0f0f0f0; + + float4 s = {0.f, 0.f, 0.f, 0.f}; + float smin = 0; + for (int l = 0; l < 4; ++l) { + s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+ 4]; + s.z += y2[l] * q4[l+8]; s.w += y2[l+32] * q4[l+12]; + smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; + } + tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin; +#else + const uint16_t * q1 = (const uint16_t *)(x[i].qs + q_offset); + const uint16_t * q2 = q1 + 32; + + q16[0] = q1[0] & 0x0f0f; + q16[1] = q1[0] & 0xf0f0; + q16[2] = q2[0] & 0x0f0f; + q16[3] = q2[0] & 0xf0f0; + + float4 s = {0.f, 0.f, 0.f, 0.f}; + float smin = 0; + for (int l = 0; l < 2; ++l) { + s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+2]; + s.z += y2[l] * q4[l+4]; s.w += y2[l+32] * q4[l+6]; + smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; + } + tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin; +#endif + + } +#else + const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 + const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); + + const int step = tid * K_QUANTS_PER_ITERATION; + + uint16_t aux16[2]; + const uint8_t * s = (const uint8_t *)aux16; + + float tmp = 0; + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + const uint8_t * q = x[i].qs + step; + const float * y = yy + i*QK_K + step; + const uint16_t * a = (const uint16_t *)x[i].scales; + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; + const float d = (float)x[i].dm[0]; + const float m = (float)x[i].dm[1]; + float sum = 0.f; + for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { + sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2]) + + y[j+16] * (d * s[0] * (q[j+16] & 0xF) - m * s[2]) + + y[j+32] * (d * s[1] * (q[j+ 0] >> 4) - m * s[3]) + + y[j+48] * (d * s[1] * (q[j+16] >> 4) - m * s[3]); + } + tmp += sum; + } + +#endif + + // sum up partial sums and write back result +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (tid == 0) { + dst[row] = tmp; + } +} + +extern "C" __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols) { + + const int row = blockIdx.x; + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + const block_q5_K * x = (const block_q5_K *)vx + ib0; + + float tmp = 0; // partial sum for thread in warp + +#if QK_K == 256 + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + + const int tid = threadIdx.x/2; // 0...15 + const int ix = threadIdx.x%2; + + const int il = tid/4; // 0...3 + const int ir = tid - 4*il;// 0...3 + const int n = 2; + + const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 + const int in = il%2; + + const int l0 = n*(2*ir + in); + const int q_offset = 32*im + l0; + const int y_offset = 64*im + l0; + + const uint8_t hm1 = 1 << (2*im); + const uint8_t hm2 = hm1 << 4; + + uint16_t aux[4]; + const uint8_t * sc = (const uint8_t *)aux; + + uint16_t q16[8]; + const uint8_t * q4 = (const uint8_t *)q16; + + for (int i = ix; i < num_blocks_per_row; i += 2) { + + const uint8_t * ql1 = x[i].qs + q_offset; + const uint8_t * qh = x[i].qh + l0; + const float * y1 = yy + i*QK_K + y_offset; + const float * y2 = y1 + 128; + + const float dall = __low2half(x[i].dm); + const float dmin = __high2half(x[i].dm); + + const uint16_t * a = (const uint16_t *)x[i].scales; + aux[0] = a[im+0] & kmask1; + aux[1] = a[im+2] & kmask1; + aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2); + aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2); + + float4 sum = {0.f, 0.f, 0.f, 0.f}; + float smin = 0; + const uint16_t * q1 = (const uint16_t *)ql1; + const uint16_t * q2 = q1 + 32; + q16[0] = q1[0] & 0x0f0f; + q16[1] = q1[8] & 0x0f0f; + q16[2] = (q1[0] >> 4) & 0x0f0f; + q16[3] = (q1[8] >> 4) & 0x0f0f; + q16[4] = q2[0] & 0x0f0f; + q16[5] = q2[8] & 0x0f0f; + q16[6] = (q2[0] >> 4) & 0x0f0f; + q16[7] = (q2[8] >> 4) & 0x0f0f; + for (int l = 0; l < n; ++l) { + sum.x += y1[l+ 0] * (q4[l +0] + (qh[l+ 0] & (hm1 << 0) ? 16 : 0)) + + y1[l+16] * (q4[l +2] + (qh[l+16] & (hm1 << 0) ? 16 : 0)); + sum.y += y1[l+32] * (q4[l +4] + (qh[l+ 0] & (hm1 << 1) ? 16 : 0)) + + y1[l+48] * (q4[l +6] + (qh[l+16] & (hm1 << 1) ? 16 : 0)); + sum.z += y2[l+ 0] * (q4[l +8] + (qh[l+ 0] & (hm2 << 0) ? 16 : 0)) + + y2[l+16] * (q4[l+10] + (qh[l+16] & (hm2 << 0) ? 16 : 0)); + sum.w += y2[l+32] * (q4[l+12] + (qh[l+ 0] & (hm2 << 1) ? 16 : 0)) + + y2[l+48] * (q4[l+14] + (qh[l+16] & (hm2 << 1) ? 16 : 0)); + smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3] + + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7]; + } + tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin; + } + +#else + const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 + const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); + const int step = tid * K_QUANTS_PER_ITERATION; + const int im = step/8; + const int in = step%8; + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + const uint8_t * q = x[i].qs + step; + const int8_t * s = x[i].scales; + const float * y = yy + i*QK_K + step; + const float d = x[i].d; + float sum = 0.f; + for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { + const uint8_t h = x[i].qh[in+j] >> im; + sum += y[j+ 0] * d * s[0] * ((q[j+ 0] & 0xF) - ((h >> 0) & 1 ? 0 : 16)) + + y[j+16] * d * s[1] * ((q[j+16] & 0xF) - ((h >> 2) & 1 ? 0 : 16)) + + y[j+32] * d * s[2] * ((q[j+ 0] >> 4) - ((h >> 4) & 1 ? 0 : 16)) + + y[j+48] * d * s[3] * ((q[j+16] >> 4) - ((h >> 6) & 1 ? 0 : 16)); + } + tmp += sum; + } +#endif + + // sum up partial sums and write back result +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (threadIdx.x == 0) { + dst[row] = tmp; + } +} + +extern "C" __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { + + static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); + + const int row = blockIdx.x*blockDim.y + threadIdx.y; + if (row > nrows) return; + + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + const block_q6_K * x = (const block_q6_K *)vx + ib0; + +#if QK_K == 256 + + const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 + + const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 + + const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const int in = tid - step*im; // 0...15 or 0...7 + +#if K_QUANTS_PER_ITERATION == 1 + const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 + const int is = 0; +#else + const int l0 = 4 * in; // 0, 4, 8, ..., 28 + const int is = in / 4; +#endif + const int ql_offset = 64*im + l0; + const int qh_offset = 32*im + l0; + const int s_offset = 8*im + is; + const int y_offset = 128*im + l0; + + float tmp = 0; // partial sum for thread in warp + + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + + const float * y = yy + i * QK_K + y_offset; + const uint8_t * ql = x[i].ql + ql_offset; + const uint8_t * qh = x[i].qh + qh_offset; + const int8_t * s = x[i].scales + s_offset; + + const float d = x[i].d; + +#if K_QUANTS_PER_ITERATION == 1 + float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32) + + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32) + + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32) + + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32) + + y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32) + + y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32) + + y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32) + +y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32); + tmp += sum; +#else + float sum = 0; + for (int l = 0; l < 4; ++l) { + sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32) + + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32) + + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32) + + y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32); + } + tmp += sum; +#endif + + } + +#else + + const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...7 + const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0...3 + + const int step = tid * K_QUANTS_PER_ITERATION; + + float tmp = 0; // partial sum for thread in warp + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + + const float * y = yy + i * QK_K + step; + const uint8_t * ql = x[i].ql + step; + const uint8_t * qh = x[i].qh + step; + const int8_t * s = x[i].scales; + + const float d = x[i+0].d; + + float sum = 0; + for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { + sum += y[j+ 0] * s[0] * d * ((int8_t)((ql[j+ 0] & 0xF) | ((qh[j] & 0x03) << 4)) - 32) + + y[j+16] * s[1] * d * ((int8_t)((ql[j+16] & 0xF) | ((qh[j] & 0x0c) << 2)) - 32) + + y[j+32] * s[2] * d * ((int8_t)((ql[j+ 0] >> 4) | ((qh[j] & 0x30) >> 0)) - 32) + + y[j+48] * s[3] * d * ((int8_t)((ql[j+16] >> 4) | ((qh[j] & 0xc0) >> 2)) - 32); + } + tmp += sum; + + } + +#endif + + // sum up partial sums and write back result +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (tid == 0) { + dst[row] = tmp; + } +} + +// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called +// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q + +#define VDR_Q4_0_Q8_1_MMVQ 2 +#define VDR_Q4_0_Q8_1_MMQ 4 + +template static __device__ __forceinline__ float vec_dot_q4_0_q8_1_impl( + const int * v, const int * u, const float & d4, const half2 & ds8) { + + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + const int vi0 = (v[i] >> 0) & 0x0F0F0F0F; + const int vi1 = (v[i] >> 4) & 0x0F0F0F0F; + + // SIMD dot product of quantized values + sumi = __dp4a(vi0, u[2*i+0], sumi); + sumi = __dp4a(vi1, u[2*i+1], sumi); + } + + const float2 ds8f = __half22float2(ds8); + + // second part effectively subtracts 8 from each quant value + return d4 * (sumi * ds8f.x - (8*vdr/QI4_0) * ds8f.y); +} + +#define VDR_Q4_1_Q8_1_MMVQ 2 +#define VDR_Q4_1_Q8_1_MMQ 4 + +template static __device__ __forceinline__ float vec_dot_q4_1_q8_1_impl( + const int * v, const int * u, const half2 & dm4, const half2 & ds8) { + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + const int vi0 = (v[i] >> 0) & 0x0F0F0F0F; + const int vi1 = (v[i] >> 4) & 0x0F0F0F0F; + + // SIMD dot product of quantized values + sumi = __dp4a(vi0, u[2*i+0], sumi); + sumi = __dp4a(vi1, u[2*i+1], sumi); + } + +#ifdef GGML_CUDA_F16 + const float2 tmp = __half22float2(__hmul2(dm4, ds8)); + const float d4d8 = tmp.x; + const float m4s8 = tmp.y; +#else + const float2 dm4f = __half22float2(dm4); + const float2 ds8f = __half22float2(ds8); + const float d4d8 = dm4f.x * ds8f.x; + const float m4s8 = dm4f.y * ds8f.y; +#endif // GGML_CUDA_F16 + + // scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it + return sumi * d4d8 + m4s8 / (QI8_1 / (vdr * QR4_1)); +} + +#define VDR_Q5_0_Q8_1_MMVQ 2 +#define VDR_Q5_0_Q8_1_MMQ 4 + +template static __device__ __forceinline__ float vec_dot_q5_0_q8_1_impl( + const int * vl, const int * vh, const int * u, const float & d5, const half2 & ds8) { + + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + int vi0 = (vl[i] >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits + vi0 |= (vh[i] << 4) & 0x00000010; // 0 -> 4 + vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12 + vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20 + vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28 + sumi = __dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values + + int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits + vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4 + vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12 + vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20 + vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28 + sumi = __dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values + } + + const float2 ds8f = __half22float2(ds8); + + // second part effectively subtracts 16 from each quant value + return d5 * (sumi * ds8f.x - (16*vdr/QI5_0) * ds8f.y); +} + +#define VDR_Q5_1_Q8_1_MMVQ 2 +#define VDR_Q5_1_Q8_1_MMQ 4 + +template static __device__ __forceinline__ float vec_dot_q5_1_q8_1_impl( + const int * vl, const int * vh, const int * u, const half2 & dm5, const half2 & ds8) { + + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + int vi0 = (vl[i] >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits + vi0 |= (vh[i] << 4) & 0x00000010; // 0 -> 4 + vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12 + vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20 + vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28 + sumi = __dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values + + int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits + vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4 + vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12 + vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20 + vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28 + sumi = __dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values + } + +#ifdef GGML_CUDA_F16 + const float2 tmp = __half22float2(__hmul2(dm5, ds8)); + const float d5d8 = tmp.x; + const float m5s8 = tmp.y; +#else + const float2 dm5f = __half22float2(dm5); + const float2 ds8f = __half22float2(ds8); + const float d5d8 = dm5f.x * ds8f.x; + const float m5s8 = dm5f.y * ds8f.y; +#endif // GGML_CUDA_F16 + + // scale second part of sum by QI5_1 / vdr to compensate for multiple threads adding it + return sumi*d5d8 + m5s8 / (QI5_1 / vdr); +} + +#define VDR_Q8_0_Q8_1_MMVQ 2 +#define VDR_Q8_0_Q8_1_MMQ 8 + +template static __device__ __forceinline__ float vec_dot_q8_0_q8_1_impl( + const int * v, const int * u, const float & d8_0, const float & d8_1) { + + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + // SIMD dot product of quantized values + sumi = __dp4a(v[i], u[i], sumi); + } + + return d8_0*d8_1 * sumi; +} + +template static __device__ __forceinline__ float vec_dot_q8_1_q8_1_impl( + const int * v, const int * u, const half2 & dm8, const half2 & ds8) { + + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + // SIMD dot product of quantized values + sumi = __dp4a(v[i], u[i], sumi); + } + +#ifdef GGML_CUDA_F16 + const float2 tmp = __half22float2(__hmul2(dm8, ds8)); + const float d8d8 = tmp.x; + const float m8s8 = tmp.y; +#else + const float2 dm8f = __half22float2(dm8); + const float2 ds8f = __half22float2(ds8); + const float d8d8 = dm8f.x * ds8f.x; + const float m8s8 = dm8f.y * ds8f.y; +#endif // GGML_CUDA_F16 + + // scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it + return sumi*d8d8 + m8s8 / (QI8_1 / vdr); +} + +#define VDR_Q2_K_Q8_1_MMVQ 1 +#define VDR_Q2_K_Q8_1_MMQ 2 + +// contiguous v/x values +static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq( + const int & v, const int * __restrict__ u, const uint8_t * __restrict__ scales, + const half2 & dm2, const float * __restrict__ d8) { + + float sumf_d = 0.0f; + float sumf_m = 0.0f; + +#pragma unroll + for (int i = 0; i < QR2_K; ++i) { + const int sc = scales[2*i]; + + const int vi = (v >> (2*i)) & 0x03030303; + + sumf_d += d8[i] * (__dp4a(vi, u[i], 0) * (sc & 0xF)); // SIMD dot product + + // fill int with 4x m + int m = sc >> 4; + m |= m << 8; + m |= m << 16; + sumf_m += d8[i] * __dp4a(m, u[i], 0); // multiply constant q2_K part with sum of q8_1 values + } + + const float2 dm2f = __half22float2(dm2); + + return dm2f.x*sumf_d - dm2f.y*sumf_m; +} + +// contiguous u/y values +static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq( + const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ scales, + const half2 & dm2, const float & d8) { + + int sumi_d = 0; + int sumi_m = 0; + +#pragma unroll + for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) { + int sumi_d_sc = 0; + + const int sc = scales[i0 / (QI8_1/2)]; + + // fill int with 4x m + int m = sc >> 4; + m |= m << 8; + m |= m << 16; + +#pragma unroll + for (int i = i0; i < i0 + QI8_1/2; ++i) { + sumi_d_sc = __dp4a(v[i], u[i], sumi_d_sc); // SIMD dot product + sumi_m = __dp4a(m, u[i], sumi_m); // multiply sum of q8_1 values with m + } + + sumi_d += sumi_d_sc * (sc & 0xF); + } + + const float2 dm2f = __half22float2(dm2); + + return d8 * (dm2f.x*sumi_d - dm2f.y*sumi_m); +} + +#define VDR_Q3_K_Q8_1_MMVQ 1 +#define VDR_Q3_K_Q8_1_MMQ 2 + +// contiguous v/x values +static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq( + const int & vl, const int & vh, const int * __restrict__ u, const uint8_t * __restrict__ scales, + const int & scale_offset, const float & d3, const float * __restrict__ d8) { + + float sumf = 0.0f; + +#pragma unroll + for (int i = 0; i < QR3_K; ++i) { + const int isc = scale_offset + 2*i; + + const int isc_low = isc % (QK_K/32); + const int sc_shift_low = 4 * (isc / (QK_K/32)); + const int sc_low = (scales[isc_low] >> sc_shift_low) & 0xF; + + const int isc_high = isc % (QK_K/64); + const int sc_shift_high = 2 * (isc / (QK_K/64)); + const int sc_high = ((scales[(QK_K/32) + isc_high] >> sc_shift_high) & 3) << 4; + + const int sc = (sc_low | sc_high) - 32; + + const int vil = (vl >> (2*i)) & 0x03030303; + + const int vih = ((vh >> i) << 2) & 0x04040404; + + const int vi = __vsubss4(vil, vih); + + sumf += d8[i] * (__dp4a(vi, u[i], 0) * sc); // SIMD dot product + } + + return d3 * sumf; +} + +// contiguous u/y values +static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq( + const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ scales, + const float & d3, const float & d8) { + + int sumi = 0; + +#pragma unroll + for (int i0 = 0; i0 < QR3_K*VDR_Q3_K_Q8_1_MMQ; i0 += QI8_1/2) { + int sumi_sc = 0; + + for (int i = i0; i < i0 + QI8_1/2; ++i) { + sumi_sc = __dp4a(v[i], u[i], sumi_sc); // SIMD dot product + } + + sumi += sumi_sc * scales[i0 / (QI8_1/2)]; + } + + return d3*d8 * sumi; +} + +#define VDR_Q4_K_Q8_1_MMVQ 2 +#define VDR_Q4_K_Q8_1_MMQ 8 + +// contiguous v/x values +static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq( + const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, + const uint8_t * __restrict__ m, const half2 & dm4, const float * __restrict__ d8) { + + float sumf_d = 0.0f; + float sumf_m = 0.0f; + +#pragma unroll + for (int i = 0; i < QR4_K; ++i) { + const int v0i = (v[0] >> (4*i)) & 0x0F0F0F0F; + const int v1i = (v[1] >> (4*i)) & 0x0F0F0F0F; + + const int dot1 = __dp4a(v1i, u[2*i+1], __dp4a(v0i, u[2*i+0], 0)); // SIMD dot product + const int dot2 = __dp4a(0x01010101, u[2*i+1], __dp4a(0x01010101, u[2*i+0], 0)); // sum of u + + sumf_d += d8[i] * (dot1 * sc[i]); + sumf_m += d8[i] * (dot2 * m[i]); // multiply constant part of q4_K with sum of q8_1 values + } + + const float2 dm4f = __half22float2(dm4); + + return dm4f.x*sumf_d - dm4f.y*sumf_m; +} + +// contiguous u/y values +static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq( + const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, + const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) { + + float sumf_d = 0.0f; + float sumf_m = 0.0f; + +#pragma unroll + for (int i = 0; i < QR4_K*VDR_Q4_K_Q8_1_MMQ/QI8_1; ++i) { + int sumi_d = 0; + +#pragma unroll + for (int j = 0; j < QI8_1; ++j) { + sumi_d = __dp4a((v[j] >> (4*i)) & 0x0F0F0F0F, u[i*QI8_1 + j], sumi_d); // SIMD dot product + } + + const float2 ds8f = __half22float2(ds8[i]); + + sumf_d += ds8f.x * (sc[i] * sumi_d); + sumf_m += ds8f.y * m[i]; // sum of q8_1 block * q4_K min val + } + + const float2 dm4f = __half22float2(dm4); + + return dm4f.x*sumf_d - dm4f.y*sumf_m; +} + +#define VDR_Q5_K_Q8_1_MMVQ 2 +#define VDR_Q5_K_Q8_1_MMQ 8 + +// contiguous v/x values +static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq( + const int * __restrict__ vl, const int * __restrict__ vh, const int * __restrict__ u, const uint8_t * __restrict__ sc, + const uint8_t * __restrict__ m, const half2 & dm5, const float * __restrict__ d8) { + + float sumf_d = 0.0f; + float sumf_m = 0.0f; + +#pragma unroll + for (int i = 0; i < QR5_K; ++i) { + const int vl0i = (vl[0] >> (4*i)) & 0x0F0F0F0F; + const int vl1i = (vl[1] >> (4*i)) & 0x0F0F0F0F; + + const int vh0i = ((vh[0] >> i) << 4) & 0x10101010; + const int vh1i = ((vh[1] >> i) << 4) & 0x10101010; + + const int v0i = vl0i | vh0i; + const int v1i = vl1i | vh1i; + + const int dot1 = __dp4a(v0i, u[2*i+0], __dp4a(v1i, u[2*i+1], 0)); // SIMD dot product + const int dot2 = __dp4a(0x01010101, u[2*i+0], __dp4a(0x01010101, u[2*i+1], 0)); // sum of u + + sumf_d += d8[i] * (dot1 * sc[i]); + sumf_m += d8[i] * (dot2 * m[i]); + + } + + const float2 dm5f = __half22float2(dm5); + + return dm5f.x*sumf_d - dm5f.y*sumf_m; +} + +// contiguous u/y values +static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq( + const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, + const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) { + + float sumf_d = 0.0f; + float sumf_m = 0.0f; + +#pragma unroll + for (int i = 0; i < QR5_K*VDR_Q5_K_Q8_1_MMQ/QI8_1; ++i) { + int sumi_d = 0; + +#pragma unroll + for (int j = 0; j < QI8_1; ++j) { + sumi_d = __dp4a(v[i*QI8_1 + j], u[i*QI8_1 + j], sumi_d); // SIMD dot product + } + + const float2 ds8f = __half22float2(ds8[i]); + + sumf_d += ds8f.x * (sc[i] * sumi_d); + sumf_m += ds8f.y * m[i]; // sum of q8_1 block * q4_K min val + } + + const float2 dm4f = __half22float2(dm4); + + return dm4f.x*sumf_d - dm4f.y*sumf_m; +} + +#define VDR_Q6_K_Q8_1_MMVQ 1 +#define VDR_Q6_K_Q8_1_MMQ 8 + +// contiguous v/x values +static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq( + const int & vl, const int & vh, const int * __restrict__ u, const int8_t * __restrict__ scales, + const float & d, const float * __restrict__ d8) { + + float sumf = 0.0f; + +#pragma unroll + for (int i = 0; i < QR6_K; ++i) { + const int sc = scales[4*i]; + + const int vil = (vl >> (4*i)) & 0x0F0F0F0F; + + const int vih = ((vh >> (4*i)) << 4) & 0x30303030; + + const int vi = __vsubss4((vil | vih), 0x20202020); // vi = (vil | vih) - 32 + + sumf += d8[i] * (__dp4a(vi, u[i], 0) * sc); // SIMD dot product + } + + return d*sumf; +} + +// contiguous u/y values +static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq( + const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ sc, + const float & d6, const float * __restrict__ d8) { + + float sumf_d = 0.0f; + +#pragma unroll + for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) { + int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale + +#pragma unroll + for (int i = i0; i < i0 + 2; ++i) { + sumi_d.x = __dp4a(v[2*i+0], u[2*i+0], sumi_d.x); // SIMD dot product + sumi_d.x = __dp4a(v[2*i+1], u[2*i+1], sumi_d.x); // SIMD dot product + + sumi_d.y = __dp4a(v[2*i+4], u[2*i+4], sumi_d.y); // SIMD dot product + sumi_d.y = __dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product + } + + sumf_d += d8[i0/4] * (sc[i0/2+0]*sumi_d.x + sc[i0/2+1]*sumi_d.y); + } + + return d6 * sumf_d; +} + +static __device__ __forceinline__ float vec_dot_q4_0_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq; + + int v[VDR_Q4_0_Q8_1_MMVQ]; + int u[2*VDR_Q4_0_Q8_1_MMVQ]; + +#pragma unroll + for (int i = 0; i < VDR_Q4_0_Q8_1_MMVQ; ++i) { + v[i] = get_int_from_uint8(bq4_0->qs, iqs + i); + u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_0); + } + + return vec_dot_q4_0_q8_1_impl(v, u, bq4_0->d, bq8_1->ds); +} + + +static __device__ __forceinline__ float vec_dot_q4_1_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq; + + int v[VDR_Q4_1_Q8_1_MMVQ]; + int u[2*VDR_Q4_1_Q8_1_MMVQ]; + +#pragma unroll + for (int i = 0; i < VDR_Q4_1_Q8_1_MMVQ; ++i) { + v[i] = get_int_from_uint8_aligned(bq4_1->qs, iqs + i); + u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_1); + } + + return vec_dot_q4_1_q8_1_impl(v, u, bq4_1->dm, bq8_1->ds); +} + +static __device__ __forceinline__ float vec_dot_q5_0_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq; + + int vl[VDR_Q5_0_Q8_1_MMVQ]; + int vh[VDR_Q5_0_Q8_1_MMVQ]; + int u[2*VDR_Q5_0_Q8_1_MMVQ]; + +#pragma unroll + for (int i = 0; i < VDR_Q5_0_Q8_1_MMVQ; ++i) { + vl[i] = get_int_from_uint8(bq5_0->qs, iqs + i); + vh[i] = get_int_from_uint8(bq5_0->qh, 0) >> (4 * (iqs + i)); + u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_0); + } + + return vec_dot_q5_0_q8_1_impl(vl, vh, u, bq5_0->d, bq8_1->ds); +} + +static __device__ __forceinline__ float vec_dot_q5_1_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq; + + int vl[VDR_Q5_1_Q8_1_MMVQ]; + int vh[VDR_Q5_1_Q8_1_MMVQ]; + int u[2*VDR_Q5_1_Q8_1_MMVQ]; + +#pragma unroll + for (int i = 0; i < VDR_Q5_1_Q8_1_MMVQ; ++i) { + vl[i] = get_int_from_uint8_aligned(bq5_1->qs, iqs + i); + vh[i] = get_int_from_uint8_aligned(bq5_1->qh, 0) >> (4 * (iqs + i)); + u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_1); + } + + return vec_dot_q5_1_q8_1_impl(vl, vh, u, bq5_1->dm, bq8_1->ds); +} + +static __device__ __forceinline__ float vec_dot_q8_0_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq; + + int v[VDR_Q8_0_Q8_1_MMVQ]; + int u[VDR_Q8_0_Q8_1_MMVQ]; + +#pragma unroll + for (int i = 0; i < VDR_Q8_0_Q8_1_MMVQ; ++i) { + v[i] = get_int_from_int8(bq8_0->qs, iqs + i); + u[i] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + } + + return vec_dot_q8_0_q8_1_impl(v, u, bq8_0->d, __low2half(bq8_1->ds)); +} + +static __device__ __forceinline__ float vec_dot_q2_K_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q2_K * bq2_K = (const block_q2_K *) vbq; + + const int bq8_offset = QR2_K * (iqs / QI8_1); + const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2); + + const uint8_t * scales = bq2_K->scales + scale_offset; + + const int v = get_int_from_uint8_aligned(bq2_K->qs, iqs); + int u[QR2_K]; + float d8[QR2_K]; + +#pragma unroll + for (int i = 0; i < QR2_K; ++ i) { + u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1); + d8[i] = __low2float(bq8_1[bq8_offset + i].ds); + } + + return vec_dot_q2_K_q8_1_impl_mmvq(v, u, scales, bq2_K->dm, d8); +} + +static __device__ __forceinline__ float vec_dot_q3_K_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q3_K * bq3_K = (const block_q3_K *) vbq; + + const int bq8_offset = QR3_K * (iqs / (QI3_K/2)); + const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2); + + const float d = bq3_K->d; + + const int vl = get_int_from_uint8(bq3_K->qs, iqs); + + // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted + const int vh = ~get_int_from_uint8(bq3_K->hmask, iqs % (QI3_K/2)) >> bq8_offset; + + int u[QR3_K]; + float d8[QR3_K]; + +#pragma unroll + for (int i = 0; i < QR3_K; ++i) { + u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1); + d8[i] = __low2float(bq8_1[bq8_offset + i].ds); + } + + return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8); +} + +static __device__ __forceinline__ float vec_dot_q4_K_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + +#ifndef GGML_QKK_64 + const block_q4_K * bq4_K = (const block_q4_K *) vbq; + + int v[2]; + int u[2*QR4_K]; + float d8[QR4_K]; + + // iqs is in 0,2..30. bq8_offset = iqs/4 -> bq8_offset = 0, 2, 4, 6 + const int bq8_offset = QR4_K * ((iqs/2) / (QI8_1/2)); + + // iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12 + // iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44 + // iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76 + // iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108 + + const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4)); + v[0] = q4[0]; + v[1] = q4[4]; + + const uint16_t * scales = (const uint16_t *)bq4_K->scales; + uint16_t aux[2]; + const int j = bq8_offset/2; + if (j < 2) { + aux[0] = scales[j+0] & 0x3f3f; + aux[1] = scales[j+2] & 0x3f3f; + } else { + aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2); + aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2); + } + const uint8_t * sc = (const uint8_t *)aux; + const uint8_t * m = sc + 2; + + for (int i = 0; i < QR4_K; ++i) { + const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; + d8[i] = __low2float(bq8i->ds); + + const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4); + u[2*i+0] = q8[0]; + u[2*i+1] = q8[4]; + } + + return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, bq4_K->dm, d8); + +#else + + const block_q4_K * bq4_K = (const block_q4_K *) vbq; + + float sumf_d = 0.0f; + float sumf_m = 0.0f; + + uint16_t aux16[2]; + const uint8_t * s = (const uint8_t *)aux16; + + const uint16_t * a = (const uint16_t *)bq4_K->scales; + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; + + const float dall = bq4_K->dm[0]; + const float dmin = bq4_K->dm[1]; + + const float d8_1 = __low2float(bq8_1[0].ds); + const float d8_2 = __low2float(bq8_1[1].ds); + + const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2)); + const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4); + const int ui3 = *((const int *)bq8_1[1].qs + (iqs/2)); + const int ui4 = *((const int *)bq8_1[1].qs + (iqs/2) + 4); + + const int * q4 = (const int *)bq4_K->qs + (iqs/2); + const int v1 = q4[0]; + const int v2 = q4[4]; + + const int dot1 = __dp4a(ui2, v2 & 0x0f0f0f0f, __dp4a(ui1, v1 & 0x0f0f0f0f, 0)); + const int dot2 = __dp4a(ui4, (v2 >> 4) & 0x0f0f0f0f, __dp4a(ui3, (v1 >> 4) & 0x0f0f0f0f, 0)); + const int dot3 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0)); + const int dot4 = __dp4a(0x01010101, ui4, __dp4a(0x01010101, ui3, 0)); + + sumf_d += d8_1 * (dot1 * s[0]) + d8_2 * (dot2 * s[1]); + sumf_m += d8_1 * (dot3 * s[2]) + d8_2 * (dot4 * s[3]); + + return dall * sumf_d - dmin * sumf_m; +#endif +} + +static __device__ __forceinline__ float vec_dot_q5_K_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + +#ifndef GGML_QKK_64 + const block_q5_K * bq5_K = (const block_q5_K *) vbq; + + int vl[2]; + int vh[2]; + int u[2*QR5_K]; + float d8[QR5_K]; + + const int bq8_offset = QR5_K * ((iqs/2) / (QI8_1/2)); + const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4)); + const int * qh = (const int *)(bq5_K->qh + 4 * ((iqs/2)%4)); + + vl[0] = ql[0]; + vl[1] = ql[4]; + + vh[0] = qh[0] >> bq8_offset; + vh[1] = qh[4] >> bq8_offset; + + const uint16_t * scales = (const uint16_t *)bq5_K->scales; + uint16_t aux[2]; + const int j = bq8_offset/2; + if (j < 2) { + aux[0] = scales[j+0] & 0x3f3f; + aux[1] = scales[j+2] & 0x3f3f; + } else { + aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2); + aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2); + } + const uint8_t * sc = (const uint8_t *)aux; + const uint8_t * m = sc + 2; + +#pragma unroll + for (int i = 0; i < QR5_K; ++i) { + const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; + d8[i] = __low2float(bq8i->ds); + + const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4); + u[2*i+0] = q8[0]; + u[2*i+1] = q8[4]; + } + + return vec_dot_q5_K_q8_1_impl_vmmq(vl, vh, u, sc, m, bq5_K->dm, d8); + +#else + + const block_q5_K * bq5_K = (const block_q5_K *) vbq; + + const int8_t * s = bq5_K->scales; + + const float d = bq5_K->d; + + const float d8_1 = __low2half(bq8_1[0].ds); + const float d8_2 = __low2half(bq8_1[1].ds); + + const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2)); + const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4); + const int ui3 = *((const int *)bq8_1[1].qs + (iqs/2)); + const int ui4 = *((const int *)bq8_1[1].qs + (iqs/2) + 4); + + const int * ql = (const int *)bq5_K->qs + (iqs/2); + const int vl1 = ql[0]; + const int vl2 = ql[4]; + + const int step = 4 * (iqs/2); // 0, 4, 8, 12 + const int im = step/8; // = 0 for iqs = 0, 2, = 1 for iqs = 4, 6 + const int in = step%8; // 0, 4, 0, 4 + const int vh = (*((const int *)(bq5_K->qh + in))) >> im; + + const int v1 = (((vh << 4) & 0x10101010) ^ 0x10101010) | ((vl1 >> 0) & 0x0f0f0f0f); + const int v2 = (((vh << 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 0) & 0x0f0f0f0f); + const int v3 = (((vh >> 0) & 0x10101010) ^ 0x10101010) | ((vl1 >> 4) & 0x0f0f0f0f); + const int v4 = (((vh >> 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 4) & 0x0f0f0f0f); + + const float sumf_d = d8_1 * (__dp4a(ui1, v1, 0) * s[0] + __dp4a(ui2, v2, 0) * s[1]) + + d8_2 * (__dp4a(ui3, v3, 0) * s[2] + __dp4a(ui4, v4, 0) * s[3]); + + return d * sumf_d; +#endif +} + +static __device__ __forceinline__ float vec_dot_q6_K_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q6_K * bq6_K = (const block_q6_K *) vbq; + + const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/4); + const int scale_offset = (QI6_K/4) * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/8); + const int vh_shift = 2 * ((iqs % (QI6_K/2)) / (QI6_K/4)); + + const int vl = get_int_from_uint8(bq6_K->ql, iqs); + const int vh = get_int_from_uint8(bq6_K->qh, (QI6_K/4) * (iqs / (QI6_K/2)) + iqs % (QI6_K/4)) >> vh_shift; + + const int8_t * scales = bq6_K->scales + scale_offset; + + int u[QR6_K]; + float d8[QR6_K]; + +#pragma unroll + for (int i = 0; i < QR6_K; ++i) { + u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1); + d8[i] = __low2float(bq8_1[bq8_offset + 2*i].ds); + } + + return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8); +} + +// https://github.com/ggerganov/llama.cpp/blob/c50a82ce0f71558cbb8e555146ba124251504b38/ggml-cuda/mmvq.cu#L4 +typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs); + +template +static __device__ void mul_mat_vec_q( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3)) + constexpr int nwarps = 1; + constexpr int rows_per_cuda_block = 1; +#else + constexpr int nwarps = ncols_y <= 4 ? 4 : 2; + constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2; +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3) + + const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; + const int row0 = rows_per_cuda_block*blockIdx.x; + const int blocks_per_row_x = ncols_x / qk; + const int blocks_per_col_y = nrows_y / QK8_1; + constexpr int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi; + +// partial sum for each thread + float tmp[ncols_y][rows_per_cuda_block] = {0.0f}; + + const block_q_t * x = (const block_q_t *) vx; + const block_q8_1 * y = (const block_q8_1 *) vy; + + for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) { + const int kby = kbx * (qk/QK8_1); // y block index that aligns with kbx + + // x block quant index when casting the quants to int + const int kqs = vdr * (tid % (qi/vdr)); + +#pragma unroll + for (int j = 0; j < ncols_y; ++j) { +#pragma unroll + for (int i = 0; i < rows_per_cuda_block; ++i) { + tmp[j][i] += vec_dot_q_cuda( + &x[kbx + (row0 + i)*blocks_per_row_x], &y[j*blocks_per_col_y + kby], kqs); + } + } + } + + __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][WARP_SIZE]; + if (threadIdx.y > 0) { +#pragma unroll + for (int j = 0; j < ncols_y; ++j) { +#pragma unroll + for (int i = 0; i < rows_per_cuda_block; ++i) { + tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i]; + } + } + } + __syncthreads(); + if (threadIdx.y > 0) { + return; + } + + // sum up partial sums and write back result +#pragma unroll + for (int j = 0; j < ncols_y; ++j) { +#pragma unroll + for (int i = 0; i < rows_per_cuda_block; ++i) { +#pragma unroll + for (int l = 0; l < nwarps-1; ++l) { + tmp[j][i] += tmp_shared[l][j][i][threadIdx.x]; + } + tmp[j][i] = warp_reduce_sum(tmp[j][i]); + } + + if (threadIdx.x < rows_per_cuda_block) { + dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x]; + } + } +} + +// batch size = 1 +extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda1( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<1, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda1( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<1, QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda1( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<1, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda1( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<1, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda1( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<1, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda1( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<1, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda1( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<1, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda1( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<1, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda1( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<1, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda1( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<1, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +// batch size = 2 +extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda2( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<2, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda2( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<2, QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda2( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<2, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda2( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<2, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda2( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<2, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda2( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<2, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda2( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<2, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda2( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<2, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda2( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<2, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda2( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<2, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +// batch size = 3 +extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda3( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<3, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda3( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<3, QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda3( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<3, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda3( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<3, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda3( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<3, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda3( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<3, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda3( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<3, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda3( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<3, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda3( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<3, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda3( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<3, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +// batch size = 4 +extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda4( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<4, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda4( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<4, QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda4( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<4, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda4( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<4, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda4( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<4, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda4( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<4, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda4( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<4, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda4( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<4, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda4( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<4, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda4( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<4, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +// batch size = 5 +extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda5( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<5, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda5( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<5, QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda5( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<5, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda5( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<5, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda5( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<5, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda5( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<5, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda5( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<5, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda5( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<5, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda5( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<5, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda5( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<5, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +// batch size = 6 +extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda6( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<6, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda6( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<6, QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda6( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<6, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda6( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<6, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda6( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<6, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda6( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<6, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda6( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<6, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda6( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<6, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda6( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<6, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda6( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<6, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +// batch size = 7 +extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda7( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<7, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda7( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<7, QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda7( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<7, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda7( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<7, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda7( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<7, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda7( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<7, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda7( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<7, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda7( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<7, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda7( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<7, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda7( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<7, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +// batch size = 8 +extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda8( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<8, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda8( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<8, QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda8( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<8, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda8( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<8, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda8( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<8, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda8( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<8, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda8( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<8, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda8( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<8, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda8( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<8, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda8( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<8, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded) { + const int ix = blockDim.x*blockIdx.x + threadIdx.x; + + if (ix >= kx_padded) { + return; + } + + const int iy = blockDim.y*blockIdx.y + threadIdx.y; + + const int i_padded = iy*kx_padded + ix; + + block_q8_1 * y = (block_q8_1 *) vy; + + const int ib = i_padded / QK8_1; // block index + const int iqs = i_padded % QK8_1; // quant index + + const float xi = ix < kx ? x[iy*kx + ix] : 0.0f; + float amax = fabsf(xi); + float sum = xi; + + amax = warp_reduce_max(amax); + sum = warp_reduce_sum(sum); + + const float d = amax / 127; + const int8_t q = amax == 0.0f ? 0 : roundf(xi / d); + + y[ib].qs[iqs] = q; + + if (iqs > 0) { + return; + } + + reinterpret_cast(y[ib].ds.x) = d; + reinterpret_cast(y[ib].ds.y) = sum; +} + +// Kernels from https://github.com/ggerganov/llama.cpp/blob/master/ggml-cuda/mmq.cu + +template static __device__ __forceinline__ void allocate_tiles_q5_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + + __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; + __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI5_0) + mmq_y/QI5_0]; + + *x_ql = tile_x_ql; + *x_dm = (half2 *) tile_x_d; +} + +template static __device__ __forceinline__ void load_tiles_q5_0( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { + GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + + GGML_CUDA_ASSUME(i_offset >= 0); + GGML_CUDA_ASSUME(i_offset < nwarps); + GGML_CUDA_ASSUME(k >= 0); + GGML_CUDA_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI5_0; + const int kqsx = k % QI5_0; + + const block_q5_0 * bx0 = (const block_q5_0 *) vx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbx; + + const int ql = get_int_from_uint8(bxi->qs, kqsx); + const int qh = get_int_from_uint8(bxi->qh, 0) >> (4 * (k % QI5_0)); + + int qs0 = (ql >> 0) & 0x0F0F0F0F; + qs0 |= (qh << 4) & 0x00000010; // 0 -> 4 + qs0 |= (qh << 11) & 0x00001000; // 1 -> 12 + qs0 |= (qh << 18) & 0x00100000; // 2 -> 20 + qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 + qs0 = __vsubss4(qs0, 0x10101010); // subtract 16 + + x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0; + + int qs1 = (ql >> 4) & 0x0F0F0F0F; + qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 + qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12 + qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 + qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 + qs1 = __vsubss4(qs1, 0x10101010); // subtract 16 + + x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1; + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI5_0; + const int kbxd = k % blocks_per_tile_x_row; + float * x_dmf = (float *) x_dm; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) { + int i = i0 + i_offset * QI5_0 + k / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbxd; + + x_dmf[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd] = bxi->d; + } +} + +static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + + const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); + const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k/QI5_0; + const float * x_dmf = (const float *) x_dm; + const float * y_df = (const float *) y_ds; + + int u[2*VDR_Q5_0_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) { + u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; + u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE]; + } + + return vec_dot_q8_0_q8_1_impl + (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); +} + + +template static __device__ __forceinline__ void allocate_tiles_q5_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + + __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_1) + mmq_y/QI5_1]; + + *x_ql = tile_x_ql; + *x_dm = tile_x_dm; +} + +template static __device__ __forceinline__ void load_tiles_q5_1( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { + GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + + GGML_CUDA_ASSUME(i_offset >= 0); + GGML_CUDA_ASSUME(i_offset < nwarps); + GGML_CUDA_ASSUME(k >= 0); + GGML_CUDA_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI5_1; + const int kqsx = k % QI5_1; + + const block_q5_1 * bx0 = (const block_q5_1 *) vx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbx; + + const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx); + const int qh = get_int_from_uint8_aligned(bxi->qh, 0) >> (4 * (k % QI5_1)); + + int qs0 = (ql >> 0) & 0x0F0F0F0F; + qs0 |= (qh << 4) & 0x00000010; // 0 -> 4 + qs0 |= (qh << 11) & 0x00001000; // 1 -> 12 + qs0 |= (qh << 18) & 0x00100000; // 2 -> 20 + qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 + + x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0; + + int qs1 = (ql >> 4) & 0x0F0F0F0F; + qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 + qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12 + qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 + qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 + + x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1; + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI5_1; + const int kbxd = k % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) { + int i = i0 + i_offset * QI5_1 + k / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbxd; + + x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm; + } +} + +static __device__ __forceinline__ float vec_dot_q5_1_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + + const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); + const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k/QI5_1; + + int u[2*VDR_Q5_1_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) { + u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; + u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_1) % WARP_SIZE]; + } + + return vec_dot_q8_1_q8_1_impl + (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dm[index_bx], y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); +} + +template static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + + __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + mmq_y]; + __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI8_0) + mmq_y/QI8_0]; + + *x_ql = tile_x_qs; + *x_dm = (half2 *) tile_x_d; +} + +template static __device__ __forceinline__ void load_tiles_q8_0( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { + GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + + GGML_CUDA_ASSUME(i_offset >= 0); + GGML_CUDA_ASSUME(i_offset < nwarps); + GGML_CUDA_ASSUME(k >= 0); + GGML_CUDA_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI8_0; + const int kqsx = k % QI8_0; + float * x_dmf = (float *) x_dm; + + const block_q8_0 * bx0 = (const block_q8_0 *) vx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = min(i, i_max); + } + + const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbx; + + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_int8(bxi->qs, kqsx); + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI8_0; + const int kbxd = k % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0) { + int i = i0 + i_offset * QI8_0 + k / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbxd; + + x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d; + } +} + +static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + + const float * x_dmf = (const float *) x_dm; + const float * y_df = (const float *) y_ds; + + return vec_dot_q8_0_q8_1_impl + (&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[j * WARP_SIZE + k], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k/QI8_0], + y_df[j * (WARP_SIZE/QI8_1) + k/QI8_1]); +} + +template static __device__ __forceinline__ void allocate_tiles_q2_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + GGML_UNUSED(x_qh); + + __shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI2_K) + mmq_y/QI2_K]; + __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/4) + mmq_y/4]; + + *x_ql = tile_x_ql; + *x_dm = tile_x_dm; + *x_sc = tile_x_sc; +} + +template static __device__ __forceinline__ void load_tiles_q2_K( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { + GGML_UNUSED(x_qh); + + GGML_CUDA_ASSUME(i_offset >= 0); + GGML_CUDA_ASSUME(i_offset < nwarps); + GGML_CUDA_ASSUME(k >= 0); + GGML_CUDA_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI2_K; + const int kqsx = k % QI2_K; + + const block_q2_K * bx0 = (const block_q2_K *) vx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = min(i, i_max); + } + + const block_q2_K * bxi = bx0 + i*blocks_per_row + kbx; + + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI2_K; + const int kbxd = k % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI2_K) { + int i = (i0 + i_offset * QI2_K + k / blocks_per_tile_x_row) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q2_K * bxi = bx0 + i*blocks_per_row + kbxd; + + x_dm[i * (WARP_SIZE/QI2_K) + i / QI2_K + kbxd] = bxi->dm; + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { + int i = i0 + i_offset * 4 + k / (WARP_SIZE/4); + + if (need_check) { + i = min(i, i_max); + } + + const block_q2_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI2_K/4); + + x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = get_int_from_uint8_aligned(bxi->scales, k % (QI2_K/4)); + } +} + +static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + GGML_UNUSED(x_qh); + + const int kbx = k / QI2_K; + const int ky = (k % QI2_K) * QR2_K; + const float * y_df = (const float *) y_ds; + + int v[QR2_K*VDR_Q2_K_Q8_1_MMQ]; + + const int kqsx = i * (WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2); + const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2)); + +#pragma unroll + for (int l = 0; l < QR2_K*VDR_Q2_K_Q8_1_MMQ; ++l) { + v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303; + } + + const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE/4) + i/4 + kbx*4]) + ky/4; + + const int index_y = j * WARP_SIZE + (QR2_K*k) % WARP_SIZE; + return vec_dot_q2_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]); +} + +template static __device__ __forceinline__ void allocate_tiles_q3_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + + __shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI3_K) + mmq_y/QI3_K]; + __shared__ int tile_x_qh[mmq_y * (WARP_SIZE/2) + mmq_y/2]; + __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/4) + mmq_y/4]; + + *x_ql = tile_x_ql; + *x_dm = tile_x_dm; + *x_qh = tile_x_qh; + *x_sc = tile_x_sc; +} + +template static __device__ __forceinline__ void load_tiles_q3_K( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { + + GGML_CUDA_ASSUME(i_offset >= 0); + GGML_CUDA_ASSUME(i_offset < nwarps); + GGML_CUDA_ASSUME(k >= 0); + GGML_CUDA_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI3_K; + const int kqsx = k % QI3_K; + + const block_q3_K * bx0 = (const block_q3_K *) vx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = min(i, i_max); + } + + const block_q3_K * bxi = bx0 + i*blocks_per_row + kbx; + + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx); + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI3_K; + const int kbxd = k % blocks_per_tile_x_row; + float * x_dmf = (float *) x_dm; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) { + int i = (i0 + i_offset * QI3_K + k / blocks_per_tile_x_row) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q3_K * bxi = bx0 + i*blocks_per_row + kbxd; + + x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = bxi->d; + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 2) { + int i = i0 + i_offset * 2 + k / (WARP_SIZE/2); + + if (need_check) { + i = min(i, i_max); + } + + const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/2)) / (QI3_K/2); + + // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted + x_qh[i * (WARP_SIZE/2) + i / 2 + k % (WARP_SIZE/2)] = ~get_int_from_uint8(bxi->hmask, k % (QI3_K/2)); + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { + int i = i0 + i_offset * 4 + k / (WARP_SIZE/4); + + if (need_check) { + i = min(i, i_max); + } + + const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI3_K/4); + + const int ksc = k % (QI3_K/4); + + const int ksc_low = ksc % (QI3_K/8); + const int shift_low = 4 * (ksc / (QI3_K/8)); + const int sc_low = (get_int_from_uint8(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F; + + const int ksc_high = QI3_K/8; + const int shift_high = 2 * ksc; + const int sc_high = ((get_int_from_uint8(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030; + + const int sc = __vsubss4(sc_low | sc_high, 0x20202020); + + x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = sc; + } +} + +static __device__ __forceinline__ float vec_dot_q3_K_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + + const int kbx = k / QI3_K; + const int ky = (k % QI3_K) * QR3_K; + const float * x_dmf = (const float *) x_dm; + const float * y_df = (const float *) y_ds; + + const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4; + + int v[QR3_K*VDR_Q3_K_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) { + const int kqsx = i * (WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2); + const int shift = 2 * ((ky % 32) / 8); + const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303; + + const int vh = x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8); + const int vlh = (vh << 2) & 0x04040404; + + v[l] = __vsubss4(vll, vlh); + } + + const int index_y = j * WARP_SIZE + (k*QR3_K) % WARP_SIZE; + return vec_dot_q3_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]); +} + +template static __device__ __forceinline__ void allocate_tiles_q4_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + GGML_UNUSED(x_qh); + + __shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_K) + mmq_y/QI4_K]; + __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/8) + mmq_y/8]; + + *x_ql = tile_x_ql; + *x_dm = tile_x_dm; + *x_sc = tile_x_sc; +} + +template static __device__ __forceinline__ void load_tiles_q4_K( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { + GGML_UNUSED(x_qh); + + GGML_CUDA_ASSUME(i_offset >= 0); + GGML_CUDA_ASSUME(i_offset < nwarps); + GGML_CUDA_ASSUME(k >= 0); + GGML_CUDA_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI4_K; // == 0 if QK_K == 256 + const int kqsx = k % QI4_K; // == k if QK_K == 256 + + const block_q4_K * bx0 = (const block_q4_K *) vx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_K * bxi = bx0 + i*blocks_per_row + kbx; + + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256 + const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_K) { + int i = (i0 + i_offset * QI4_K + k / blocks_per_tile_x_row) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_K * bxi = bx0 + i*blocks_per_row + kbxd; + +#if QK_K == 256 + x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm; +#else + x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = {bxi->dm[0], bxi->dm[1]}; +#endif + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { + int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI4_K/8); + + const int * scales = (const int *) bxi->scales; + + const int ksc = k % (WARP_SIZE/8); + + // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8 + int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits + scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits + + x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8; + } +} + +static __device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + GGML_UNUSED(x_qh); + + const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2*((k % 16) / 8); + + const int index_y = j * WARP_SIZE + (QR4_K*k) % WARP_SIZE; + return vec_dot_q4_K_q8_1_impl_mmq(&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[index_y], sc, sc+8, + x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]); +} + +template static __device__ __forceinline__ void allocate_tiles_q5_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + GGML_UNUSED(x_qh); + + __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_K) + mmq_y/QI5_K]; + __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/8) + mmq_y/8]; + + *x_ql = tile_x_ql; + *x_dm = tile_x_dm; + *x_sc = tile_x_sc; +} + +template static __device__ __forceinline__ void load_tiles_q5_K( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { + GGML_UNUSED(x_qh); + + GGML_CUDA_ASSUME(i_offset >= 0); + GGML_CUDA_ASSUME(i_offset < nwarps); + GGML_CUDA_ASSUME(k >= 0); + GGML_CUDA_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI5_K; // == 0 if QK_K == 256 + const int kqsx = k % QI5_K; // == k if QK_K == 256 + + const block_q5_K * bx0 = (const block_q5_K *) vx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_K * bxi = bx0 + i*blocks_per_row + kbx; + const int ky = QR5_K*kqsx; + + const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx); + const int ql0 = (ql >> 0) & 0x0F0F0F0F; + const int ql1 = (ql >> 4) & 0x0F0F0F0F; + + const int qh = get_int_from_uint8_aligned(bxi->qh, kqsx % (QI5_K/4)); + const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010; + const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010; + + const int kq0 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + 0; + const int kq1 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + (QI5_K/4); + + x_ql[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0; + x_ql[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1; + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256 + const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_K) { + int i = (i0 + i_offset * QI5_K + k / blocks_per_tile_x_row) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_K * bxi = bx0 + i*blocks_per_row + kbxd; + +#if QK_K == 256 + x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm; +#endif + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { + int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI5_K/8); + + const int * scales = (const int *) bxi->scales; + + const int ksc = k % (WARP_SIZE/8); + + // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8 + int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits + scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits + + x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8; + } +} + +static __device__ __forceinline__ float vec_dot_q5_K_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + GGML_UNUSED(x_qh); + + const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2 * ((k % 16) / 8); + + const int index_x = i * (QR5_K*WARP_SIZE + 1) + QR5_K*k; + const int index_y = j * WARP_SIZE + (QR5_K*k) % WARP_SIZE; + return vec_dot_q5_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, sc+8, + x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]); +} + +template static __device__ __forceinline__ void allocate_tiles_q6_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + GGML_UNUSED(x_qh); + + __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI6_K) + mmq_y/QI6_K]; + __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/8) + mmq_y/8]; + + *x_ql = tile_x_ql; + *x_dm = tile_x_dm; + *x_sc = tile_x_sc; +} + +template static __device__ __forceinline__ void load_tiles_q6_K( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { + GGML_UNUSED(x_qh); + + GGML_CUDA_ASSUME(i_offset >= 0); + GGML_CUDA_ASSUME(i_offset < nwarps); + GGML_CUDA_ASSUME(k >= 0); + GGML_CUDA_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI6_K; // == 0 if QK_K == 256 + const int kqsx = k % QI6_K; // == k if QK_K == 256 + + const block_q6_K * bx0 = (const block_q6_K *) vx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = min(i, i_max); + } + + const block_q6_K * bxi = bx0 + i*blocks_per_row + kbx; + const int ky = QR6_K*kqsx; + + const int ql = get_int_from_uint8(bxi->ql, kqsx); + const int ql0 = (ql >> 0) & 0x0F0F0F0F; + const int ql1 = (ql >> 4) & 0x0F0F0F0F; + + const int qh = get_int_from_uint8(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4)); + const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030; + const int qh1 = (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) & 0x30303030; + + const int kq0 = ky - ky % QI6_K + k % (QI6_K/2) + 0; + const int kq1 = ky - ky % QI6_K + k % (QI6_K/2) + (QI6_K/2); + + x_ql[i * (2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020); + x_ql[i * (2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020); + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256 + const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 + float * x_dmf = (float *) x_dm; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) { + int i = (i0 + i_offset * QI6_K + k / blocks_per_tile_x_row) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q6_K * bxi = bx0 + i*blocks_per_row + kbxd; + + x_dmf[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd] = bxi->d; + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { + int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q6_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / 4; + + x_sc[i * (WARP_SIZE/8) + i / 8 + k % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, k % (QI6_K/8)); + } +} + +static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + GGML_UNUSED(x_qh); + + const float * x_dmf = (const float *) x_dm; + const float * y_df = (const float *) y_ds; + + const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/8]); + + const int index_x = i * (QR6_K*WARP_SIZE + 1) + QR6_K*k; + const int index_y = j * WARP_SIZE + (QR6_K*k) % WARP_SIZE; + return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]); +} + + +static __device__ __forceinline__ float vec_dot_q4_0_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + + const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); + const float * x_dmf = (const float *) x_dm; + + int u[2*VDR_Q4_0_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) { + u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; + u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_0) % WARP_SIZE]; + } + + return vec_dot_q4_0_q8_1_impl + (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dmf[i * (WARP_SIZE/QI4_0) + i/QI4_0 + k/QI4_0], + y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); +} + +static __device__ __forceinline__ float vec_dot_q4_1_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + + const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); + + int u[2*VDR_Q4_1_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) { + u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; + u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_1) % WARP_SIZE]; + } + + return vec_dot_q4_1_q8_1_impl + (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dm[i * (WARP_SIZE/QI4_1) + i/QI4_1 + k/QI4_1], + y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); +} + + +extern "C" __global__ void + mul_mat_q4_0( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + const int mmq_x = MMQ_X_Q4_0_AMPERE; + const int mmq_y = MMQ_Y_Q4_0_AMPERE; + const int nwarps = NWARPS_Q4_0_AMPERE; + + mul_mat_q, + load_tiles_q4_0, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +} + +extern "C" __global__ void + mul_mat_q4_1( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + const int mmq_x = MMQ_X_Q4_1_AMPERE; + const int mmq_y = MMQ_Y_Q4_1_AMPERE; + const int nwarps = NWARPS_Q4_1_AMPERE; + + mul_mat_q, + load_tiles_q4_1, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +} + + +extern "C" __global__ void + mul_mat_q5_0( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + const int mmq_x = MMQ_X_Q5_0_AMPERE; + const int mmq_y = MMQ_Y_Q5_0_AMPERE; + const int nwarps = NWARPS_Q5_0_AMPERE; + + mul_mat_q, + load_tiles_q5_0, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +} + +extern "C" __global__ void +mul_mat_q5_1( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + const int mmq_x = MMQ_X_Q5_1_AMPERE; + const int mmq_y = MMQ_Y_Q5_1_AMPERE; + const int nwarps = NWARPS_Q5_1_AMPERE; + + mul_mat_q, + load_tiles_q5_1, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +} + +extern "C" __global__ void + mul_mat_q8_0( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + const int mmq_x = MMQ_X_Q8_0_AMPERE; + const int mmq_y = MMQ_Y_Q8_0_AMPERE; + const int nwarps = NWARPS_Q8_0_AMPERE; + + mul_mat_q, + load_tiles_q8_0, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +} + +extern "C" __global__ void +mul_mat_q2_K( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + const int mmq_x = MMQ_X_Q2_K_AMPERE; + const int mmq_y = MMQ_Y_Q2_K_AMPERE; + const int nwarps = NWARPS_Q2_K_AMPERE; + mul_mat_q, + load_tiles_q2_K, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +} + +extern "C" __global__ void + mul_mat_q3_K( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + const int mmq_x = MMQ_X_Q3_K_AMPERE; + const int mmq_y = MMQ_Y_Q3_K_AMPERE; + const int nwarps = NWARPS_Q3_K_AMPERE; + mul_mat_q, + load_tiles_q3_K, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +} + +extern "C" __global__ void + mul_mat_q4_K( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + const int mmq_x = MMQ_X_Q4_K_AMPERE; + const int mmq_y = MMQ_Y_Q4_K_AMPERE; + const int nwarps = NWARPS_Q4_K_AMPERE; + mul_mat_q, + load_tiles_q4_K, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +} + +extern "C" __global__ void +mul_mat_q5_K( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + const int mmq_x = MMQ_X_Q5_K_AMPERE; + const int mmq_y = MMQ_Y_Q5_K_AMPERE; + const int nwarps = NWARPS_Q5_K_AMPERE; + mul_mat_q, + load_tiles_q5_K, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +} + +extern "C" __global__ void + mul_mat_q6_K( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + const int mmq_x = MMQ_X_Q6_K_AMPERE; + const int mmq_y = MMQ_Y_Q6_K_AMPERE; + const int nwarps = NWARPS_Q6_K_AMPERE; + mul_mat_q, + load_tiles_q6_K, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +} diff --git a/mcandle/mcandle-kernels/src/reduce.cu b/mcandle/mcandle-kernels/src/reduce.cu new file mode 100644 index 000000000..7715fbf61 --- /dev/null +++ b/mcandle/mcandle-kernels/src/reduce.cu @@ -0,0 +1,618 @@ +#include "cuda_utils.cuh" +#include +#include + +#define WARP_SIZE 32 +const int BLOCK_SIZE = 1024; + +// TODO: Maybe add some fast_sum_f16_f32 variant that not only accumulate in f32 +// but also expect a f32 output so that this can be used for normalization e.g. +// in softmax. + +// Fast reduce sum kernel, this assumes that the dimensions to loop over are at +// the end, each block is responsible for populating one value in the output +// array. There are at most 1024 threads per block. +template +__device__ void +fast_sum(const size_t src_numel, const size_t el_to_sum_per_block, + const size_t num_dims, const size_t *info, const T *src, T *dst) { + const size_t *dims = info; + const size_t *strides = info + num_dims; + + __shared__ T shr[BLOCK_SIZE]; + size_t tid = threadIdx.x; + size_t dst_id = blockIdx.x; + + shr[tid] = 0; + // Elements summed in this block range from dst_id * el_to_sum_per_block + // to (dst_id + 1) * el_to_sum_per_block. + size_t start_idx = dst_id * el_to_sum_per_block; + size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); + size_t idx = start_idx + tid; + + while (idx < stop_idx) { + // TODO: Fast version for the contiguous case. + size_t strided_i = get_strided_index(idx, num_dims, dims, strides); + shr[tid] += src[strided_i]; + idx += blockDim.x; + } + + // Parallel reduction, see the slides: + // https://www.olcf.ornl.gov/wp-content/uploads/2019/12/05_Atomics_Reductions_Warp_Shuffle.pdf + // https://stackoverflow.com/questions/66078814/is-cuda-atomicadd-operation-faster-than-launch-another-kernel-when-we-do-reduce + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + __syncthreads(); + if (tid < s) + shr[tid] += shr[tid + s]; + } + + if (tid == 0) + dst[dst_id] = shr[0]; +} + +static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32); + a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32); + } + return a; +} + +static __device__ __forceinline__ float warp_reduce_sum(float x) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x += __shfl_xor_sync(0xffffffff, x, mask, 32); + } + return x; +} + +// LayerNorm implementation adapted from ggml, accumulation is made using f32. +// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L477 +template +__device__ void layernorm(const T * x, T * dst, const T * alpha, const T * beta, const int ncols, const int block_size, const float eps) { + const int row = blockIdx.x*blockDim.y + threadIdx.y; + const int tid = threadIdx.x; + + float2 mean_var = make_float2(0.f, 0.f); + + for (int col = tid; col < ncols; col += block_size) { + const float xi = x[row*ncols + col]; + mean_var.x += xi; + mean_var.y += xi * xi; + } + + // sum up partial sums + mean_var = warp_reduce_sum(mean_var); + if (block_size > WARP_SIZE) { + __shared__ float2 s_sum[32]; + int warp_id = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + s_sum[warp_id] = mean_var; + } + __syncthreads(); + mean_var = s_sum[lane_id]; + mean_var = warp_reduce_sum(mean_var); + } + + const float mean = mean_var.x / ncols; + const float var = mean_var.y / ncols - mean * mean; + const float inv_std = rsqrtf(var + eps); + + if (alpha == nullptr && beta == nullptr) { + for (int col = tid; col < ncols; col += block_size) { + float lhs = (static_cast(x[row*ncols + col]) - mean) * inv_std; + dst[row*ncols + col] = static_cast(lhs); + } + } + else if (alpha == nullptr && beta != nullptr) { + for (int col = tid; col < ncols; col += block_size) { + float b = static_cast(beta[col]); + float lhs = (static_cast(x[row*ncols + col]) - mean) * inv_std; + dst[row*ncols + col] = static_cast(lhs + b); + } + } + else if (alpha != nullptr && beta == nullptr) { + for (int col = tid; col < ncols; col += block_size) { + float a = static_cast(alpha[col]); + float lhs = (static_cast(x[row*ncols + col]) - mean) * inv_std; + dst[row*ncols + col] = static_cast(lhs * a); + } + } + else { + for (int col = tid; col < ncols; col += block_size) { + float a = static_cast(alpha[col]); + float b = static_cast(beta[col]); + float lhs = (static_cast(x[row*ncols + col]) - mean) * inv_std; + dst[row*ncols + col] = static_cast(lhs * a + b); + } + } +} + +// RmsNorm implementation adapted from ggml, accumulation is made using f32. +// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L523 +template +__device__ void rmsnorm(const T * x, T * dst, const T * alpha, const int ncols, const int block_size, const float eps) { + const int row = blockIdx.x*blockDim.y + threadIdx.y; + const int tid = threadIdx.x; + + float tmp = 0.0f; // partial sum for thread in warp + + for (int col = tid; col < ncols; col += block_size) { + const float xi = static_cast(x[row*ncols + col]); + tmp += xi * xi; + } + + // sum up partial sums + tmp = warp_reduce_sum(tmp); + if (block_size > WARP_SIZE) { + __shared__ float s_sum[32]; + int warp_id = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + s_sum[warp_id] = tmp; + } + __syncthreads(); + tmp = s_sum[lane_id]; + tmp = warp_reduce_sum(tmp); + } + + const float mean = tmp / ncols; + const float scale = rsqrtf(mean + eps); + + if (alpha == nullptr) { + for (int col = tid; col < ncols; col += block_size) { + dst[row*ncols + col] = static_cast(scale * static_cast(x[row*ncols + col])); + } + } + else { + for (int col = tid; col < ncols; col += block_size) { + float a = static_cast(alpha[col]); + dst[row*ncols + col] = static_cast(scale * static_cast(x[row*ncols + col]) * a); + } + } +} + +// Softmax implementation adapted from ggml. +// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L4159 +template +__device__ void softmax(const T * x, T * dst, const int ncols) { + const int row = blockDim.x*blockIdx.x + threadIdx.x; + const int block_size = blockDim.y; + const int tid = threadIdx.y; + + T max_val = -INFINITY; + + for (int col = tid; col < ncols; col += block_size) { + const int i = row*ncols + col; + max_val = maxg(max_val, x[i]); + } + + // find the max value in the block +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + max_val = maxg(max_val, __shfl_xor_sync(0xffffffff, max_val, mask, 32)); + } + + ACC tmp = 0.; + + for (int col = tid; col < ncols; col += block_size) { + const int i = row*ncols + col; + const T val = expg(x[i] - max_val); + tmp += static_cast(val); + dst[i] = val; + } + + // sum up partial sums +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + const ACC inv_tmp = 1. / tmp; + + for (int col = tid; col < ncols; col += block_size) { + const int i = row*ncols + col; + dst[i] *= inv_tmp; + } +} + +template +__device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (2 * idx >= bh * td) return; + + uint32_t rope_idx = idx % (td / 2); + T c = cos[rope_idx]; + T s = sin[rope_idx]; + + dst[2 * idx] = src[2 * idx] * c - src[2 * idx + 1] * s; + dst[2 * idx + 1] = src[2 * idx] * s + src[2 * idx + 1] * c; +} + +template +__device__ void rope(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td, const uint32_t d) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (2 * idx >= bh * td) return; + + uint32_t i_bh = idx / (td / 2); + uint32_t i_td = idx - (td / 2) * i_bh; + uint32_t i_t = i_td / (d / 2); + uint32_t i_d = i_td - (d / 2) * i_t; + uint32_t i1 = i_bh * td + i_t * d + i_d; + uint32_t i2 = i1 + d / 2; + uint32_t i_cs = i_t * (d / 2) + i_d; + T c = cos[i_cs]; + T s = sin[i_cs]; + + dst[i1] = src[i1] * c - src[i2] * s; + dst[i2] = src[i1] * s + src[i2] * c; +} + +template +__device__ void rope_thd( + const T * src, + const T * cos, + const T * sin, + T * dst, + const uint32_t b, + const uint32_t t, + const uint32_t h, + const uint32_t d +) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (2 * idx >= b * t * h * d) return; + + uint32_t i_bth = idx / (d / 2); + uint32_t i_d = idx - (d / 2) * i_bth; + uint32_t i_t = (i_bth / h) % t; + uint32_t i1 = i_bth * d + i_d; + uint32_t i2 = i1 + d / 2; + uint32_t i_cs = i_t * (d / 2) + i_d; + T c = cos[i_cs]; + T s = sin[i_cs]; + + dst[i1] = src[i1] * c - src[i2] * s; + dst[i2] = src[i1] * s + src[i2] * c; +} + +template +__device__ void +fast_max(const size_t src_numel, const size_t el_to_sum_per_block, + const size_t num_dims, const size_t *info, const T *src, T *dst) { + const size_t *dims = info; + const size_t *strides = info + num_dims; + + __shared__ T shr[BLOCK_SIZE]; + size_t tid = threadIdx.x; + size_t dst_id = blockIdx.x; + + shr[tid] = -INFINITY; + // Elements summed in this block range from dst_id * el_to_sum_per_block + // to (dst_id + 1) * el_to_sum_per_block. + size_t start_idx = dst_id * el_to_sum_per_block; + size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); + size_t idx = start_idx + tid; + + while (idx < stop_idx) { + // TODO: Fast version for the contiguous case. + size_t strided_i = get_strided_index(idx, num_dims, dims, strides); + shr[tid] = maxg(shr[tid], src[strided_i]); + idx += blockDim.x; + } + + // Parallel reduction, see the slides: + // https://www.olcf.ornl.gov/wp-content/uploads/2019/12/05_Atomics_Reductions_Warp_Shuffle.pdf + // https://stackoverflow.com/questions/66078814/is-cuda-atomicadd-operation-faster-than-launch-another-kernel-when-we-do-reduce + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + __syncthreads(); + if (tid < s) + shr[tid] = maxg(shr[tid], shr[tid + s]); + } + + if (tid == 0) + dst[dst_id] = shr[0]; +} + +template +__device__ void +fast_min(const size_t src_numel, const size_t el_to_sum_per_block, + const size_t num_dims, const size_t *info, const T *src, T *dst) { + const size_t *dims = info; + const size_t *strides = info + num_dims; + + __shared__ T shr[BLOCK_SIZE]; + size_t tid = threadIdx.x; + size_t dst_id = blockIdx.x; + + shr[tid] = INFINITY; + // Elements summed in this block range from dst_id * el_to_sum_per_block + // to (dst_id + 1) * el_to_sum_per_block. + size_t start_idx = dst_id * el_to_sum_per_block; + size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); + size_t idx = start_idx + tid; + + while (idx < stop_idx) { + // TODO: Fast version for the contiguous case. + size_t strided_i = get_strided_index(idx, num_dims, dims, strides); + shr[tid] = ming(shr[tid], src[strided_i]); + idx += blockDim.x; + } + + // Parallel reduction, see the slides: + // https://www.olcf.ornl.gov/wp-content/uploads/2019/12/05_Atomics_Reductions_Warp_Shuffle.pdf + // https://stackoverflow.com/questions/66078814/is-cuda-atomicadd-operation-faster-than-launch-another-kernel-when-we-do-reduce + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + __syncthreads(); + if (tid < s) + shr[tid] = ming(shr[tid], shr[tid + s]); + } + + if (tid == 0) + dst[dst_id] = shr[0]; +} + +template +__device__ void +fast_argmin(const size_t src_numel, const size_t el_to_sum_per_block, + const size_t num_dims, const size_t *info, const T *src, uint32_t *dst) { + const size_t *dims = info; + const size_t *strides = info + num_dims; + + __shared__ T shr[BLOCK_SIZE]; + __shared__ uint32_t shr_index[BLOCK_SIZE]; + size_t tid = threadIdx.x; + size_t dst_id = blockIdx.x; + + // Not sure how that works on uint32_t and uint8_t but it seems to do ok. + shr[tid] = INFINITY; + shr_index[tid] = 0xFFFFFFFF; + bool not_set = true; + // Elements summed in this block range from dst_id * el_to_sum_per_block + // to (dst_id + 1) * el_to_sum_per_block. + size_t start_idx = dst_id * el_to_sum_per_block; + size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); + size_t idx = start_idx + tid; + + while (idx < stop_idx) { + // TODO: Fast version for the contiguous case. + size_t strided_i = get_strided_index(idx, num_dims, dims, strides); + if (not_set || src[strided_i] < shr[tid]) { + shr[tid] = src[strided_i]; + // Assume that the reduction takes place over the last dimension which is contiguous. + shr_index[tid] = idx % dims[num_dims - 1]; + not_set = false; + } + idx += blockDim.x; + } + + // Parallel reduction, see the slides: + // https://www.olcf.ornl.gov/wp-content/uploads/2019/12/05_Atomics_Reductions_Warp_Shuffle.pdf + // https://stackoverflow.com/questions/66078814/is-cuda-atomicadd-operation-faster-than-launch-another-kernel-when-we-do-reduce + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + __syncthreads(); + if (tid < s && shr[tid + s] < shr[tid]) { + shr[tid] = shr[tid + s]; + shr_index[tid] = shr_index[tid + s]; + } + } + + if (tid == 0) + dst[dst_id] = shr_index[0]; +} + +template +__device__ void +fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block, + const size_t num_dims, const size_t *info, const T *src, uint32_t *dst) { + const size_t *dims = info; + const size_t *strides = info + num_dims; + + __shared__ T shr[BLOCK_SIZE]; + __shared__ uint32_t shr_index[BLOCK_SIZE]; + size_t tid = threadIdx.x; + size_t dst_id = blockIdx.x; + + shr[tid] = -INFINITY; + shr_index[tid] = 0xFFFFFFFF; + bool not_set = true; + // Elements summed in this block range from dst_id * el_to_sum_per_block + // to (dst_id + 1) * el_to_sum_per_block. + size_t start_idx = dst_id * el_to_sum_per_block; + size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); + size_t idx = start_idx + tid; + + while (idx < stop_idx) { + // TODO: Fast version for the contiguous case. + size_t strided_i = get_strided_index(idx, num_dims, dims, strides); + if (not_set || src[strided_i] > shr[tid]) { + shr[tid] = src[strided_i]; + // Assume that the reduction takes place over the last dimension which is contiguous. + shr_index[tid] = idx % dims[num_dims - 1]; + not_set = false; + } + idx += blockDim.x; + } + + // Parallel reduction, see the slides: + // https://www.olcf.ornl.gov/wp-content/uploads/2019/12/05_Atomics_Reductions_Warp_Shuffle.pdf + // https://stackoverflow.com/questions/66078814/is-cuda-atomicadd-operation-faster-than-launch-another-kernel-when-we-do-reduce + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + __syncthreads(); + if (tid < s && shr[tid + s] > shr[tid]) { + shr[tid] = shr[tid + s]; + shr_index[tid] = shr_index[tid + s]; + } + } + + if (tid == 0) + dst[dst_id] = shr_index[0]; +} + +#define FAST_OP(TYPENAME, MIN_NAME, MAX_NAME, ARGMIN_NAME, ARGMAX_NAME, SUM_NAME) \ + extern "C" __global__ void ARGMIN_NAME( \ + const size_t src_numel, const size_t el_to_sum_per_block, \ + const size_t num_dims, const size_t *info, const TYPENAME *src, \ + uint32_t *dst) { \ + fast_argmin(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \ + } \ + extern "C" __global__ void ARGMAX_NAME( \ + const size_t src_numel, const size_t el_to_sum_per_block, \ + const size_t num_dims, const size_t *info, const TYPENAME *src, \ + uint32_t *dst) { \ + fast_argmax(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \ + } \ + extern "C" __global__ void MIN_NAME( \ + const size_t src_numel, const size_t el_to_sum_per_block, \ + const size_t num_dims, const size_t *info, const TYPENAME *src, \ + TYPENAME *dst) { \ + fast_min(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \ + } \ + extern "C" __global__ void MAX_NAME( \ + const size_t src_numel, const size_t el_to_sum_per_block, \ + const size_t num_dims, const size_t *info, const TYPENAME *src, \ + TYPENAME *dst) { \ + fast_max(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \ + } \ + extern "C" __global__ void SUM_NAME( \ + const size_t src_numel, const size_t el_to_sum_per_block, \ + const size_t num_dims, const size_t *info, const TYPENAME *src, \ + TYPENAME *dst) { \ + fast_sum(src_numel, el_to_sum_per_block, num_dims, info, src, dst); \ + } + +#define SUM_OP(TYPENAME, FN_NAME) \ + extern "C" __global__ void FN_NAME( \ + const size_t numel, const size_t num_dims, const size_t num_sum_dims, \ + const size_t *info, const TYPENAME *inp, TYPENAME *out) { \ + const size_t *dims = info; \ + const size_t *strides = info + num_dims; \ + const size_t *sum_dims_l = info + 2 * num_dims; \ + const size_t *sum_dims_s = info + 2 * num_dims + num_sum_dims; \ + if (is_contiguous(num_dims, dims, strides)) { \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; \ + i += blockDim.x * gridDim.x) { \ + size_t dst_index = i; \ + for (unsigned int nd = 0; nd < num_sum_dims; ++nd) { \ + size_t stride = sum_dims_s[nd]; \ + size_t pre = dst_index / stride; \ + size_t post = dst_index % stride; \ + dst_index = (pre / sum_dims_l[nd]) * stride + post; \ + } \ + atomicAdd(out + dst_index, inp[i]); \ + } \ + } else { \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; \ + i += blockDim.x * gridDim.x) { \ + unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \ + size_t dst_index = i; \ + for (unsigned int nd = 0; nd < num_sum_dims; ++nd) { \ + size_t stride = sum_dims_s[nd]; \ + size_t pre = dst_index / stride; \ + size_t post = dst_index % stride; \ + dst_index = (pre / sum_dims_l[nd]) * stride + post; \ + } \ + atomicAdd(out + dst_index, inp[strided_i]); \ + } \ + } \ + } + +#define SOFTMAX_OP(TYPENAME, ACC_TYPENAME, FN_NAME) \ + extern "C" __global__ void FN_NAME( \ + const TYPENAME *src, TYPENAME *dst, \ + const int n_cols) { \ + softmax(src, dst, n_cols); \ + } \ + +#define RMSNORM_OP(TYPENAME, FN_NAME) \ + extern "C" __global__ void FN_NAME( \ + const TYPENAME *src, TYPENAME *dst, const TYPENAME *alpha, \ + const int n_cols, const int block_size, const float eps) { \ + rmsnorm(src, dst, alpha, n_cols, block_size, eps); \ + } \ + +#define LAYERNORM_OP(TYPENAME, FN_NAME) \ + extern "C" __global__ void FN_NAME( \ + const TYPENAME *src, TYPENAME *dst, const TYPENAME *alpha, \ + const TYPENAME *beta, const int n_cols, const int block_size, const float eps) { \ + layernorm(src, dst, alpha, beta, n_cols, block_size, eps); \ + } \ + +#define ROPE_OP(TYPENAME, FN_NAME, FN_NAME_I, FN_NAME_THD) \ + extern "C" __global__ void FN_NAME_I( \ + const TYPENAME *src, \ + const TYPENAME *cos, \ + const TYPENAME *sin, \ + TYPENAME *dst, \ + const uint32_t bh, \ + const uint32_t td) { \ + ropei(src, cos, sin, dst, bh, td); \ + } \ + extern "C" __global__ void FN_NAME( \ + const TYPENAME *src, \ + const TYPENAME *cos, \ + const TYPENAME *sin, \ + TYPENAME *dst, \ + const uint32_t bh, \ + const uint32_t td, \ + const uint32_t d) { \ + rope(src, cos, sin, dst, bh, td, d); \ + } \ + extern "C" __global__ void FN_NAME_THD( \ + const TYPENAME *src, \ + const TYPENAME *cos, \ + const TYPENAME *sin, \ + TYPENAME *dst, \ + const uint32_t b, \ + const uint32_t t, \ + const uint32_t h, \ + const uint32_t d) { \ + rope_thd(src, cos, sin, dst, b, t, h, d); \ + } \ + +#if __CUDA_ARCH__ >= 800 +SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16) +RMSNORM_OP(__nv_bfloat16, rmsnorm_bf16) +LAYERNORM_OP(__nv_bfloat16, layernorm_bf16) +ROPE_OP(__nv_bfloat16, rope_bf16, rope_i_bf16, rope_thd_bf16) +SUM_OP(__nv_bfloat16, sum_bf16) +FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16) + +// NOTE: No reduce ops for f8 +// SUM_OP(__nv_fp8_e4m3, sum_fp8_e4m3) +// SOFTMAX_OP(__nv_fp8_e4m3, float, softmax_fp8_e4m3) +// RMSNORM_OP(__nv_fp8_e4m3, rmsnorm_fp8_e4m3) +// LAYERNORM_OP(__nv_fp8_e4m3, layernorm_fp8_e4m3) +// ROPE_OP(__nv_fp8_e4m3, rope_fp8_e4m3, rope_i_fp8_e4m3, rope_thd_fp8_e4m3) +// FAST_OP(__nv_fp8_e4m3, fast_min_fp8_e4m3, fast_max_fp8_e4m3, fast_argmin_fp8_e4m3, fast_argmax_fp8_e4m3, fast_sum_fp8_e4m3) +#endif + +#if __CUDA_ARCH__ >= 530 +SOFTMAX_OP(__half, float, softmax_f16) +RMSNORM_OP(__half, rmsnorm_f16) +LAYERNORM_OP(__half, layernorm_f16) +ROPE_OP(__half, rope_f16, rope_i_f16, rope_thd_f16) +SUM_OP(__half, sum_f16) +FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fast_sum_f16) +#endif + +SUM_OP(float, sum_f32) +SUM_OP(double, sum_f64) +SUM_OP(uint32_t, sum_u32) +SOFTMAX_OP(float, float, softmax_f32) +SOFTMAX_OP(double, double, softmax_f64) +RMSNORM_OP(float, rmsnorm_f32) +RMSNORM_OP(double, rmsnorm_f64) +LAYERNORM_OP(float, layernorm_f32) +LAYERNORM_OP(double, layernorm_f64) +ROPE_OP(float, rope_f32, rope_i_f32, rope_thd_f32) +ROPE_OP(double, rope_f64, rope_i_f64, rope_thd_f64) + +FAST_OP(float, fast_min_f32, fast_max_f32, fast_argmin_f32, fast_argmax_f32, fast_sum_f32) +FAST_OP(double, fast_min_f64, fast_max_f64, fast_argmin_f64, fast_argmax_f64, fast_sum_f64) +FAST_OP(uint32_t, fast_min_u32, fast_max_u32, fast_argmin_u32, fast_argmax_u32, fast_sum_u32) +FAST_OP(int16_t, fast_min_i16, fast_max_i16, fast_argmin_i16, fast_argmax_i16, fast_sum_i16) +FAST_OP(int32_t, fast_min_i32, fast_max_i32, fast_argmin_i32, fast_argmax_i32, fast_sum_i32) +FAST_OP(int64_t, fast_min_i64, fast_max_i64, fast_argmin_i64, fast_argmax_i64, fast_sum_i64) +FAST_OP(uint8_t, fast_min_u8, fast_max_u8, fast_argmin_u8, fast_argmax_u8, fast_sum_u8) diff --git a/mcandle/mcandle-kernels/src/sort.cu b/mcandle/mcandle-kernels/src/sort.cu new file mode 100644 index 000000000..7db1b20ec --- /dev/null +++ b/mcandle/mcandle-kernels/src/sort.cu @@ -0,0 +1,93 @@ +// Adapted from https://github.com/ggerganov/llama.cpp/blob/master/ggml-cuda/argsort.cu +#define SORT_ORDER_ASC 1 +#define SORT_ORDER_DESC 0 +#include "cuda_utils.cuh" +#include + +template +static inline __device__ void ggml_cuda_swap(T & a, T & b) { + T tmp = a; + a = b; + b = tmp; +} + +template +static __device__ void k_argsort(const T * x, uint32_t * dst, const int ncols, int ncols_pad) { + // bitonic sort + int col = threadIdx.x; + int row = blockIdx.y; + + if (col >= ncols_pad) { + return; + } + + const T * x_row = x + row * ncols; + extern __shared__ int dst_row[]; + + // initialize indices + dst_row[col] = col; + + __syncthreads(); + + for (int k = 2; k <= ncols_pad; k *= 2) { + for (int j = k / 2; j > 0; j /= 2) { + int ixj = col ^ j; + if (ixj > col) { + if ((col & k) == 0) { + if (dst_row[col] >= ncols || + (dst_row[ixj] < ncols && (order == SORT_ORDER_ASC ? + x_row[dst_row[col]] > x_row[dst_row[ixj]] : + x_row[dst_row[col]] < x_row[dst_row[ixj]])) + ) { + ggml_cuda_swap(dst_row[col], dst_row[ixj]); + } + } else { + if (dst_row[ixj] >= ncols || + (dst_row[col] < ncols && (order == SORT_ORDER_ASC ? + x_row[dst_row[col]] < x_row[dst_row[ixj]] : + x_row[dst_row[col]] > x_row[dst_row[ixj]])) + ) { + ggml_cuda_swap(dst_row[col], dst_row[ixj]); + } + } + } + __syncthreads(); + } + } + + // copy the result to dst without the padding + if (col < ncols) { + dst[row * ncols + col] = dst_row[col]; + } +} + +#define ASORT_OP(TYPENAME, RUST_NAME) \ +extern "C" __global__ void asort_asc_##RUST_NAME( \ + const TYPENAME * x, uint32_t * dst, const int ncols, int ncols_pad \ +) { \ + k_argsort(x, dst, ncols, ncols_pad); \ +} \ +extern "C" __global__ void asort_desc_##RUST_NAME( \ + const TYPENAME * x, uint32_t * dst, const int ncols, int ncols_pad \ +) { \ + k_argsort(x, dst, ncols, ncols_pad); \ +} \ + +#if __CUDA_ARCH__ >= 800 +ASORT_OP(__nv_bfloat16, bf16) + +// NOTE: No sort ops for f8 +// ASORT_OP(__nv_fp8_e4m3, fp8_e4m3) +#endif + +#if __CUDA_ARCH__ >= 530 +ASORT_OP(__half, f16) +#endif + +ASORT_OP(float, f32) +ASORT_OP(double, f64) +ASORT_OP(uint8_t, u8) +ASORT_OP(uint32_t, u32) +ASORT_OP(int16_t, i16) +ASORT_OP(int32_t, i32) +ASORT_OP(int64_t, i64) diff --git a/mcandle/mcandle-kernels/src/ternary.cu b/mcandle/mcandle-kernels/src/ternary.cu new file mode 100644 index 000000000..c426640b3 --- /dev/null +++ b/mcandle/mcandle-kernels/src/ternary.cu @@ -0,0 +1,91 @@ +#include "cuda_utils.cuh" +#include + +#define WHERE_OP(TYPENAME, ID_TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const size_t numel, \ + const size_t num_dims, \ + const size_t *info, \ + const ID_TYPENAME *ids, \ + const TYPENAME *t, \ + const TYPENAME *f, \ + TYPENAME *out \ +) { \ + const size_t *dims = info; \ + const size_t *strides = info + num_dims; \ + const size_t *strides_t = info + 2*num_dims; \ + const size_t *strides_f = info + 3*num_dims; \ + if (is_contiguous(num_dims, dims, strides) \ + && is_contiguous(num_dims, dims, strides_f) \ + && is_contiguous(num_dims, dims, strides_t)) { \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + out[i] = ids[i] ? t[i] : f[i]; \ + } \ + } \ + else { \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \ + unsigned strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \ + unsigned strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \ + out[i] = ids[strided_i] ? t[strided_i_t] : f[strided_i_f]; \ + } \ + } \ +} \ + +#if __CUDA_ARCH__ >= 800 +WHERE_OP(__nv_bfloat16, int16_t, where_i16_bf16) +WHERE_OP(__nv_bfloat16, int32_t, where_i32_bf16) +WHERE_OP(__nv_bfloat16, int64_t, where_i64_bf16) +WHERE_OP(__nv_bfloat16, uint32_t, where_u32_bf16) +WHERE_OP(__nv_bfloat16, uint8_t, where_u8_bf16) + +WHERE_OP(__nv_fp8_e4m3, int16_t, where_i16_fp8_e4m3) +WHERE_OP(__nv_fp8_e4m3, int32_t, where_i32_fp8_e4m3) +WHERE_OP(__nv_fp8_e4m3, int64_t, where_i64_fp8_e4m3) +WHERE_OP(__nv_fp8_e4m3, uint32_t, where_u32_fp8_e4m3) +WHERE_OP(__nv_fp8_e4m3, uint8_t, where_u8_fp8_e4m3) +#endif + +#if __CUDA_ARCH__ >= 530 +WHERE_OP(__half, int16_t, where_i16_f16) +WHERE_OP(__half, int32_t, where_i32_f16) +WHERE_OP(__half, int64_t, where_i64_f16) +WHERE_OP(__half, uint32_t, where_u32_f16) +WHERE_OP(__half, uint8_t, where_u8_f16) +#endif + +WHERE_OP(float, int16_t, where_i16_f32) +WHERE_OP(double, int16_t, where_i16_f64) +WHERE_OP(uint8_t, int16_t, where_i16_u8) +WHERE_OP(uint32_t, int16_t, where_i16_u32) +WHERE_OP(int16_t, int16_t, where_i16_i16) +WHERE_OP(int32_t, int16_t, where_i16_i32) +WHERE_OP(int64_t, int16_t, where_i16_i64) + +WHERE_OP(float, int32_t, where_i32_f32) +WHERE_OP(double, int32_t, where_i32_f64) +WHERE_OP(uint8_t, int32_t, where_i32_u8) +WHERE_OP(uint32_t, int32_t, where_i32_u32) +WHERE_OP(int32_t, int32_t, where_i32_i64) + +WHERE_OP(float, int64_t, where_i64_f32) +WHERE_OP(double, int64_t, where_i64_f64) +WHERE_OP(uint8_t, int64_t, where_i64_u8) +WHERE_OP(uint32_t, int64_t, where_i64_u32) +WHERE_OP(int64_t, int64_t, where_i64_i64) + +WHERE_OP(float, uint32_t, where_u32_f32) +WHERE_OP(double, uint32_t, where_u32_f64) +WHERE_OP(uint8_t, uint32_t, where_u32_u8) +WHERE_OP(uint32_t, uint32_t, where_u32_u32) +WHERE_OP(int16_t, uint32_t, where_u32_i16) +WHERE_OP(int32_t, uint32_t, where_u32_i32) +WHERE_OP(int64_t, uint32_t, where_u32_i64) + +WHERE_OP(float, uint8_t, where_u8_f32) +WHERE_OP(double, uint8_t, where_u8_f64) +WHERE_OP(uint8_t, uint8_t, where_u8_u8) +WHERE_OP(uint32_t, uint8_t, where_u8_u32) +WHERE_OP(int16_t, uint8_t, where_u8_i16) +WHERE_OP(int32_t, uint8_t, where_u8_i32) +WHERE_OP(int64_t, uint8_t, where_u8_i64) diff --git a/mcandle/mcandle-kernels/src/unary.cu b/mcandle/mcandle-kernels/src/unary.cu new file mode 100644 index 000000000..ca888ca3f --- /dev/null +++ b/mcandle/mcandle-kernels/src/unary.cu @@ -0,0 +1,233 @@ +#define _USE_MATH_DEFINES +#include +#include +#include "cuda_utils.cuh" + +#define UNARY_OP(TYPENAME, FN_NAME, FUNC) \ +extern "C" __global__ void FN_NAME( \ + const size_t numel, \ + const size_t num_dims, \ + const size_t *info, \ + const TYPENAME *inp, \ + TYPENAME *out \ +) { \ + const size_t *dims = info; \ + const size_t *strides = info + num_dims; \ + if (info == nullptr || is_contiguous(num_dims, dims, strides)) { \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + TYPENAME x = inp ? inp[i] : out[i]; \ + out[i] = FUNC; \ + } \ + } \ + else { \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \ + TYPENAME x = inp ? inp[strided_i] : out[i]; \ + out[i] = FUNC; \ + } \ + } \ +} \ + +template +__device__ __forceinline__ T gelu_erf_fwd(T x) { + return x * normcdfg(x); +} + +template +__device__ __forceinline__ T gelu_fwd(T x) { + T x_sq = x * x; + T x_cube = x_sq * x; + T alpha = x + static_cast(0.044715) * x_cube; + return static_cast(0.5) * x * (static_cast(1.0) + tanhg(static_cast(M_2_SQRTPI * M_SQRT1_2) * alpha)); +} + +template +__device__ __forceinline__ T elu_fwd(T x, T alpha) { + if (x > static_cast(0)) { + return x; + } + return alpha * (expg(x) - static_cast(1)); +} + +template +__device__ __forceinline__ T relu_fwd(T x) { + T zero = 0.; + return maxg(x, zero); +} + +template +__device__ __forceinline__ T silu_fwd(T x) { + return x / (static_cast(1) + expg(-x)); +} + +template +__device__ __forceinline__ T sigmoid_fwd(T x) { + return recipg(static_cast(1) + expg(-x)); +} + +#define UNARY_OP1(TYPENAME, FN_NAME, FUNC) \ +extern "C" __global__ void FN_NAME( \ + const size_t numel, \ + const size_t num_dims, \ + const size_t *info, \ + const TYPENAME param, \ + const TYPENAME *inp, \ + TYPENAME *out \ +) { \ + const size_t *dims = info; \ + const size_t *strides = info + num_dims; \ + if (info == nullptr || is_contiguous(num_dims, dims, strides)) { \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + TYPENAME x = inp ? inp[i] : out[i]; \ + out[i] = FUNC; \ + } \ + } \ + else { \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \ + TYPENAME x = inp ? inp[strided_i] : out[i]; \ + out[i] = FUNC; \ + } \ + } \ +} \ + +template +__device__ T sign_(T t) { + return static_cast(t > static_cast(0)) - static_cast(t < static_cast(0)); +} + + +#if __CUDA_ARCH__ >= 800 +UNARY_OP(__nv_bfloat16, ucopy_bf16, x) +UNARY_OP(__nv_bfloat16, uneg_bf16, -x) +UNARY_OP(__nv_bfloat16, urecip_bf16, recipg(x)) +UNARY_OP(__nv_bfloat16, uexp_bf16, expg(x)) +UNARY_OP(__nv_bfloat16, ulog_bf16, logg(x)) +UNARY_OP(__nv_bfloat16, usin_bf16, sing(x)) +UNARY_OP(__nv_bfloat16, ucos_bf16, cosg(x)) +UNARY_OP(__nv_bfloat16, utanh_bf16, tanhg(x)) +UNARY_OP(__nv_bfloat16, uerf_bf16, erfg(x)) +UNARY_OP(__nv_bfloat16, uceil_bf16, ceilg(x)) +UNARY_OP(__nv_bfloat16, ufloor_bf16, floorg(x)) +UNARY_OP(__nv_bfloat16, uround_bf16, roundg(x)) +UNARY_OP(__nv_bfloat16, unormcdf_bf16, normcdfg(x)) +UNARY_OP(__nv_bfloat16, uabs_bf16, absg(x)) +UNARY_OP(__nv_bfloat16, usqr_bf16, x*x) +UNARY_OP(__nv_bfloat16, usqrt_bf16, sqrtg(x)) +UNARY_OP(__nv_bfloat16, ugelu_bf16, gelu_fwd(x)) +UNARY_OP(__nv_bfloat16, ugelu_erf_bf16, gelu_erf_fwd(x)) +UNARY_OP(__nv_bfloat16, urelu_bf16, relu_fwd(x)) +UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param)) +UNARY_OP(__nv_bfloat16, usilu_bf16, silu_fwd(x)) +UNARY_OP1(__nv_bfloat16, upowf_bf16, powg(x, param)) +UNARY_OP(__nv_bfloat16, usign_bf16, sign_(x)) +UNARY_OP(__nv_bfloat16, usigmoid_bf16, sigmoid_fwd(x)) + +#define F8E4M3_TO_FLOAT(x) __half2float(__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3)) + +UNARY_OP(__nv_fp8_e4m3, ucopy_f8_e4m3, x) +UNARY_OP(__nv_fp8_e4m3, uneg_fp8_e4m3, __nv_fp8_e4m3(-F8E4M3_TO_FLOAT(x))) +UNARY_OP(__nv_fp8_e4m3, urecip_fp8_e4m3, recipg(x)) +UNARY_OP(__nv_fp8_e4m3, uexp_fp8_e4m3, expg(x)) +UNARY_OP(__nv_fp8_e4m3, ulog_fp8_e4m3, logg(x)) +UNARY_OP(__nv_fp8_e4m3, usin_fp8_e4m3, sing(x)) +UNARY_OP(__nv_fp8_e4m3, ucos_fp8_e4m3, cosg(x)) +UNARY_OP(__nv_fp8_e4m3, utanh_fp8_e4m3, tanhg(x)) +UNARY_OP(__nv_fp8_e4m3, uerf_fp8_e4m3, erfg(x)) +UNARY_OP(__nv_fp8_e4m3, uceil_fp8_e4m3, ceilg(x)) +UNARY_OP(__nv_fp8_e4m3, ufloor_fp8_e4m3, floorg(x)) +UNARY_OP(__nv_fp8_e4m3, uround_fp8_e4m3, roundg(x)) +UNARY_OP(__nv_fp8_e4m3, unormcdf_fp8_e4m3, normcdfg(x)) +UNARY_OP(__nv_fp8_e4m3, uabs_fp8_e4m3, absg(x)) +UNARY_OP(__nv_fp8_e4m3, usqr_fp8_e4m3, __nv_fp8_e4m3(F8E4M3_TO_FLOAT(x)*F8E4M3_TO_FLOAT(x))) +UNARY_OP(__nv_fp8_e4m3, usqrt_fp8_e4m3, sqrtg(x)) +UNARY_OP(__nv_fp8_e4m3, ugelu_fp8_e4m3, __nv_fp8_e4m3(gelu_fwd(F8E4M3_TO_FLOAT(x)))) +UNARY_OP(__nv_fp8_e4m3, ugelu_erf_fp8_e4m3, __nv_fp8_e4m3(gelu_erf_fwd(F8E4M3_TO_FLOAT(x)))) +UNARY_OP(__nv_fp8_e4m3, urelu_fp8_e4m3, __nv_fp8_e4m3(relu_fwd(F8E4M3_TO_FLOAT(x)))) +UNARY_OP1(__nv_fp8_e4m3, uelu_fp8_e4m3, __nv_fp8_e4m3(elu_fwd(F8E4M3_TO_FLOAT(x), F8E4M3_TO_FLOAT(param)))) +UNARY_OP(__nv_fp8_e4m3, usilu_fp8_e4m3, __nv_fp8_e4m3(silu_fwd(F8E4M3_TO_FLOAT(x)))) +UNARY_OP1(__nv_fp8_e4m3, upowf_fp8_e4m3, powg(x, param)) +UNARY_OP(__nv_fp8_e4m3, usign_fp8_e4m3, __nv_fp8_e4m3(sign_(F8E4M3_TO_FLOAT(x)))) +UNARY_OP(__nv_fp8_e4m3, usigmoid_fp8_e4m3, __nv_fp8_e4m3(sigmoid_fwd(F8E4M3_TO_FLOAT(x)))) +#endif + +#if __CUDA_ARCH__ >= 530 +UNARY_OP(__half, ucopy_f16, x) +UNARY_OP(__half, uneg_f16, -x) +UNARY_OP(__half, urecip_f16, recipg(x)) +UNARY_OP(__half, uexp_f16, expg(x)) +UNARY_OP(__half, ulog_f16, logg(x)) +UNARY_OP(__half, usin_f16, sing(x)) +UNARY_OP(__half, ucos_f16, cosg(x)) +UNARY_OP(__half, utanh_f16, tanhg(x)) +UNARY_OP(__half, uerf_f16, erfg(x)) +UNARY_OP(__half, uceil_f16, ceilg(x)) +UNARY_OP(__half, ufloor_f16, floorg(x)) +UNARY_OP(__half, uround_f16, roundg(x)) +UNARY_OP(__half, unormcdf_f16, normcdfg(x)) +UNARY_OP(__half, uabs_f16, absg(x)) +UNARY_OP(__half, usqr_f16, x*x) +UNARY_OP(__half, usqrt_f16, sqrtg(x)) +UNARY_OP(__half, ugelu_f16, gelu_fwd(x)) +UNARY_OP(__half, ugelu_erf_f16, gelu_erf_fwd(x)) +UNARY_OP(__half, urelu_f16, relu_fwd(x)) +UNARY_OP1(__half, uelu_f16, elu_fwd(x, param)) +UNARY_OP(__half, usilu_f16, silu_fwd(x)) +UNARY_OP1(__half, upowf_f16, powg(x, param)) +UNARY_OP(__half, usign_f16, sign_(x)) +UNARY_OP(__half, usigmoid_f16, sigmoid_fwd(x)) +#endif + +UNARY_OP(uint8_t, ucopy_u8, x) +UNARY_OP(uint32_t, ucopy_u32, x) +UNARY_OP(int16_t, ucopy_i16, x) +UNARY_OP(int32_t, ucopy_i32, x) +UNARY_OP(int64_t, ucopy_i64, x) +UNARY_OP(float, ucopy_f32, x) +UNARY_OP(double, ucopy_f64, x) +UNARY_OP(float, uneg_f32, -x) +UNARY_OP(double, uneg_f64, -x) +UNARY_OP(float, urecip_f32, recipg(x)) +UNARY_OP(double, urecip_f64, recipg(x)) +UNARY_OP(float, uexp_f32, expg(x)) +UNARY_OP(double, uexp_f64, expg(x)) +UNARY_OP(float, ulog_f32, logg(x)) +UNARY_OP(double, ulog_f64, logg(x)) +UNARY_OP(float, usin_f32, sing(x)) +UNARY_OP(double, usin_f64, sing(x)) +UNARY_OP(float, ucos_f32, cosg(x)) +UNARY_OP(double, ucos_f64, cosg(x)) +UNARY_OP(float, utanh_f32, tanhg(x)) +UNARY_OP(double, utanh_f64, tanhg(x)) +UNARY_OP(float, uerf_f32, erfg(x)) +UNARY_OP(double, uerf_f64, erfg(x)) +UNARY_OP(float, uceil_f32, ceilg(x)) +UNARY_OP(double, uceil_f64, ceilg(x)) +UNARY_OP(float, ufloor_f32, floorg(x)) +UNARY_OP(double, ufloor_f64, floorg(x)) +UNARY_OP(float, uround_f32, roundg(x)) +UNARY_OP(double, uround_f64, roundg(x)) +UNARY_OP(float, unormcdf_f32, normcdfg(x)) +UNARY_OP(double, unormcdf_f64, normcdfg(x)) +UNARY_OP(float, uabs_f32, absg(x)) +UNARY_OP(double, uabs_f64, absg(x)) +UNARY_OP(float, usqr_f32, x*x) +UNARY_OP(double, usqr_f64, x*x) +UNARY_OP(float, usqrt_f32, sqrtg(x)) +UNARY_OP(double, usqrt_f64, sqrtg(x)) +UNARY_OP(float, ugelu_f32, gelu_fwd(x)) +UNARY_OP(double, ugelu_f64, gelu_fwd(x)) +UNARY_OP(float, ugelu_erf_f32, gelu_erf_fwd(x)) +UNARY_OP(double, ugelu_erf_f64, gelu_erf_fwd(x)) +UNARY_OP(float, urelu_f32, relu_fwd(x)) +UNARY_OP(double, urelu_f64, relu_fwd(x)) +UNARY_OP1(float, uelu_f32, elu_fwd(x, param)) +UNARY_OP1(double, uelu_f64, elu_fwd(x, param)) +UNARY_OP(float, usilu_f32, silu_fwd(x)) +UNARY_OP(double, usilu_f64, silu_fwd(x)) +UNARY_OP1(float, upowf_f32, powg(x, param)) +UNARY_OP1(double, upowf_f64, powg(x, param)) +UNARY_OP(float, usign_f32, sign_(x)) +UNARY_OP(double, usign_f64, sign_(x)) +UNARY_OP(float, usigmoid_f32, sigmoid_fwd(x)) +UNARY_OP(double, usigmoid_f64, sigmoid_fwd(x)) diff --git a/mcandle/mcandle-metal-kernels/Cargo.toml b/mcandle/mcandle-metal-kernels/Cargo.toml new file mode 100644 index 000000000..bb8e7ffc0 --- /dev/null +++ b/mcandle/mcandle-metal-kernels/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "mcandle-metal-kernels" +version = "0.3.2" +edition = "2021" + +description = "Metal kernels for Candle" +repository = "https://github.com/huggingface/mcandle" +keywords = ["blas", "tensor", "machine-learning"] +categories = ["science"] +license = "MIT OR Apache-2.0" + + +[dependencies] +metal = { version = "0.27.0", features = ["mps"] } +once_cell = "1.18.0" +thiserror = "1" +tracing = "0.1.37" + +[dev-dependencies] +clap = { version = "4.2.4", features = ["derive"] } +half = { version = "2.3.1", features = [ + "num-traits", + "use-intrinsics", + "rand_distr", +] } +anyhow = "1" +rand = "0.8.5" +rand_distr = "0.4.3" diff --git a/mcandle/mcandle-metal-kernels/README.md b/mcandle/mcandle-metal-kernels/README.md new file mode 100644 index 000000000..b2c6836ac --- /dev/null +++ b/mcandle/mcandle-metal-kernels/README.md @@ -0,0 +1,3 @@ +# mcandle-metal-kernels + +This crate contains Metal kernels used from mcandle. \ No newline at end of file diff --git a/mcandle/mcandle-metal-kernels/examples/metal_benchmarks.rs b/mcandle/mcandle-metal-kernels/examples/metal_benchmarks.rs new file mode 100644 index 000000000..0c22bf612 --- /dev/null +++ b/mcandle/mcandle-metal-kernels/examples/metal_benchmarks.rs @@ -0,0 +1,136 @@ +use anyhow::Result; +use mcandle_metal_kernels::GemmDType; +/// This example contains some simple benchmarks so that it's easy to run them in perf etc. +use clap::{Parser, Subcommand}; +use half::f16; + +fn run_gemm(f32: bool, n: usize) -> Result<()> { + const WARMUP_ITERS: usize = 2; + const MIN_DUR: f64 = 4.; + + let device = metal::Device::system_default().unwrap(); + + let (b, m, n, k) = (1, n, n, n); + let kernels = mcandle_metal_kernels::Kernels::new(); + let command_queue = device.new_command_queue(); + let options = metal::MTLResourceOptions::StorageModeManaged; + + let (lhs, rhs) = if f32 { + let lhs: Vec = (0..b * m * k).map(|f| f as f32).collect(); + let rhs: Vec = (0..b * n * k).map(|f| f as f32).collect(); + let lhs = device.new_buffer_with_data( + lhs.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(&lhs) as u64, + options, + ); + let rhs = device.new_buffer_with_data( + rhs.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(&rhs) as u64, + options, + ); + (lhs, rhs) + } else { + let lhs: Vec = (0..b * m * k).map(|f| f16::from_f32(f as f32)).collect(); + let rhs: Vec = (0..b * n * k).map(|f| f16::from_f32(f as f32)).collect(); + let lhs = device.new_buffer_with_data( + lhs.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(&lhs) as u64, + options, + ); + let rhs = device.new_buffer_with_data( + rhs.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(&rhs) as u64, + options, + ); + (lhs, rhs) + }; + let (dtype, name, sizeof) = if f32 { + (GemmDType::F32, "sgemm", core::mem::size_of::()) + } else { + (GemmDType::F16, "hgemm", core::mem::size_of::()) + }; + let output = device.new_buffer((b * m * n * sizeof) as u64, options); + + for mlx in [false, true] { + let mut sum_dt = 0f64; + let mut iters = 0usize; + for idx in 0.. { + let command_buffer = command_queue.new_command_buffer(); + let start_time = std::time::Instant::now(); + if mlx { + mcandle_metal_kernels::call_mlx_gemm( + &device, + command_buffer, + &kernels, + dtype, + (b, m, n, k), + &[m * k, k, 1], + 0, + &lhs, + &[n * k, n, 1], + 0, + &rhs, + &output, + )?; + } else { + mcandle_metal_kernels::call_gemm( + &device, + command_buffer, + &kernels, + name, + (b, m, n, k), + &[m * k, k, 1], + 0, + &lhs, + &[n * k, n, 1], + 0, + &rhs, + &output, + )?; + } + command_buffer.commit(); + command_buffer.wait_until_completed(); + let dt = start_time.elapsed().as_secs_f64(); + if idx < WARMUP_ITERS { + continue; + } + sum_dt += dt; + iters += 1; + if sum_dt > MIN_DUR { + break; + } + } + let gflops = (2 * n * n * n * iters) as f64 / (1e9 * sum_dt); + let mlx = if mlx { "MLX" } else { "MFA" }; + println!("{mlx} {dtype:?}, {n:6} gflops {gflops:.0}"); + } + + Ok(()) +} + +#[derive(Subcommand, Debug, Clone)] +enum Task { + Gemm, +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +pub struct Args { + /// The benchmark to be run. + #[command(subcommand)] + task: Task, +} + +fn main() -> Result<()> { + let args = Args::parse(); + match args.task { + Task::Gemm => { + for f32 in [false, true] { + for n in [512, 1024, 2048, 4096] { + run_gemm(f32, n)?; + } + } + } + } + Ok(()) +} diff --git a/mcandle/mcandle-metal-kernels/src/affine.metal b/mcandle/mcandle-metal-kernels/src/affine.metal new file mode 100644 index 000000000..e5229f55e --- /dev/null +++ b/mcandle/mcandle-metal-kernels/src/affine.metal @@ -0,0 +1,126 @@ +#include + +METAL_FUNC uint get_strided_index( + uint idx, + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides +) { + uint strided_i = 0; + for (uint d = 0; d < num_dims; d++) { + uint dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; +} + +using namespace metal; + +#define AFFINE(FN_NAME, T) \ +kernel void FN_NAME( \ + constant size_t &dim, \ + constant float &mul, \ + constant float &add, \ + device const T *input, \ + device T *output, \ + uint id [[ thread_position_in_grid ]] \ +) { \ + if (id >= dim) { \ + return; \ + } \ + output[id] = T(fma(float(input[id]), mul, add)); \ +} \ +kernel void FN_NAME##_strided( \ + constant size_t &dim, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + constant float &mul, \ + constant float &add, \ + device const T *input, \ + device T *output, \ + uint id [[ thread_position_in_grid ]] \ +) { \ + if (id >= dim) { \ + return; \ + } \ + output[id] = T(fma(float(input[get_strided_index(id, num_dims, dims, strides)]), mul, add)); \ +} + +#define POWF(FN_NAME, TYPENAME) \ +kernel void FN_NAME( \ + constant size_t &dim, \ + constant float &mul, \ + device const TYPENAME *input, \ + device TYPENAME *output, \ + uint id [[ thread_position_in_grid ]] \ +) { \ + if (id >= dim) { \ + return; \ + } \ + output[id] = TYPENAME(pow(input[id], TYPENAME(mul))); \ +} \ +kernel void FN_NAME##_strided( \ + constant size_t &dim, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + constant float &mul, \ + device const TYPENAME *input, \ + device TYPENAME *output, \ + uint id [[ thread_position_in_grid ]] \ +) { \ + if (id >= dim) { \ + return; \ + } \ + output[id] = TYPENAME(pow(input[get_strided_index(id, num_dims, dims, strides)], TYPENAME(mul))); \ +} + +#define ELU(FN_NAME, TYPENAME) \ +kernel void FN_NAME( \ + constant size_t &dim, \ + constant float &mul, \ + device const TYPENAME *input, \ + device TYPENAME *output, \ + uint id [[ thread_position_in_grid ]] \ +) { \ + if (id >= dim) { \ + return; \ + } \ + const TYPENAME x = input[id]; \ + output[id] = TYPENAME((x > 0)?x: mul * (exp(x) - 1)); \ +} \ +kernel void FN_NAME##_strided( \ + constant size_t &dim, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + constant float &mul, \ + device const TYPENAME *input, \ + device TYPENAME *output, \ + uint id [[ thread_position_in_grid ]] \ +) { \ + if (id >= dim) { \ + return; \ + } \ + const TYPENAME x = input[get_strided_index(id, num_dims, dims, strides)]; \ + output[id] = TYPENAME((x > 0)?x: mul * (exp(x) - 1)); \ +} \ + + +AFFINE(affine_u8, uint8_t) +AFFINE(affine_u32, uint32_t) +AFFINE(affine_f32, float) +AFFINE(affine_f16, half) +POWF(powf_f32, float) +POWF(powf_f16, half) +ELU(elu_f32, float) +ELU(elu_f16, half) + + +#if defined(__HAVE_BFLOAT__) +AFFINE(affine_bf16, bfloat); +POWF(powf_bf16, bfloat); +ELU(elu_bf16, bfloat); +#endif diff --git a/mcandle/mcandle-metal-kernels/src/binary.metal b/mcandle/mcandle-metal-kernels/src/binary.metal new file mode 100644 index 000000000..4c558c2cd --- /dev/null +++ b/mcandle/mcandle-metal-kernels/src/binary.metal @@ -0,0 +1,129 @@ +#include + +#define MAX(x, y) ((x) > (y) ? (x) : (y)) +#define MIN(x, y) ((x) < (y) ? (x) : (y)) + +METAL_FUNC uint get_strided_index( + uint idx, + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides +) { + uint strided_i = 0; + for (uint d = 0; d < num_dims; d++) { + uint dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; +} + +using namespace metal; + +#define BINARY(FN, TYPENAME, OUT_TYPENAME, FN_NAME, FN_NAME_STRIDED) \ +kernel void FN_NAME( \ + constant size_t &dim, \ + device const TYPENAME *left, \ + device const TYPENAME *right, \ + device OUT_TYPENAME *output, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + if (tid >= dim) { \ + return; \ + } \ + TYPENAME x = left[tid]; \ + TYPENAME y = right[tid]; \ + output[tid] = OUT_TYPENAME(FN); \ +}\ +kernel void FN_NAME_STRIDED( \ + constant size_t &dim, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *left_strides, \ + constant size_t *right_strides, \ + device const TYPENAME *left, \ + device const TYPENAME *right, \ + device OUT_TYPENAME *output, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + if (tid >= dim) { \ + return; \ + } \ + TYPENAME x = left[get_strided_index(tid, num_dims, dims, left_strides)]; \ + TYPENAME y = right[get_strided_index(tid, num_dims, dims, right_strides)]; \ + output[tid] = OUT_TYPENAME(FN); \ +} + +#define BINARY_OP(FN, NAME) \ +BINARY(FN, float, float, NAME##_f32, NAME##_f32_strided); \ +BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided); \ +BINARY(FN, uint32_t, uint32_t, NAME##_u32, NAME##_u32_strided); \ +BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided); \ +BINARY(FN, int16_t, int16_t, NAME##_i16, NAME##_i16_strided); \ +BINARY(FN, int32_t, int32_t, NAME##_i32, NAME##_i32_strided); + +#define BINARY_OP_OUT(NAME, FN) \ +BINARY(FN, float, uint8_t, NAME##_f32, NAME##_f32_strided); \ +BINARY(FN, half, uint8_t, NAME##_f16, NAME##_f16_strided); \ +BINARY(FN, uint32_t, uint8_t, NAME##_u32, NAME##_u32_strided); \ +BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided); \ +BINARY(FN, int16_t, uint8_t, NAME##_i16, NAME##_i16_strided); \ +BINARY(FN, int32_t, uint8_t, NAME##_i32, NAME##_i32_strided); + +#define INT64_BINARY_OP(NAME, FN) \ +BINARY(FN, int64_t, int64_t, NAME##_i64, NAME##_i64_strided); + +#define INT64_BINARY_OP_OUT(NAME, FN) \ +BINARY(FN, int64_t, uint8_t, NAME##_i64, NAME##_i64_strided); + +#define BFLOAT_BINARY_OP(FN, NAME) \ +BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided); + +#define BFLOAT_BINARY_OP_OUT(NAME, FN) \ +BINARY(FN, bfloat, uint8_t, NAME##_bf16, NAME##_bf16_strided); + +BINARY_OP(x + y, add) +BINARY_OP(x - y, sub) +BINARY_OP(x * y, mul) +BINARY_OP(x / y, div) +BINARY_OP(MIN(x, y), min) +BINARY_OP(MAX(x, y), max) + +BINARY_OP_OUT(eq, x == y) +BINARY_OP_OUT(ne, x != y) +BINARY_OP_OUT(le, x <= y) +BINARY_OP_OUT(lt, x < y) +BINARY_OP_OUT(ge, x >= y) +BINARY_OP_OUT(gt, x > y) + +#if __METAL_VERSION__ >= 220 +INT64_BINARY_OP(add, x + y) +INT64_BINARY_OP(sub, x - y) +INT64_BINARY_OP(mul, x * y) +INT64_BINARY_OP(div, x / y) +INT64_BINARY_OP(min, MIN(x, y)) +INT64_BINARY_OP(max, MAX(x, y)) + +INT64_BINARY_OP_OUT(eq, x == y) +INT64_BINARY_OP_OUT(ne, x != y) +INT64_BINARY_OP_OUT(le, x <= y) +INT64_BINARY_OP_OUT(lt, x < y) +INT64_BINARY_OP_OUT(ge, x >= y) +INT64_BINARY_OP_OUT(gt, x > y) +#endif + +#if defined(__HAVE_BFLOAT__) +BFLOAT_BINARY_OP(x + y, add) +BFLOAT_BINARY_OP(x - y, sub) +BFLOAT_BINARY_OP(x * y, mul) +BFLOAT_BINARY_OP(x / y, div) +BFLOAT_BINARY_OP(MIN(x, y), min) +BFLOAT_BINARY_OP(MAX(x, y), max) + +BFLOAT_BINARY_OP_OUT(eq, x == y) +BFLOAT_BINARY_OP_OUT(ne, x != y) +BFLOAT_BINARY_OP_OUT(le, x <= y) +BFLOAT_BINARY_OP_OUT(lt, x < y) +BFLOAT_BINARY_OP_OUT(ge, x >= y) +BFLOAT_BINARY_OP_OUT(gt, x > y) +#endif diff --git a/mcandle/mcandle-metal-kernels/src/cast.metal b/mcandle/mcandle-metal-kernels/src/cast.metal new file mode 100644 index 000000000..5a8324bf1 --- /dev/null +++ b/mcandle/mcandle-metal-kernels/src/cast.metal @@ -0,0 +1,165 @@ +#include + +METAL_FUNC uint get_strided_index( + uint idx, + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides +) { + uint strided_i = 0; + for (uint d = 0; d < num_dims; d++) { + uint dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; +} + + +using namespace metal; + +#define CAST(FN_NAME, FN_NAME_STRIDED, LEFT_TYPENAME, RIGHT_TYPENAME) \ +kernel void FN_NAME( \ + constant size_t &dim, \ + device const LEFT_TYPENAME *input, \ + device RIGHT_TYPENAME *output, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + if (tid >= dim) { \ + return; \ + } \ + output[tid] = static_cast(input[tid]); \ +} \ +kernel void FN_NAME_STRIDED( \ + constant size_t &dim, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + device const LEFT_TYPENAME *input, \ + device RIGHT_TYPENAME *output, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + if (tid >= dim) { \ + return; \ + } \ + output[tid] = static_cast(input[get_strided_index(tid, num_dims, dims, strides)]); \ +} \ + +#define CAST_THROUGH(FN_NAME, FN_NAME_STRIDED, LEFT_TYPENAME, RIGHT_TYPENAME, IR_TYPENAME) \ +kernel void FN_NAME( \ + constant size_t &dim, \ + device const LEFT_TYPENAME *input, \ + device RIGHT_TYPENAME *output, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + if (tid >= dim) { \ + return; \ + } \ + output[tid] = static_cast(static_cast(input[tid])); \ +} \ +kernel void FN_NAME_STRIDED( \ + constant size_t &dim, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + device const LEFT_TYPENAME *input, \ + device RIGHT_TYPENAME *output, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + if (tid >= dim) { \ + return; \ + } \ + output[tid] = static_cast(static_cast(input[get_strided_index(tid, num_dims, dims, strides)])); \ +} \ + +// u32 +CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float) +CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t) +CAST(cast_u32_f16, cast_u32_f16_strided, uint32_t, half) +CAST(cast_u32_i32, cast_u32_i32_strided, uint32_t, int32_t) +CAST(cast_u32_i16, cast_u32_i16_strided, uint32_t, int16_t) +#if __METAL_VERSION__ >= 220 +CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t) +#endif +#if defined(__HAVE_BFLOAT__) +CAST(cast_u32_bf16, cast_u32_bf16_strided, uint32_t, bfloat) +#endif + +// u8 +CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t) +CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float) +CAST(cast_u8_f16, cast_u8_f16_strided, uint8_t, half) +CAST(cast_u8_i32, cast_u8_i32_strided, uint8_t, int64_t) +CAST(cast_u8_i16, cast_u8_i16_strided, uint8_t, int16_t) +#if __METAL_VERSION__ >= 220 +CAST(cast_u8_i64, cast_u8_i64_strided, uint8_t, int64_t) +#endif +#if defined(__HAVE_BFLOAT__) +CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat) +#endif + +// f16 +CAST(cast_f16_f32, cast_f16_f32_strided, half, float) +CAST(cast_f16_u8, cast_f16_u8_strided, half, uint8_t) +CAST(cast_f16_u32, cast_f16_u32_strided, half, uint32_t) +CAST(cast_f16_i16, cast_f16_i16_strided, half, int16_t) +CAST(cast_f16_i32, cast_f16_i32_strided, half, int64_t) +CAST(cast_f16_i64, cast_f16_i64_strided, half, int64_t) +#if defined(__HAVE_BFLOAT__) +CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float) +#endif + +// i64 +CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float) +CAST(cast_i64_u8, cast_i64_u8_strided, int64_t, uint8_t) +CAST(cast_i64_u32, cast_i64_u32_strided, int64_t, uint32_t) +CAST(cast_i64_i32, cast_i64_i32_strided, int64_t, int32_t) +CAST(cast_i64_i16, cast_i64_i16_strided, int64_t, int16_t) +CAST(cast_i64_f16, cast_i64_f16_strided, int64_t, half) +#if defined(__HAVE_BFLOAT__) +CAST_THROUGH(cast_i64_bf16, cast_i64_bf16_strided, int64_t, bfloat, float) +#endif + +// i32 +CAST(cast_i32_f32, cast_i32_f32_strided, int32_t, float) +CAST(cast_i32_u8, cast_i32_u8_strided, int32_t, uint8_t) +CAST(cast_i32_u32, cast_i32_u32_strided, int32_t, uint32_t) +CAST(cast_i32_i64, cast_i32_i64_strided, int32_t, int64_t) +CAST(cast_i32_i16, cast_i32_i16_strided, int32_t, int16_t) +CAST(cast_i32_f16, cast_i32_f16_strided, int32_t, half) +#if defined(__HAVE_BFLOAT__) +CAST_THROUGH(cast_i32_bf16, cast_i32_bf16_strided, int64_t, bfloat, float) +#endif + +// i16 +CAST(cast_i16_f32, cast_i16_f32_strided, int16_t, float) +CAST(cast_i16_u8, cast_i16_u8_strided, int16_t, uint8_t) +CAST(cast_i16_u32, cast_i16_u32_strided, int16_t, uint32_t) +CAST(cast_i16_i32, cast_i16_i32_strided, int16_t, int32_t) +CAST(cast_i16_i64, cast_i16_i64_strided, int16_t, int64_t) +CAST(cast_i16_f16, cast_i16_f16_strided, int16_t, half) +#if defined(__HAVE_BFLOAT__) +CAST_THROUGH(cast_i16_bf16, cast_i16_bf16_strided, int16_t, bfloat, float) +#endif + +// f32 +CAST(cast_f32_f16, cast_f32_f16_strided, float, half) +CAST(cast_f32_u32, cast_f32_u32_strided, float, uint32_t) +CAST(cast_f32_u8, cast_f32_u8_strided, float, uint8_t) +CAST(cast_f32_i16, cast_f32_i16_strided, float, int16_t) +CAST(cast_f32_i32, cast_f32_i32_strided, float, int32_t) +CAST(cast_f32_i64, cast_f32_i64_strided, float, int64_t) +#if defined(__HAVE_BFLOAT__) +CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat) +#endif + +// bf16 +#if defined(__HAVE_BFLOAT__) +CAST(cast_bf16_u32, cast_bf16_u32_strided, bfloat, uint32_t) +CAST(cast_bf16_i16, cast_bf16_i16_strided, bfloat, int16_t) +CAST(cast_bf16_i32, cast_bf16_i32_strided, bfloat, int32_t) +CAST(cast_bf16_i64, cast_bf16_i64_strided, bfloat, int64_t) +CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float) +CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float) +CAST_THROUGH(cast_bf16_f16, cast_bf16_f16_strided, bfloat, half, float) +#endif \ No newline at end of file diff --git a/mcandle/mcandle-metal-kernels/src/conv.metal b/mcandle/mcandle-metal-kernels/src/conv.metal new file mode 100644 index 000000000..5348a0f00 --- /dev/null +++ b/mcandle/mcandle-metal-kernels/src/conv.metal @@ -0,0 +1,599 @@ +#include + +using namespace metal; + +#define MAX(x, y) ((x) > (y) ? (x) : (y)) + +template +METAL_FUNC void im2col( + constant size_t &dst_numel, + constant size_t &h_out, + constant size_t &w_out, + constant size_t &h_k, + constant size_t &w_k, + constant size_t &stride, + constant size_t &padding, + constant size_t &dilation, + constant size_t *src_dims, + constant size_t *src_strides, + device const T *src, + device T *dst, + uint tid [[ thread_position_in_grid ]] +) { + // dst: (b_size, h_out, w_out, c_in, h_k, w_k) + // src: (b_size, c_in, h_in, w_in) + if (tid >= dst_numel) { + return; + } + const size_t b_in = src_dims[0]; + const size_t c_in = src_dims[1]; + const size_t h_in = src_dims[2]; + const size_t w_in = src_dims[3]; + + const size_t dst_s4 = w_k; + const size_t dst_s3 = h_k * dst_s4; + const size_t dst_s2 = c_in * dst_s3; + const size_t dst_s1 = w_out * dst_s2; + const size_t dst_s0 = h_out * dst_s1; + + size_t tmp_tid = tid; + const size_t b_idx = tmp_tid / dst_s0; + tmp_tid -= b_idx * dst_s0; + const size_t h_idx = tmp_tid / dst_s1; + tmp_tid -= h_idx * dst_s1; + const size_t w_idx = tmp_tid / dst_s2; + tmp_tid -= w_idx * dst_s2; + const size_t c_idx = tmp_tid / dst_s3; + tmp_tid -= c_idx * dst_s3; + const size_t h_k_idx = tmp_tid / dst_s4; + tmp_tid -= h_k_idx * dst_s4; + const size_t w_k_idx = tmp_tid; + size_t src_h_idx = h_idx * stride + h_k_idx * dilation; + size_t src_w_idx = w_idx * stride + w_k_idx * dilation; + if (src_h_idx < padding || src_h_idx >= h_in + padding) { + dst[tid] = static_cast(0); + } + else if (src_w_idx < padding || src_w_idx >= w_in + padding) { + dst[tid] = static_cast(0); + } + else { + src_h_idx -= padding; + src_w_idx -= padding; + const size_t src_i = + b_idx * src_strides[0] + + c_idx * src_strides[1] + + src_h_idx * src_strides[2] + + src_w_idx * src_strides[3]; + dst[tid] = src[src_i]; + } +} + +template +METAL_FUNC void col2im1d( + constant size_t &dst_el, + constant size_t &l_out, + constant size_t &l_in, + constant size_t &c_out, + constant size_t &k_size, + constant size_t &stride, + device const T *src, + device T *dst, + uint dst_i [[ thread_position_in_grid ]] +) { + // src: (b_size, l_in, c_out, l_k) + // dst: (b_size, c_out, l_out) + if (dst_i >= dst_el) { + return; + } + + const size_t dst_s0 = c_out * l_out; + const size_t dst_s1 = l_out; + const size_t src_s0 = c_out * k_size * l_in; + const size_t src_s1 = c_out * k_size; + const size_t src_s2 = k_size; + + size_t tmp_dst_i = dst_i; + const size_t b_idx = tmp_dst_i / dst_s0; + tmp_dst_i -= b_idx * dst_s0; + const size_t c_idx = tmp_dst_i / dst_s1; + tmp_dst_i -= c_idx * dst_s1; + const int l_out_idx = tmp_dst_i; + + dst[dst_i] = static_cast(0); + + int l_in_idx = l_out_idx / stride; + int k0 = l_out_idx - l_in_idx * stride; + // l_out_idx = l_in_idx * stride + k0 + for (; k0 < k_size && l_in_idx >= 0; k0 += stride, --l_in_idx) { + if (l_in_idx < l_in) { + const size_t src_i = b_idx * src_s0 + l_in_idx * src_s1 + c_idx * src_s2 + k0; + dst[dst_i] += src[src_i]; + } + } +} + +template +METAL_FUNC void im2col1d( + constant size_t &dst_numel, + constant size_t &l_out, + constant size_t &l_k, + constant size_t &stride, + constant size_t &padding, + constant size_t &dilation, + constant size_t *src_dims, + constant size_t *src_strides, + device const T *src, + device T *dst, + uint tid [[ thread_position_in_grid ]] +) { + // dst: (b_size, l_out, c_in, l_k) + // src: (b_size, c_in, l_in) + if (tid >= dst_numel) { + return; + } + const size_t b_in = src_dims[0]; + const size_t c_in = src_dims[1]; + const size_t l_in = src_dims[2]; + + const size_t dst_s2 = l_k; + const size_t dst_s1 = c_in * dst_s2; + const size_t dst_s0 = l_out * dst_s1; + + size_t tmp_dst_i = tid; + const size_t b_idx = tmp_dst_i / dst_s0; + tmp_dst_i -= b_idx * dst_s0; + const size_t l_idx = tmp_dst_i / dst_s1; + tmp_dst_i -= l_idx * dst_s1; + const size_t c_idx = tmp_dst_i / dst_s2; + tmp_dst_i -= c_idx * dst_s2; + const size_t l_k_idx = tmp_dst_i; + size_t src_l_idx = l_idx * stride + l_k_idx * dilation; + if (src_l_idx < padding || src_l_idx >= l_in + padding) { + dst[tid] = static_cast(0); + } + else { + src_l_idx -= padding; + const size_t src_i = b_idx * src_strides[0] + c_idx * src_strides[1] + src_l_idx * src_strides[2]; + dst[tid] = src[src_i]; + } +} + +template +METAL_FUNC void upsample_nearest2d( + constant size_t &w_out, + constant size_t &h_out, + constant float &w_scale, + constant float &h_scale, + constant size_t *src_dims, + constant size_t *src_s, + device const T *src, + device T *dst, + uint tid [[ thread_position_in_grid ]] +) { + // src: (b_size, c_in, w_in, h_in) + + const size_t c = src_dims[1]; + const size_t w_in = src_dims[2]; + const size_t h_in = src_dims[3]; + + if (tid >= src_dims[0] * c * w_out * h_out) { + return; + } + + // TODO: Improve this. + const size_t b_idx = tid / (w_out * h_out * c); + const size_t c_idx = (tid / (w_out * h_out)) % c; + const size_t dst_w = (tid / h_out) % w_out; + const size_t dst_h = tid % h_out; + + size_t src_w = static_cast(dst_w * w_scale); + size_t src_h = static_cast(dst_h * h_scale); + if (src_w >= w_in) { + src_w = w_in - 1; + } + if (src_h >= h_in) { + src_h = h_in - 1; + } + + const size_t src_i = b_idx * src_s[0] + c_idx * src_s[1] + src_w * src_s[2] + src_h * src_s[3]; + dst[tid] = src[src_i]; +} + +#define IM2COL_OP(T, FN_NAME) \ +kernel void FN_NAME( \ + constant size_t &dst_numel, \ + constant size_t &h_out, \ + constant size_t &w_out, \ + constant size_t &h_k, \ + constant size_t &w_k, \ + constant size_t &stride, \ + constant size_t &padding, \ + constant size_t &dilation, \ + constant size_t *src_dims, \ + constant size_t *src_strides, \ + device const T *src, \ + device T *dst, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + im2col(dst_numel, h_out, w_out, h_k, w_k, stride, padding, dilation, src_dims, src_strides, src, dst, tid); \ +} \ + +#define IM2COL1D_OP(T, FN_NAME) \ +kernel void FN_NAME( \ + constant size_t &dst_numel, \ + constant size_t &l_out, \ + constant size_t &l_k, \ + constant size_t &stride, \ + constant size_t &padding, \ + constant size_t &dilation, \ + constant size_t *src_dims, \ + constant size_t *src_strides, \ + device const T *src, \ + device T *dst, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + im2col1d(dst_numel, l_out, l_k, stride, padding, dilation, src_dims, src_strides, src, dst, tid); \ +} \ + +#define COL2IM1D_OP(T, FN_NAME) \ +kernel void FN_NAME( \ + constant size_t &dst_el, \ + constant size_t &l_out, \ + constant size_t &l_in, \ + constant size_t &c_out, \ + constant size_t &k_size, \ + constant size_t &stride, \ + device const T *src, \ + device T *dst, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + col2im1d(dst_el, l_out, l_in, c_out, k_size, stride, src, dst, tid); \ +} \ + +#define UPSAMPLE_NEAREST2D_OP(TYPENAME, FN_NAME) \ +kernel void FN_NAME( \ + constant size_t &w_out, \ + constant size_t &h_out, \ + constant float &w_scale, \ + constant float &h_scale, \ + constant size_t *dims, \ + constant size_t *strides, \ + device const TYPENAME *src, \ + device TYPENAME *dst, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + upsample_nearest2d(w_out, h_out, w_scale, h_scale, dims, strides, src, dst, tid); \ +} \ + +template +METAL_FUNC void avg_pool2d( + constant size_t &w_k, + constant size_t &h_k, + constant size_t &w_stride, + constant size_t &h_stride, + constant size_t *src_dims, + constant size_t *src_strides, + device const T *src, + device T *dst, + uint tid [[ thread_position_in_grid ]] +) { + const size_t c = src_dims[1]; + const size_t w_in = src_dims[2]; + const size_t h_in = src_dims[3]; + + const size_t w_out = (w_in - w_k) / w_stride + 1; + const size_t h_out = (h_in - h_k) / h_stride + 1; + if (tid >= src_dims[0] * c * w_out * h_out) { + return; + } + + const size_t b_idx = tid / (w_out * h_out * c); + const size_t c_idx = (tid / (w_out * h_out)) % c; + const size_t dst_w = (tid / h_out) % w_out; + const size_t dst_h = tid % h_out; + + const size_t src_idx0 = b_idx * src_strides[0]; + A d = 0; + for (size_t w_offset = 0; w_offset < w_k; ++w_offset) { + size_t src_w = w_stride * dst_w + w_offset; + if (src_w >= w_in){ + continue; + } + for (size_t h_offset = 0; h_offset < h_k; ++h_offset) { + size_t src_h = h_stride * dst_h + h_offset; + if (src_h >= h_in) { + continue; + } + const size_t src_idx = src_idx0 + c_idx * src_strides[1] + src_w * src_strides[2] + src_h * src_strides[3]; + d += static_cast(src[src_idx]); + } + } + dst[tid] = static_cast(d / (w_k * h_k)); +} + +#define AVGPOOL2D_OP(TYPENAME, TYPEACC, FN_NAME) \ +kernel void FN_NAME( \ + constant size_t &w_k, \ + constant size_t &h_k, \ + constant size_t &w_s, \ + constant size_t &h_s, \ + constant size_t *src_dims, \ + constant size_t *src_s, \ + device const TYPENAME *src, \ + device TYPENAME *dst, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + avg_pool2d(w_k, h_k, w_s, h_s, src_dims, src_s, src, dst, tid); \ +} \ + +template +METAL_FUNC void max_pool2d( + constant size_t &w_k, + constant size_t &h_k, + constant size_t &w_stride, + constant size_t &h_stride, + constant size_t *src_dims, + constant size_t *src_strides, + device const T *src, + device T *dst, + uint tid [[ thread_position_in_grid ]] +) { + const size_t c = src_dims[1]; + const size_t w_in = src_dims[2]; + const size_t h_in = src_dims[3]; + + const size_t w_out = (w_in - w_k) / w_stride + 1; + const size_t h_out = (h_in - h_k) / h_stride + 1; + if (tid >= src_dims[0] * c * w_out * h_out) { + return; + } + + const size_t b_idx = tid / (w_out * h_out * c); + const size_t c_idx = (tid / (w_out * h_out)) % c; + const size_t dst_w = (tid / h_out) % w_out; + const size_t dst_h = tid % h_out; + + const size_t src_idx0 = b_idx * src_strides[0]; + T d = 0; + bool set = false; + for (size_t w_offset = 0; w_offset < w_k; ++w_offset) { + size_t src_w = w_stride * dst_w + w_offset; + if (src_w >= w_in){ + continue; + } + for (size_t h_offset = 0; h_offset < h_k; ++h_offset) { + size_t src_h = h_stride * dst_h + h_offset; + if (src_h >= h_in) { + continue; + } + const size_t src_idx = src_idx0 + c_idx * src_strides[1] + src_w * src_strides[2] + src_h * src_strides[3]; + if (set) { + d = MAX(d, src[src_idx]); + } + else { + d = src[src_idx]; + set = true; + } + } + } + dst[tid] = d; +} + +#define MAXPOOL2D_OP(TYPENAME, FN_NAME) \ +kernel void FN_NAME( \ + constant size_t &w_k, \ + constant size_t &h_k, \ + constant size_t &w_s, \ + constant size_t &h_s, \ + constant size_t *src_dims, \ + constant size_t *src_s, \ + device const TYPENAME *src, \ + device TYPENAME *dst, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + max_pool2d(w_k, h_k, w_s, h_s, src_dims, src_s, src, dst, tid); \ +} \ + + +// Naive implementation of conv_transpose1d. +template +METAL_FUNC void conv_transpose1d( + constant size_t &l_out, + constant size_t &stride, + constant size_t &padding, + constant size_t &out_padding, + constant size_t &dilation, + constant size_t *src_dims, + constant size_t *src_strides, + constant size_t *k_dims, + constant size_t *k_strides, + device const T *src, + device const T *k, + device T *dst, + uint tid [[ thread_position_in_grid ]] +) { + // src: (b_size, c_in, l_in) + // kernel: (c_in, c_out, l_k) + const size_t l_k = k_dims[2]; + const size_t c_out = k_dims[1]; + const size_t c_in = src_dims[1]; + const size_t l_in = src_dims[2]; + if (tid >= src_dims[0] * c_out * l_out) { + return; + } + + const size_t b_idx = tid / (l_out * c_out); + const size_t dst_c_idx = (tid / l_out) % c_out; + const size_t out_x = tid % l_out; + + const size_t src_idx0 = b_idx * src_strides[0]; + A d = 0; + for (int k_x = 0; k_x < (int)l_k; ++k_x) { + // let out_x = inp_x * p.stride + k_x * p.dilation - p.padding; + int inp_x_stride = (int)(out_x + padding) - k_x * dilation; + if (inp_x_stride < 0 || inp_x_stride % stride) { + continue; + } + int inp_x = inp_x_stride / stride; + if (inp_x >= l_in) continue; + for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) { + const size_t src_idx = src_idx0 + src_c_idx * src_strides[1] + inp_x * src_strides[2]; + const size_t k_idx = src_c_idx * k_strides[0] + dst_c_idx * k_strides[1] + k_x * k_strides[2]; + d += static_cast(src[src_idx]) * static_cast(k[k_idx]); + } + } + dst[tid] = static_cast(d); +} + +#define CONVT1D_OP(TYPENAME, TYPEACC, FN_NAME) \ +kernel void FN_NAME( \ + constant size_t &l_out, \ + constant size_t &stride, \ + constant size_t &padding, \ + constant size_t &out_padding, \ + constant size_t &dilation, \ + constant size_t *src_dims, \ + constant size_t *src_strides, \ + constant size_t *k_dims, \ + constant size_t *k_strides, \ + device const TYPENAME *src, \ + device const TYPENAME *k, \ + device TYPENAME *dst, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + conv_transpose1d(l_out, stride, padding, out_padding, dilation, src_dims, src_strides, k_dims, k_strides, src, k, dst, tid); \ +} \ + +template +METAL_FUNC void conv_transpose2d( + constant size_t &w_out, + constant size_t &h_out, + constant size_t &stride, + constant size_t &padding, + constant size_t &out_padding, + constant size_t &dilation, + constant size_t *input_dims, + constant size_t *input_stride, + constant size_t *k_dims, + constant size_t *k_stride, + device const T *src, + device const T *k, + device T *dst, + uint tid [[ thread_position_in_grid ]] +) { + const size_t h_k = k_dims[2]; + const size_t w_k = k_dims[3]; + const size_t c_out = k_dims[1]; + const size_t c_in = input_dims[1]; + const size_t h_in = input_dims[2]; + const size_t w_in = input_dims[3]; + + if (tid >= input_dims[0] * c_out * w_out * h_out) { + return; + } + + const size_t b_idx = tid / (w_out * h_out * c_out); + const size_t dst_c_idx = (tid / (w_out * h_out)) % c_out; + const size_t out_y = (tid / w_out) % h_out; + const size_t out_x = tid % w_out; + + const size_t src_idx0 = b_idx * input_stride[0]; + + A d = 0; + for (int k_x = 0; k_x < (int)w_k; ++k_x) { + const int inp_x_stride = (int)(out_x + padding) - k_x * dilation; + if (inp_x_stride < 0 || inp_x_stride % stride) { + continue; + } + const int inp_x = inp_x_stride / stride; + if (inp_x >= w_in) continue; + for (int k_y = 0; k_y < (int)h_k; ++k_y) { + const int inp_y_stride = (int)(out_y + padding) - k_y * dilation; + if (inp_y_stride < 0 || inp_y_stride % stride) { + continue; + } + const int inp_y = inp_y_stride / stride; + if (inp_y >= h_in) continue; + for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) { + const size_t src_idx = src_idx0 + src_c_idx * input_stride[1] + inp_y * input_stride[2] + inp_x * input_stride[3]; + const size_t k_idx = src_c_idx * k_stride[0] + dst_c_idx * k_stride[1] + k_y * k_stride[2] + k_x * k_stride[3]; + d += static_cast(src[src_idx]) * static_cast(k[k_idx]); + } + } + } + dst[tid] = static_cast(d); +} + +#define CONVT2D_OP(TYPENAME, TYPEACC, FN_NAME) \ +kernel void FN_NAME( \ + constant size_t &w_out, \ + constant size_t &h_out, \ + constant size_t &stride, \ + constant size_t &padding, \ + constant size_t &out_padding, \ + constant size_t &dilation, \ + constant size_t *input_dims, \ + constant size_t *input_stride, \ + constant size_t *k_dims, \ + constant size_t *k_stride, \ + device const TYPENAME *src, \ + device const TYPENAME *k, \ + device TYPENAME *dst, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + conv_transpose2d(w_out, h_out, stride, padding, out_padding, dilation, input_dims, input_stride, k_dims, k_stride, src, k, dst, tid); \ +} \ + +IM2COL_OP(float, im2col_f32) +IM2COL_OP(half, im2col_f16) +IM2COL_OP(uint8_t, im2col_u8) +IM2COL_OP(uint32_t, im2col_u32) +#if defined(__HAVE_BFLOAT__) +IM2COL_OP(bfloat, im2col_bf16) +#endif + +COL2IM1D_OP(float, col2im1d_f32) +COL2IM1D_OP(uint8_t, col2im1d_u8) +COL2IM1D_OP(uint32_t, col2im1d_u32) + +IM2COL1D_OP(float, im2col1d_f32) +IM2COL1D_OP(uint8_t, im2col1d_u8) +IM2COL1D_OP(uint32_t, im2col1d_u32) + +UPSAMPLE_NEAREST2D_OP(float, upsample_nearest2d_f32) +UPSAMPLE_NEAREST2D_OP(half, upsample_nearest2d_f16) +UPSAMPLE_NEAREST2D_OP(uint8_t, upsample_nearest2d_u8) +UPSAMPLE_NEAREST2D_OP(uint32_t, upsample_nearest2d_u32) +#if defined(__HAVE_BFLOAT__) +UPSAMPLE_NEAREST2D_OP(bfloat, upsample_nearest2d_bf16) +#endif + +MAXPOOL2D_OP(float, max_pool2d_f32) +MAXPOOL2D_OP(half, max_pool2d_f16) +MAXPOOL2D_OP(uint32_t, max_pool2d_u32) +MAXPOOL2D_OP(uint8_t, max_pool2d_u8) +#if defined(__HAVE_BFLOAT__) +MAXPOOL2D_OP(bfloat, max_pool2d_bf16) +#endif + +AVGPOOL2D_OP(float, float, avg_pool2d_f32) +AVGPOOL2D_OP(half, float, avg_pool2d_f16) +AVGPOOL2D_OP(uint32_t, uint32_t, avg_pool2d_u32) +AVGPOOL2D_OP(uint8_t, uint8_t, avg_pool2d_u8) +#if defined(__HAVE_BFLOAT__) +AVGPOOL2D_OP(bfloat, float, avg_pool2d_bf16) +#endif + +CONVT1D_OP(float, float, conv_transpose1d_f32) +CONVT1D_OP(half, float, conv_transpose1d_f16) +CONVT1D_OP(uint8_t, uint8_t, conv_transpose1d_u8) +CONVT1D_OP(uint32_t, uint32_t, conv_transpose1d_u32) +#if defined(__HAVE_BFLOAT__) +CONVT1D_OP(bfloat, float, conv_transpose1d_bf16) +#endif + +CONVT2D_OP(float, float, conv_transpose2d_f32) +CONVT2D_OP(half, float, conv_transpose2d_f16) +#if defined(__HAVE_BFLOAT__) +CONVT1D_OP(bfloat, float, conv_transpose2d_bf16) +#endif diff --git a/mcandle/mcandle-metal-kernels/src/fill.metal b/mcandle/mcandle-metal-kernels/src/fill.metal new file mode 100644 index 000000000..7e99a8525 --- /dev/null +++ b/mcandle/mcandle-metal-kernels/src/fill.metal @@ -0,0 +1,41 @@ +#include + +using namespace metal; + +template METAL_FUNC void fill_with( + device T *out, + constant float &value, + constant size_t &numel, + uint tid [[thread_position_in_grid]] +) { + if (tid >= numel) { + return; + } + out[tid] = static_cast(value); +} + +#define FILL_OP(NAME, T) \ +kernel void fill_##NAME( \ + device T *out, \ + constant float &value, \ + constant size_t &numel, \ + uint tid [[thread_position_in_grid]] \ +) { \ + fill_with(out, value, numel, tid); \ +} \ + + +#define FILL_OPS(NAME, T) \ +FILL_OP(NAME, T) \ + +FILL_OPS(u8, uchar) +FILL_OPS(u32, uint) +FILL_OPS(i64, long) +FILL_OPS(f16, half) +FILL_OPS(f32, float) +FILL_OPS(i32, int) +FILL_OPS(i16, short) + +#if __METAL_VERSION__ >= 310 +FILL_OPS(bf16, bfloat) +#endif diff --git a/mcandle/mcandle-metal-kernels/src/indexing.metal b/mcandle/mcandle-metal-kernels/src/indexing.metal new file mode 100644 index 000000000..6de4a452f --- /dev/null +++ b/mcandle/mcandle-metal-kernels/src/indexing.metal @@ -0,0 +1,304 @@ +#include +using namespace metal; + +METAL_FUNC uint get_strided_index( + uint idx, + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides +) { + uint strided_i = 0; + for (uint d = 0; d < num_dims; d++) { + uint dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; +} + +template +METAL_FUNC void index( + constant size_t &dst_size, + constant size_t &left_size, + constant size_t &src_dim_size, + constant size_t &right_size, + constant size_t &ids_size, + constant bool &contiguous, + constant size_t *src_dims, + constant size_t *src_strides, + const device TYPENAME *input, + const device INDEX_TYPENAME *input_ids, + device TYPENAME *output, + uint tid [[ thread_position_in_grid ]] +) { + if (tid >= dst_size) { + return; + } + const size_t id_i = (tid / right_size) % ids_size; + const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); + const size_t right_rank_i = tid % right_size; + const size_t left_rank_i = tid / right_size / ids_size; + /* + // Force prevent out of bounds indexing + // since there doesn't seem to be a good way to force crash + // No need to check for zero we're only allowing unsized. + */ + const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i; + const size_t strided_src_i = contiguous ? src_i : get_strided_index(src_i, src_dim_size, src_dims, src_strides); + output[tid] = input[strided_src_i]; +} + +# define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \ +kernel void NAME( \ + constant size_t &dst_size, \ + constant size_t &left_size, \ + constant size_t &src_dim_size, \ + constant size_t &right_size, \ + constant size_t &ids_size, \ + constant bool &contiguous, \ + constant size_t *src_dims, \ + constant size_t *src_strides, \ + const device TYPENAME *input, \ + const device INDEX_TYPENAME *input_ids, \ + device TYPENAME *output, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + index(dst_size, left_size, src_dim_size, right_size, ids_size, contiguous, src_dims, src_strides, input, input_ids, output, tid); \ +} + + +template +METAL_FUNC void gather( + constant size_t &dst_size, + constant size_t &left_size, + constant size_t &src_dim_size, + constant size_t &right_size, + constant size_t &ids_size, + const device TYPENAME *input, + const device INDEX_TYPENAME *input_ids, + device TYPENAME *output, + uint tid [[ thread_position_in_grid ]] +) { + if (tid >= dst_size) { + return; + } + const INDEX_TYPENAME input_i = input_ids[tid]; + const size_t right_rank_i = tid % right_size; + const size_t left_rank_i = tid / right_size / ids_size; + const size_t src_i = (left_rank_i * src_dim_size + input_i) * right_size + right_rank_i; + output[tid] = input[src_i]; +} + +# define GATHER_OP(NAME, INDEX_TYPENAME, TYPENAME) \ +kernel void NAME( \ + constant size_t &dst_size, \ + constant size_t &left_size, \ + constant size_t &src_dim_size, \ + constant size_t &right_size, \ + constant size_t &ids_size, \ + const device TYPENAME *input, \ + const device INDEX_TYPENAME *input_ids, \ + device TYPENAME *output, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + gather(dst_size, left_size, src_dim_size, right_size, ids_size, input, input_ids, output, tid); \ +} + +template +METAL_FUNC void scatter_add( + constant size_t &dst_size, + constant size_t &left_size, + constant size_t &src_dim_size, + constant size_t &right_size, + constant size_t &dst_dim_size, + const device TYPENAME *input, + const device INDEX_TYPENAME *input_ids, + device TYPENAME *output, + uint tid [[ thread_position_in_grid ]] +) { + if (tid >= dst_size) { + return; + } + const size_t right_rank_i = tid % right_size; + const size_t left_rank_i = tid / right_size; + for (unsigned int j = 0; j < src_dim_size; ++j) { + const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i; + const INDEX_TYPENAME idx = input_ids[src_i]; + const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i; + output[dst_i] += input[src_i]; + } +} + +# define SCATTER_ADD_OP(NAME, INDEX_TYPENAME, TYPENAME) \ +kernel void NAME( \ + constant size_t &dst_size, \ + constant size_t &left_size, \ + constant size_t &src_dim_size, \ + constant size_t &right_size, \ + constant size_t &dst_dim_size, \ + const device TYPENAME *input, \ + const device INDEX_TYPENAME *input_ids, \ + device TYPENAME *output, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + scatter_add(dst_size, left_size, src_dim_size, right_size, dst_dim_size, input, input_ids, output, tid); \ +} + +template +METAL_FUNC void index_add( + constant size_t &dst_size, + constant size_t &left_size, + constant size_t &src_dim_size, + constant size_t &right_size, + constant size_t &dst_dim_size, + constant size_t &ids_dim_size, + const device TYPENAME *input, + const device INDEX_TYPENAME *input_ids, + device TYPENAME *output, + uint tid [[ thread_position_in_grid ]] +) { + if (tid >= dst_size) { + return; + } + const size_t right_rank_i = tid % right_size; + const size_t left_rank_i = tid / right_size; + for (unsigned int j = 0; j < ids_dim_size; ++j) { + const INDEX_TYPENAME idx = input_ids[j]; + const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i; + const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i; + output[dst_i] += input[src_i]; + } +} + +# define INDEX_ADD_OP(NAME, INDEX_TYPENAME, TYPENAME) \ +kernel void NAME( \ + constant size_t &dst_size, \ + constant size_t &left_size, \ + constant size_t &src_dim_size, \ + constant size_t &right_size, \ + constant size_t &dst_dim_size, \ + constant size_t &ids_dim_size, \ + const device TYPENAME *input, \ + const device INDEX_TYPENAME *input_ids, \ + device TYPENAME *output, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + index_add(dst_size, left_size, src_dim_size, right_size, dst_dim_size, ids_dim_size, input, input_ids, output, tid); \ +} + + +INDEX_OP(is_i64_f32, int64_t, float) +INDEX_OP(is_i64_f16, int64_t, half) +#if defined(__HAVE_BFLOAT__) +INDEX_OP(is_i64_bf16, int64_t, bfloat) +#endif + +INDEX_OP(is_i32_f32, int32_t, float) +INDEX_OP(is_i32_f16, int32_t, half) +#if defined(__HAVE_BFLOAT__) +INDEX_OP(is_i32_bf16, int32_t, bfloat) +#endif + +INDEX_OP(is_i16_f32, int16_t, float) +INDEX_OP(is_i16_f16, int16_t, half) +#if defined(__HAVE_BFLOAT__) +INDEX_OP(is_i16_bf16, int16_t, bfloat) +#endif + +INDEX_OP(is_u32_u8, uint32_t, uint8_t) +INDEX_OP(is_u32_u32, uint32_t, uint32_t) +INDEX_OP(is_u32_f32, uint32_t, float) +INDEX_OP(is_u32_f16, uint32_t, half) +#if defined(__HAVE_BFLOAT__) +INDEX_OP(is_u32_bf16, uint32_t, bfloat) +#endif + +INDEX_OP(is_u8_u8, uint8_t, uint8_t) +INDEX_OP(is_u8_u32, uint8_t, uint32_t) +INDEX_OP(is_u8_f32, uint8_t, float) +INDEX_OP(is_u8_f16, uint8_t, half) +#if defined(__HAVE_BFLOAT__) +INDEX_OP(is_u8_bf16, uint8_t, bfloat) +#endif + +GATHER_OP(gather_u32_f32, uint, float) +GATHER_OP(gather_u32_f16, uint, half) +#if defined(__HAVE_BFLOAT__) +GATHER_OP(gather_u32_bf16, uint, bfloat) +#endif + +SCATTER_ADD_OP(sa_u32_f32, uint32_t, float) +SCATTER_ADD_OP(sa_u8_f32, uint8_t, float) +SCATTER_ADD_OP(sa_i16_f32, int16_t, float) +SCATTER_ADD_OP(sa_i32_f32, int32_t, float) +SCATTER_ADD_OP(sa_i64_f32, int64_t, float) +SCATTER_ADD_OP(sa_u32_f16, uint32_t, half) +SCATTER_ADD_OP(sa_u8_f16, uint8_t, half) +SCATTER_ADD_OP(sa_i16_f16, int16_t, half) +SCATTER_ADD_OP(sa_i32_f16, int32_t, half) +SCATTER_ADD_OP(sa_i64_f16, int64_t, half) +#if defined(__HAVE_BFLOAT__) +SCATTER_ADD_OP(sa_u32_bf16, uint32_t, bfloat) +SCATTER_ADD_OP(sa_u8_bf16, uint8_t, bfloat) +SCATTER_ADD_OP(sa_i64_bf16, int64_t, bfloat) +#endif + +// i64 +INDEX_ADD_OP(ia_i64_f16, int64_t, half) +INDEX_ADD_OP(ia_i64_f32, int64_t, float) +INDEX_ADD_OP(ia_i64_i16, int64_t, int16_t) +INDEX_ADD_OP(ia_i64_i32, int64_t, int32_t) +INDEX_ADD_OP(ia_i64_i64, int64_t, int64_t) +INDEX_ADD_OP(ia_i64_u32, int64_t, uint32_t) +INDEX_ADD_OP(ia_i64_u8, int64_t, uint8_t) +#if defined(__HAVE_BFLOAT__) +INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat) +#endif + +// i32 +INDEX_ADD_OP(ia_i32_f16, int32_t, half) +INDEX_ADD_OP(ia_i32_f32, int32_t, float) +INDEX_ADD_OP(ia_i32_i64, int32_t, int64_t) +INDEX_ADD_OP(ia_i32_i32, int32_t, int32_t) +INDEX_ADD_OP(ia_i32_u32, int32_t, uint32_t) +INDEX_ADD_OP(ia_i32_u8, int32_t, uint8_t) +#if defined(__HAVE_BFLOAT__) +INDEX_ADD_OP(ia_i32_bf16, int32_t, bfloat) +#endif + +// i16 +INDEX_ADD_OP(ia_i16_f16, int16_t, half) +INDEX_ADD_OP(ia_i16_f32, int16_t, float) +INDEX_ADD_OP(ia_i16_i16, int16_t, int16_t) +INDEX_ADD_OP(ia_i16_i32, int16_t, int32_t) +INDEX_ADD_OP(ia_i16_i64, int16_t, int64_t) +INDEX_ADD_OP(ia_i16_u32, int16_t, uint32_t) +INDEX_ADD_OP(ia_i16_u8, int16_t, uint8_t) +#if defined(__HAVE_BFLOAT__) +INDEX_ADD_OP(ia_i16_bf16, int16_t, bfloat) +#endif + + +// u32 +INDEX_ADD_OP(ia_u32_f16, uint32_t, half) +INDEX_ADD_OP(ia_u32_f32, uint32_t, float) +INDEX_ADD_OP(ia_u32_i16, uint32_t, int16_t) +INDEX_ADD_OP(ia_u32_i32, uint32_t, int32_t) +INDEX_ADD_OP(ia_u32_i64, uint32_t, int64_t) +INDEX_ADD_OP(ia_u32_u32, uint32_t, uint32_t) +INDEX_ADD_OP(ia_u32_u8, uint32_t, uint8_t) +#if defined(__HAVE_BFLOAT__) +INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat) +#endif + +// u8 +INDEX_ADD_OP(ia_u8_f16, uint8_t, half) +INDEX_ADD_OP(ia_u8_f32, uint8_t, float) +INDEX_ADD_OP(ia_u8_i16, uint8_t, int16_t) +INDEX_ADD_OP(ia_u8_i32, uint8_t, int32_t) +INDEX_ADD_OP(ia_u8_i64, uint8_t, int64_t) +INDEX_ADD_OP(ia_u8_u32, uint8_t, uint32_t) +INDEX_ADD_OP(ia_u8_u8, uint8_t, uint8_t) +#if defined(__HAVE_BFLOAT__) +INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat) +#endif diff --git a/mcandle/mcandle-metal-kernels/src/lib.rs b/mcandle/mcandle-metal-kernels/src/lib.rs new file mode 100644 index 000000000..93362299b --- /dev/null +++ b/mcandle/mcandle-metal-kernels/src/lib.rs @@ -0,0 +1,2905 @@ +use metal::{ + Buffer, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState, Device, Function, + FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger, +}; +use std::collections::HashMap; +use std::ffi::c_void; +use std::sync::RwLock; + +pub mod utils; +pub use utils::BufferOffset; +use utils::{get_block_dims, linear_split, EncoderParam, EncoderProvider}; + +const AFFINE: &str = include_str!("affine.metal"); +const BINARY: &str = include_str!("binary.metal"); +const CAST: &str = include_str!("cast.metal"); +const CONV: &str = include_str!("conv.metal"); +const FILL: &str = include_str!("fill.metal"); +const INDEXING: &str = include_str!("indexing.metal"); +// Current source: https://github.com/ivarflakstad/metal-flash-attention/tree/mcandle +const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib"); +const MLX_GEMM: &str = include_str!("mlx_gemm.metal"); +const QUANTIZED: &str = include_str!("quantized.metal"); +const RANDOM: &str = include_str!("random.metal"); +const REDUCE: &str = include_str!("reduce.metal"); +const SORT: &str = include_str!("sort.metal"); +const TERNARY: &str = include_str!("ternary.metal"); +const UNARY: &str = include_str!("unary.metal"); +const SDPA: &str = include_str!("scaled_dot_product_attention.metal"); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum Source { + Affine, + Binary, + Cast, + Conv, + Fill, + Gemm, + Indexing, + Mfa, + Quantized, + Random, + Reduce, + Sort, + Ternary, + Unary, + Sdpa, +} + +pub mod copy2d { + pub struct Kernel(pub &'static str); + pub const FLOAT: Kernel = Kernel("copy2d_f32"); + pub const HALF: Kernel = Kernel("copy2d_f16"); + pub const BFLOAT: Kernel = Kernel("copy2d_bf16"); + pub const I64: Kernel = Kernel("copy2d_i64"); + pub const I32: Kernel = Kernel("copy2d_i32"); + pub const I16: Kernel = Kernel("copy2d_i16"); + pub const U32: Kernel = Kernel("copy2d_u32"); + pub const U8: Kernel = Kernel("copy2d_u8"); +} + +macro_rules! ops{ + ($($name:ident),+) => { + + pub mod contiguous { + pub struct Kernel(pub &'static str); + $( + pub mod $name { + use super::Kernel; + pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32")); + pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16")); + pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16")); + pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64")); + pub const I32: Kernel = Kernel(concat!(stringify!($name), "_i32")); + pub const I16: Kernel = Kernel(concat!(stringify!($name), "_i16")); + pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32")); + pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8")); + } + )+ + pub mod copy { + use super::Kernel; + pub const FLOAT: Kernel = Kernel("copy_f32"); + pub const HALF: Kernel = Kernel("copy_f16"); + pub const BFLOAT: Kernel = Kernel("copy_bf16"); + pub const I64: Kernel = Kernel("copy_i64"); + pub const I32: Kernel = Kernel("copy_i32"); + pub const I16: Kernel = Kernel("copy_i16"); + pub const U32: Kernel = Kernel("copy_u32"); + pub const U8: Kernel = Kernel("copy_u8"); + } + } + + pub mod contiguous_tiled { + pub struct Kernel(pub &'static str); + $( + pub mod $name { + use super::Kernel; + pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32_tiled")); + pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_tiled")); + pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_tiled")); + pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64_tiled")); + pub const I32: Kernel = Kernel(concat!(stringify!($name), "_i32_tiled")); + pub const I16: Kernel = Kernel(concat!(stringify!($name), "_i16_tiled")); + pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32_tiled")); + pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8_tiled")); + } + )+ + pub mod copy { + use super::Kernel; + pub const FLOAT: Kernel = Kernel("copy_f32_tiled"); + pub const HALF: Kernel = Kernel("copy_f16_tiled"); + pub const BFLOAT: Kernel = Kernel("copy_bf16_tiled"); + pub const I64: Kernel = Kernel("copy_i64_tiled"); + pub const I32: Kernel = Kernel("copy_i32_tiled"); + pub const I16: Kernel = Kernel("copy_i16_tiled"); + pub const U32: Kernel = Kernel("copy_u32_tiled"); + pub const U8: Kernel = Kernel("copy_u8_tiled"); + } + } + + pub mod strided { + pub struct Kernel(pub &'static str); + $( + pub mod $name { + use super::Kernel; + pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32_strided")); + pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_strided")); + pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_strided")); + pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64_strided")); + pub const I32: Kernel = Kernel(concat!(stringify!($name), "_i32_strided")); + pub const I16: Kernel = Kernel(concat!(stringify!($name), "_i16_strided")); + pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32_strided")); + pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8_strided")); + } + )+ + pub mod copy { + use super::Kernel; + pub const FLOAT: Kernel = Kernel("copy_f32_strided"); + pub const HALF: Kernel = Kernel("copy_f16_strided"); + pub const BFLOAT: Kernel = Kernel("copy_bf16_strided"); + pub const I64: Kernel = Kernel("copy_i64_strided"); + pub const I32: Kernel = Kernel("copy_i32_strided"); + pub const I16: Kernel = Kernel("copy_i16_strided"); + pub const U32: Kernel = Kernel("copy_u32_strided"); + pub const U8: Kernel = Kernel("copy_u8_strided"); + } + } + }; +} + +pub mod unary { + ops!( + cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf, + tanh, recip, silu, sign, sigmoid + ); +} +pub mod binary { + ops!(add, sub, mul, div, min, max, eq, ne, le, lt, ge, gt); +} + +#[derive(thiserror::Error, Debug)] +pub enum MetalKernelError { + #[error("Could not lock kernel map: {0}")] + LockError(String), + #[error("Error while loading library: {0}")] + LoadLibraryError(String), + #[error("Error while loading function: {0:?}")] + LoadFunctionError(String), + #[error("Failed to create compute function")] + FailedToCreateComputeFunction, + #[error("Failed to create pipeline")] + FailedToCreatePipeline(String), + #[error("Invalid matmul arguments {lhs_stride:?} {rhs_stride:?} {mnk:?}")] + MatMulNonContiguous { + lhs_stride: Vec, + rhs_stride: Vec, + mnk: (usize, usize, usize), + }, + #[error("Sdpa {variation} head size was {got}, expectd {expected:?}")] + SdpaHeadSizeMismatch { + variation: &'static str, + got: usize, + expected: Vec, + }, + #[error("Sdpa {variation} got dtype {got:?}")] + SdpaHeadDTypeMismatch { + variation: &'static str, + got: SdpaDType, + }, +} + +impl From> for MetalKernelError { + fn from(e: std::sync::PoisonError) -> Self { + Self::LockError(e.to_string()) + } +} + +type Libraries = HashMap; +type Pipelines = HashMap<(&'static str, Option), ComputePipelineState>; + +#[derive(Debug)] +pub struct Kernels { + libraries: RwLock, + pipelines: RwLock, +} + +impl Default for Kernels { + fn default() -> Self { + Self::new() + } +} + +impl Kernels { + pub fn new() -> Self { + let libraries = RwLock::new(Libraries::new()); + let pipelines = RwLock::new(Pipelines::new()); + Self { + libraries, + pipelines, + } + } + + fn get_library_source(&self, source: Source) -> &'static str { + match source { + Source::Affine => AFFINE, + Source::Binary => BINARY, + Source::Cast => CAST, + Source::Conv => CONV, + Source::Fill => FILL, + Source::Gemm => MLX_GEMM, + Source::Indexing => INDEXING, + Source::Quantized => QUANTIZED, + Source::Random => RANDOM, + Source::Reduce => REDUCE, + Source::Sort => SORT, + Source::Ternary => TERNARY, + Source::Unary => UNARY, + Source::Sdpa => SDPA, + Source::Mfa => panic!("Invalid lib"), + } + } + + /// Load the give library from its [`source`]. + /// If this has been previously loaded it will just fetch it from cache. + pub fn load_library( + &self, + device: &Device, + source: Source, + ) -> Result { + let mut libraries = self.libraries.write()?; + if let Some(lib) = libraries.get(&source) { + Ok(lib.clone()) + } else { + let lib = match source { + Source::Mfa => { + let source_data = MFA; + device.new_library_with_data(source_data).map_err(|e| { + MetalKernelError::LoadLibraryError(format!( + "Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}" + )) + })? + } + source => { + let source_content = self.get_library_source(source); + device + .new_library_with_source(source_content, &CompileOptions::new()) + .map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))? + } + }; + libraries.insert(source, lib.clone()); + Ok(lib) + } + } + + fn load_function( + &self, + device: &Device, + source: Source, + name: &'static str, + constants: Option, + ) -> Result { + let func = self + .load_library(device, source)? + .get_function(name, constants) + .map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?; + Ok(func) + } + + /// Load the give pipeline + /// loads the library from source, then gets the function [`name`] from + /// that source + fn load_pipeline_with_constants( + &self, + device: &Device, + source: Source, + name: &'static str, + constants: Option, + ) -> Result { + let mut pipelines = self.pipelines.write()?; + let key = (name, constants); + if let Some(pipeline) = pipelines.get(&key) { + Ok(pipeline.clone()) + } else { + let (name, constants) = key; + let func = self.load_function( + device, + source, + name, + constants.as_ref().map(|c| c.function_constant_values()), + )?; + let pipeline = device + .new_compute_pipeline_state_with_function(&func) + .map_err(|e| MetalKernelError::FailedToCreatePipeline(e.to_string()))?; + pipelines.insert((name, constants), pipeline.clone()); + + Ok(pipeline) + } + } + + /// Load the give pipeline + /// loads the library from source, then gets the function [`name`] from + /// that source (without constants) + pub fn load_pipeline( + &self, + device: &Device, + source: Source, + name: &'static str, + ) -> Result { + self.load_pipeline_with_constants(device, source, name, None) + } +} + +#[allow(clippy::too_many_arguments)] +pub fn call_copy2d( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + name: copy2d::Kernel, + input: &Buffer, + output: &Buffer, + d1: usize, + d2: usize, + src_s: usize, + dst_s: usize, + src_o_in_bytes: usize, + dst_o_in_bytes: usize, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?; + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + encoder.set_compute_pipeline_state(&pipeline); + set_params!( + encoder, + ( + d1 as i64, + d2 as i64, + src_s as i64, + dst_s as i64, + (input, src_o_in_bytes), + (output, dst_o_in_bytes) + ) + ); + + let grid_dims = MTLSize { + width: d1 as u64, + height: d2 as u64, + depth: 1, + }; + let group_dims = get_block_dims(d1 as u64, d2 as u64, 1); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_threads(grid_dims, group_dims); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_unary_contiguous_tiled( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + kernel_name: unary::contiguous_tiled::Kernel, + length: usize, + input: BufferOffset, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + let tile_size = 2; + let tiles = (length + tile_size - 1) / tile_size; + + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (length, &input, output)); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, tiles); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_unary_contiguous( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + kernel_name: unary::contiguous::Kernel, + length: usize, + input: BufferOffset, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (length, &input, output)); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_unary_strided( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + name: unary::strided::Kernel, + shape: &[usize], + input: BufferOffset, + strides: &[usize], + output: BufferOffset, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?; + + let length: usize = shape.iter().product(); + let num_dims: usize = shape.len(); + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); + + encoder.set_compute_pipeline_state(&pipeline); + set_params!(encoder, (length, num_dims, shape, strides, &input, &output)); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_binary_contiguous( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + kernel_name: binary::contiguous::Kernel, + length: usize, + left: BufferOffset, + right: BufferOffset, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?; + + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (length, &left, &right, output)); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); + + encoder.use_resource(left.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(right.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_binary_strided( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + name: binary::strided::Kernel, + shape: &[usize], + left_input: BufferOffset, + left_strides: &[usize], + right_input: BufferOffset, + right_strides: &[usize], + output: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?; + + let num_dims: usize = shape.len(); + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + let width: usize = shape.iter().product(); + let length: usize = shape.iter().product(); + let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); + + encoder.set_compute_pipeline_state(&pipeline); + set_params!( + encoder, + ( + length, + num_dims, + shape, + left_strides, + right_strides, + &left_input, + &right_input, + output + ) + ); + encoder.use_resource(left_input.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(right_input.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_cast_contiguous( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + kernel_name: &'static str, + length: usize, + input: BufferOffset, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; + + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (length, &input, output)); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_cast_strided( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + kernel_name: &'static str, + shape: &[usize], + input: BufferOffset, + input_strides: &[usize], + output: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; + + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + encoder.set_compute_pipeline_state(&pipeline); + + let length: usize = shape.iter().product(); + + set_params!( + encoder, + (length, shape.len(), shape, input_strides, &input, output) + ); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); + + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_reduce_contiguous( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + kernel_name: &'static str, + length: usize, + out_length: usize, + input: BufferOffset, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; + let elements_to_sum = length / out_length; + + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (length, elements_to_sum, &input, output)); + + let thread_group_count = MTLSize { + width: out_length as u64, + height: 1, + depth: 1, + }; + + let width = std::cmp::min( + pipeline.max_total_threads_per_threadgroup(), + (elements_to_sum as u64 + 2 - 1) / 2, + ) + .next_power_of_two(); + + let thread_group_size = MTLSize { + width, + height: 1, + depth: 1, + }; + + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_reduce_strided( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + kernel_name: &'static str, + shape: &[usize], + strides: &[usize], + out_length: usize, + input: BufferOffset, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let length: usize = shape.iter().product(); + let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; + let elements_to_sum = length / out_length; + + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + (shape.len(), shape, strides, elements_to_sum, &input, output) + ); + + let thread_group_count = MTLSize { + width: out_length as u64, + height: 1, + depth: 1, + }; + + let width = std::cmp::min( + pipeline.max_total_threads_per_threadgroup(), + elements_to_sum as u64, + ) + .next_power_of_two(); + + let thread_group_size = MTLSize { + width, + height: 1, + depth: 1, + }; + + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_last_softmax( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + kernel_name: &'static str, + length: usize, + elements_to_sum: usize, + input: &Buffer, + input_offset: usize, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + (length, elements_to_sum, (input, input_offset), output) + ); + + let out_length = length / elements_to_sum; + + let thread_group_count = MTLSize { + width: out_length as u64, + height: 1, + depth: 1, + }; + + let width = std::cmp::min( + pipeline.max_total_threads_per_threadgroup(), + elements_to_sum as u64, + ) + .next_power_of_two(); + + let thread_group_size = MTLSize { + width, + height: 1, + depth: 1, + }; + + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + Ok(()) +} + +// Requires continuous input and mask +#[allow(clippy::too_many_arguments)] +pub fn call_last_attn_softmax( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + input: &Buffer, + input_offset: usize, + mask: &Buffer, + mask_offset: usize, + input_shape: &[usize], + scale: f32, + ty: SdpaDType, + output: &Buffer, +) -> Result<(), MetalKernelError> { + // Everything is in reverse + let ne00 = input_shape[input_shape.len() - 1] as i64; + let ne01 = input_shape[input_shape.len() - 2] as i64; + let ne02 = input_shape[input_shape.len() - 3] as i64; + let ne03 = input_shape[input_shape.len() - 4] as i64; + + let mut nth = 32; // SIMD width + let name = if ne00 % 4 == 0 { + while nth < ne00 / 4 && nth * ne01 * ne02 * ne03 < 256 { + nth *= 2; + } + match ty { + SdpaDType::F32 => "attn_soft_max_f32_4", + SdpaDType::F16 => "attn_soft_max_f16_4", + SdpaDType::BF16 => "attn_soft_max_bf16_4", + } + } else { + while nth < ne00 && nth * ne01 * ne02 * ne03 < 256 { + nth *= 2; + } + match ty { + SdpaDType::F32 => "attn_soft_max_f32", + SdpaDType::F16 => "attn_soft_max_f16", + SdpaDType::BF16 => "attn_soft_max_bf16", + } + }; + + let pipeline = kernels.load_pipeline(device, Source::Reduce, name)?; + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + (input, input_offset), + (mask, mask_offset), + output, + ne00, + ne01, + ne02, + scale + ) + ); + + let thread_group_count = MTLSize { + width: (ne01 * ne02 * ne03) as u64, + height: 1, + depth: 1, + }; + let thread_group_size = MTLSize { + width: nth as u64, + height: 1, + depth: 1, + }; + + encoder.set_threadgroup_memory_length(0, 32 * std::mem::size_of::() as u64); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_rms_norm( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + kernel_name: &'static str, + length: usize, + elements_to_sum: usize, + eps: f32, + input: &Buffer, + input_offset: usize, + alpha: &Buffer, + alpha_offset: usize, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + length, + elements_to_sum, + (input, input_offset), + output, + (alpha, alpha_offset), + eps + ) + ); + + let out_length = length / elements_to_sum; + + let thread_group_count = MTLSize { + width: out_length as u64, + height: 1, + depth: 1, + }; + + let width = std::cmp::min( + pipeline.max_total_threads_per_threadgroup(), + elements_to_sum as u64, + ) + .next_power_of_two(); + + let thread_group_size = MTLSize { + width, + height: 1, + depth: 1, + }; + + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.set_threadgroup_memory_length(0, (width * 4).max(16) as u64); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_layer_norm( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + kernel_name: &'static str, + length: usize, + elements_to_sum: usize, + eps: f32, + input: &Buffer, + input_offset: usize, + alpha: &Buffer, + alpha_offset: usize, + beta: &Buffer, + beta_offset: usize, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + length, + elements_to_sum, + (input, input_offset), + output, + (alpha, alpha_offset), + (beta, beta_offset), + eps + ) + ); + + let out_length = length / elements_to_sum; + + let thread_group_count = MTLSize { + width: out_length as u64, + height: 1, + depth: 1, + }; + + let width = std::cmp::min( + pipeline.max_total_threads_per_threadgroup(), + elements_to_sum as u64, + ) + .next_power_of_two(); + + let thread_group_size = MTLSize { + width, + height: 1, + depth: 1, + }; + + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.set_threadgroup_memory_length(0, (width * 8).max(32) as u64); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_rope_i( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + kernel_name: &'static str, + bh: usize, + td: usize, + src: &Buffer, + src_offset: usize, + cos: &Buffer, + cos_offset: usize, + sin: &Buffer, + sin_offset: usize, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + bh, + td, + (src, src_offset), + (cos, cos_offset), + (sin, sin_offset), + output + ) + ); + let (thread_group_count, thread_group_size) = linear_split(&pipeline, (bh * td) / 2); + encoder.use_resource(src, metal::MTLResourceUsage::Read); + encoder.use_resource(cos, metal::MTLResourceUsage::Read); + encoder.use_resource(sin, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_rope_thd( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + kernel_name: &'static str, + b: usize, + t: usize, + h: usize, + d: usize, + src: &Buffer, + src_offset: usize, + cos: &Buffer, + cos_offset: usize, + sin: &Buffer, + sin_offset: usize, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + b, + t, + h, + d, + (src, src_offset), + (cos, cos_offset), + (sin, sin_offset), + output + ) + ); + let (thread_group_count, thread_group_size) = linear_split(&pipeline, (b * t * h * d) / 2); + encoder.use_resource(src, metal::MTLResourceUsage::Read); + encoder.use_resource(cos, metal::MTLResourceUsage::Read); + encoder.use_resource(sin, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_rope( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + kernel_name: &'static str, + bh: usize, + td: usize, + d: usize, + src: &Buffer, + src_offset: usize, + cos: &Buffer, + cos_offset: usize, + sin: &Buffer, + sin_offset: usize, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + bh, + td, + d, + (src, src_offset), + (cos, cos_offset), + (sin, sin_offset), + output + ) + ); + let (thread_group_count, thread_group_size) = linear_split(&pipeline, (bh * td) / 2); + encoder.use_resource(src, metal::MTLResourceUsage::Read); + encoder.use_resource(cos, metal::MTLResourceUsage::Read); + encoder.use_resource(sin, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_affine( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + name: &'static str, + size: usize, + input: BufferOffset, + output: &Buffer, + mul: f32, + add: f32, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; + + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (size, mul, add, &input, output)); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_affine_strided( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + input: BufferOffset, + input_stride: &[usize], + output: &Buffer, + mul: f32, + add: f32, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; + let size: usize = shape.iter().product(); + + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + size, + shape.len(), + shape, + input_stride, + mul, + add, + &input, + output + ) + ); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_powf( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + name: &'static str, + size: usize, + input: BufferOffset, + output: &Buffer, + mul: f32, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; + + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (size, mul, &input, output)); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_powf_strided( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + input: BufferOffset, + input_stride: &[usize], + output: &Buffer, + mul: f32, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; + let size: usize = shape.iter().product(); + + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + (size, shape.len(), shape, input_stride, mul, &input, output) + ); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_elu( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + name: &'static str, + size: usize, + input: BufferOffset, + output: &Buffer, + mul: f32, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; + + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (size, mul, &input, output)); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_elu_strided( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + input: BufferOffset, + input_stride: &[usize], + output: &Buffer, + mul: f32, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; + let size: usize = shape.iter().product(); + + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + (size, shape.len(), shape, input_stride, mul, &input, output) + ); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_where_cond_strided( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + cond: BufferOffset, + cond_stride: &[usize], + left: BufferOffset, + left_stride: &[usize], + right: BufferOffset, + right_stride: &[usize], + output: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?; + + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + encoder.set_compute_pipeline_state(&pipeline); + + let size: usize = shape.iter().product(); + let rank = shape.len(); + + set_params!( + encoder, + ( + size, + rank, + shape, + cond_stride, + left_stride, + right_stride, + &cond, + &left, + &right, + output + ) + ); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + + encoder.use_resource(cond.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(left.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(right.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_index_select( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + ids_size: usize, + dim: usize, + contiguous: bool, + src_dims: &[usize], + src_strides: &[usize], + input: BufferOffset, + ids: BufferOffset, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let left_size: usize = shape[..dim].iter().product(); + let right_size: usize = shape[dim + 1..].iter().product(); + let src_dim_size = shape[dim]; + let dst_el = ids_size * left_size * right_size; + + let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; + + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + dst_el, + left_size, + src_dim_size, + right_size, + ids_size, + contiguous, + src_dims, + src_strides, + &input, + &ids, + output + ) + ); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); + + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_gather( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + ids_size: usize, + dim: usize, + input: BufferOffset, + ids: BufferOffset, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let left_size: usize = shape[..dim].iter().product(); + let right_size: usize = shape[dim + 1..].iter().product(); + let src_dim_size = shape[dim]; + let dst_el = ids_size * left_size * right_size; + + let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; + + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + dst_el, + left_size, + src_dim_size, + right_size, + ids_size, + &input, + &ids, + output + ) + ); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); + + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_scatter_add( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + name: &'static str, + src_shape: &[usize], + dst_shape: &[usize], + dim: usize, + input: BufferOffset, + ids: BufferOffset, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let left_size: usize = src_shape[..dim].iter().product(); + let right_size: usize = src_shape[dim + 1..].iter().product(); + let src_dim_size = src_shape[dim]; + let dst_el = left_size * right_size; + let dst_dim_size = dst_shape[dim]; + + let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; + + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + dst_el, + left_size, + src_dim_size, + right_size, + dst_dim_size, + &input, + &ids, + output + ) + ); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); + + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_index_add( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + name: &'static str, + src_shape: &[usize], + dst_shape: &[usize], + ids_shape: &[usize], + dim: usize, + input: BufferOffset, + ids: BufferOffset, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let left_size: usize = src_shape[..dim].iter().product(); + let right_size: usize = src_shape[dim + 1..].iter().product(); + let src_dim_size = src_shape[dim]; + let dst_el = left_size * right_size; + let dst_dim_size = dst_shape[dim]; + let ids_dim_size = ids_shape[0]; + + let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + dst_el, + left_size, + src_dim_size, + right_size, + dst_dim_size, + ids_dim_size, + &input, + &ids, + output + ) + ); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); + + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + Ok(()) +} + +#[derive(Debug, PartialEq)] +pub enum Value { + USize(usize), + Bool(bool), + F32(f32), + U16(u16), +} + +impl std::hash::Hash for Value { + fn hash(&self, state: &mut H) { + match self { + Value::F32(v) => v.to_bits().hash(state), + Value::USize(v) => v.hash(state), + Value::U16(v) => v.hash(state), + Value::Bool(v) => v.hash(state), + } + } +} + +impl Value { + fn data_type(&self) -> MTLDataType { + match self { + Value::USize(_) => MTLDataType::UInt, + Value::F32(_) => MTLDataType::Float, + Value::U16(_) => MTLDataType::UShort, + Value::Bool(_) => MTLDataType::Bool, + } + } +} + +/// Not true, good enough for our purposes. +impl Eq for Value {} + +#[derive(Debug, Eq, PartialEq, Hash)] +struct ConstantValues(Vec<(usize, Value)>); + +impl ConstantValues { + pub fn new(values: Vec<(usize, Value)>) -> Self { + Self(values) + } + + fn function_constant_values(&self) -> FunctionConstantValues { + let f = FunctionConstantValues::new(); + for (index, value) in &self.0 { + let ty = value.data_type(); + match value { + Value::USize(v) => { + f.set_constant_value_at_index( + v as *const usize as *const c_void, + ty, + *index as u64, + ); + } + Value::F32(v) => { + f.set_constant_value_at_index( + v as *const f32 as *const c_void, + ty, + *index as u64, + ); + } + Value::U16(v) => { + f.set_constant_value_at_index( + v as *const u16 as *const c_void, + ty, + *index as u64, + ); + } + Value::Bool(v) => { + f.set_constant_value_at_index( + v as *const bool as *const c_void, + ty, + *index as u64, + ); + } + } + } + f + } +} + +#[allow(clippy::too_many_arguments)] +pub fn call_gemm( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + name: &'static str, + (b, m, n, k): (usize, usize, usize, usize), + lhs_stride: &[usize], + lhs_offset: usize, + lhs_buffer: &Buffer, + rhs_stride: &[usize], + rhs_offset: usize, + rhs_buffer: &Buffer, + output: &Buffer, + alpha: f32, + beta: f32, +) -> Result<(), MetalKernelError> { + assert!(rhs_stride.len() >= 2); + assert!(lhs_stride.len() >= 2); + let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; + let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; + let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; + let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; + // lhs has shape b, m, k + // We also allow for the case where the stride on the minor dimension is not as expected but + // there is a single element. + let a_trans = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) { + false + } else if (lhs_m1 == m || k == 1) && (lhs_m2 == 1 || m == 1) { + true + } else { + return Err(MetalKernelError::MatMulNonContiguous { + lhs_stride: lhs_stride.to_vec(), + rhs_stride: rhs_stride.to_vec(), + mnk: (m, n, k), + })?; + }; + // rhs has shape b, k, n + let b_trans = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) { + false + } else if (rhs_m1 == k || n == 1) && (rhs_m2 == 1 || k == 1) { + true + } else { + return Err(MetalKernelError::MatMulNonContiguous { + lhs_stride: lhs_stride.to_vec(), + rhs_stride: rhs_stride.to_vec(), + mnk: (m, n, k), + })?; + }; + let d_trans = false; + let batched = b > 1; + let fused_activation = false; + let fused_bias = false; + let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 { + let m_simd = 8; + let n_simd = 8; + let k_simd = 64; + let m_splits = 1; + let n_splits = 1; + (m_simd, n_simd, k_simd, m_splits, n_splits) + } else { + let m_simd = 40; + let n_simd = 40; + let k_simd = 32; + let m_splits = 1; + let n_splits = 1; + (m_simd, n_simd, k_simd, m_splits, n_splits) + }; + let constants = Some(ConstantValues::new(vec![ + (0, Value::USize(m)), + (1, Value::USize(n)), + (2, Value::USize(k)), + (10, Value::Bool(a_trans)), + (11, Value::Bool(b_trans)), + (13, Value::Bool(d_trans)), + (20, Value::F32(alpha)), + (21, Value::F32(beta)), + (100, Value::Bool(batched)), + (101, Value::Bool(fused_activation)), + // Garbage + (102, Value::Bool(false)), + (103, Value::Bool(false)), + (113, Value::Bool(false)), + (50_000, Value::Bool(false)), + // End garbage + (200, Value::U16(m_simd)), + (201, Value::U16(n_simd)), + (202, Value::U16(k_simd)), + (210, Value::U16(m_splits)), + (211, Value::U16(n_splits)), + (50_001, Value::Bool(fused_bias)), + ])); + let pipeline = kernels.load_pipeline_with_constants(device, Source::Mfa, name, constants)?; + let m_group = m_simd * m_splits; + let n_group = n_simd * n_splits; + + let a_block_length = m_group * k_simd; + let b_block_length = k_simd * n_group; + + let mut block_elements = a_block_length + b_block_length; + if (m % 8 != 0) && (n % 8 != 0) { + let c_block_length = m_group * n_group; + block_elements = std::cmp::max(c_block_length, block_elements) + } + if fused_bias { + if d_trans { + block_elements = std::cmp::max(block_elements, m_group); + } else { + block_elements = std::cmp::max(block_elements, n_group); + } + } + let bytes = match name { + "sgemm" => 4, + "hgemm" => 2, + "bgemm" => 2, + other => { + return Err(MetalKernelError::LoadLibraryError(format!( + "{other} is not a valid kernel for gemm" + ))); + } + }; + let block_bytes = block_elements * bytes; + + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + encoder.set_compute_pipeline_state(&pipeline); + encoder.set_threadgroup_memory_length(0, block_bytes.into()); + encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger); + encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger); + encoder.set_buffer(2, Some(output), 0); + // TODO Tensor D + + let grid_z = b; + if batched { + let byte_stride_a: usize = lhs_stride[lhs_stride.len() - 3] * bytes as usize; + let byte_stride_b: usize = rhs_stride[rhs_stride.len() - 3] * bytes as usize; + let byte_stride_c = m * n * bytes as usize; + // TODO byte_stride_d + let byte_stride_d = 0; + + let buffer: Vec = vec![ + byte_stride_a as _, + byte_stride_b as _, + byte_stride_c as _, + byte_stride_d as _, + ]; + encoder.set_bytes( + 10, + (buffer.len() * core::mem::size_of::()) as NSUInteger, + buffer.as_ptr() as *const NSUInteger as *const c_void, + ); + } + + let grid_size = MTLSize { + width: divide(n, n_group.into()), + height: divide(m, m_group.into()), + depth: grid_z as NSUInteger, + }; + let group_size = MTLSize { + width: 32 * (m_splits as u64) * (n_splits as u64), + height: 1, + depth: 1, + }; + encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(grid_size, group_size); + Ok(()) +} + +#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)] +pub enum SdpaDType { + BF16, + F16, + F32, +} + +/// SDPA full is supported when: +/// - q head dim == 64, 128 +/// - no mask +/// - q heads == kv heads +/// - final type != bf16 (TODO maybe just template this kernel too?) +/// - q,k,v are contiguous +#[allow(clippy::too_many_arguments)] +pub fn call_sdpa_full( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + q_offset: usize, + q_shape: &[usize], + q_buffer: &Buffer, + k_offset: usize, + k_buffer: &Buffer, + v_offset: usize, + v_buffer: &Buffer, + output: &Buffer, + alpha: f32, + softcapping: f32, + itype: SdpaDType, +) -> Result<(), MetalKernelError> { + #[derive(Debug)] + #[repr(C)] + struct MLXFastAttentionParams { + m: i32, + n: i32, + k: i32, + + ldq: i32, // ldq == ldo + ldk: i32, + ldv: i32, + lds: i32, + ldo: i32, + + tiles_n: i32, + tiles_m: i32, + + batch_stride_q: i32, + batch_stride_k: i32, + batch_stride_v: i32, + batch_stride_o: i32, + + swizzle_log: i32, + gemm_n_iterations_aligned: i32, + gemm_k_iterations_aligned: i32, + gemm_sv_m_block_iterations: i32, + + batch_ndim: i32, + alpha: f32, + softcapping: f32, + } + + let bk = q_shape.last().unwrap(); + + const BN: usize = 16; + const BM: usize = 16; + const WM: usize = 2; + const WN: usize = 2; + + let name = match (bk, itype) { + (32, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_32_itype_half", + (64, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_64_itype_half", + (96, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_96_itype_half", + (128, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_128_itype_half", + (256, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_256_itype_half", + (32, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_32_itype_float", + (64, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_64_itype_float", + (96, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_96_itype_float", + (128, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_128_itype_float", + (256, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_256_itype_float", + (other, SdpaDType::F16 | SdpaDType::F32) => { + return Err(MetalKernelError::SdpaHeadSizeMismatch { + variation: "full", + got: *other, + expected: vec![32, 64, 96, 128, 256], + }) + } + (_, SdpaDType::BF16) => { + return Err(MetalKernelError::SdpaHeadDTypeMismatch { + variation: "full", + got: SdpaDType::BF16, + }) + } + }; + + let pipeline = kernels.load_pipeline(device, Source::Sdpa, &name)?; + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + encoder.set_compute_pipeline_state(&pipeline); + + // q = (bs, qhead, seq, hidden) + // k/v = (bs, kv_head, seq, hidden) + + let qseq = q_shape[q_shape.len() - 2]; + + let m = q_shape[q_shape.len() - 2]; + let n = m; + let k = q_shape[q_shape.len() - 1]; + let bs_out = q_shape[0] * q_shape[1]; + + let batch_shape = [q_shape[0] * q_shape[1]]; + let dk = q_shape[q_shape.len() - 1]; + let ldq = dk; + let ldk = dk; + let ldv = dk; + let lds = BN; + let ldo = dk; + + let tn = 1; + let tm = (m + BM - 1) / BM; + + let b_stride_q = dk * qseq; + let b_stride_k = dk * qseq; + let b_stride_v = dk * qseq; + let b_stride_o = dk * qseq; + let swizzle_log = 0; + let gemm_n_iterations_aligned = (n + BN - 1) / BN; + let gemm_k_iterations_aligned = (k + bk - 1) / bk; + let gemm_sv_m_block_iterations = (m + BM - 1) / BM; + let batch_ndim = batch_shape.len(); + + let alpha = if softcapping != 1. { + alpha / softcapping + } else { + alpha + }; + + let params = MLXFastAttentionParams { + m: m as i32, + n: n as i32, + k: k as i32, + ldq: ldq as i32, + ldk: ldk as i32, + ldv: ldv as i32, + lds: lds as i32, + ldo: ldo as i32, + tiles_n: tn, + tiles_m: tm as i32, + batch_stride_q: b_stride_q as i32, + batch_stride_k: b_stride_k as i32, + batch_stride_v: b_stride_v as i32, + batch_stride_o: b_stride_o as i32, + swizzle_log, + gemm_n_iterations_aligned: gemm_n_iterations_aligned as i32, + gemm_k_iterations_aligned: gemm_k_iterations_aligned as i32, + gemm_sv_m_block_iterations: gemm_sv_m_block_iterations as i32, + batch_ndim: batch_ndim as i32, + alpha, + softcapping, + }; + let batch_strides = [b_stride_q, b_stride_k, b_stride_v, b_stride_o]; + + impl EncoderParam for MLXFastAttentionParams { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { + encoder.set_bytes( + position, + core::mem::size_of::() as u64, + &data as *const MLXFastAttentionParams as *const c_void, + ); + } + } + + set_params!( + encoder, + ( + (q_buffer, q_offset), + (k_buffer, k_offset), + (v_buffer, v_offset), + output, + params, + &batch_shape[..], + &batch_strides[..] + ) + ); + + let grid_dims = MTLSize { + width: 1, + height: tm as u64, + depth: bs_out as u64, + }; + let group_dims = MTLSize { + width: 32, + height: WM as u64, + depth: WN as u64, + }; + encoder.use_resource(q_buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(k_buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(v_buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(grid_dims, group_dims); + Ok(()) +} + +/// SDPA full is supported when: +/// - q head dim == 64, 96, 128 +/// - no mask +/// - q,k,v are contiguous +#[allow(clippy::too_many_arguments)] +pub fn call_sdpa_vector( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + q_offset: usize, + q_shape: &[usize], + q_buffer: &Buffer, + k_offset: usize, + k_shape: &[usize], + k_stride: &[usize], + k_buffer: &Buffer, + v_offset: usize, + v_stride: &[usize], + v_buffer: &Buffer, + output: &Buffer, + alpha: f32, + softcapping: f32, + itype: SdpaDType, +) -> Result<(), MetalKernelError> { + let bk = q_shape.last().unwrap(); + + let gqa_factor = (q_shape[1] / k_shape[1]) as i32; + let n = k_shape[2] as i32; + let b = (q_shape[0] * q_shape[1]) as i32; + let kstride = k_stride[1]; + let vstride = v_stride[1]; + + let name = match (bk, itype) { + (32, SdpaDType::F16) => "sdpa_vector_float16_t_32", + (64, SdpaDType::F16) => "sdpa_vector_float16_t_64", + (96, SdpaDType::F16) => "sdpa_vector_float16_t_96", + (128, SdpaDType::F16) => "sdpa_vector_float16_t_128", + (256, SdpaDType::F16) => "sdpa_vector_float16_t_256", + (32, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_32", + (64, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_64", + (96, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_96", + (128, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_128", + (256, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_256", + (32, SdpaDType::F32) => "sdpa_vector_float_32", + (64, SdpaDType::F32) => "sdpa_vector_float_64", + (96, SdpaDType::F32) => "sdpa_vector_float_96", + (128, SdpaDType::F32) => "sdpa_vector_float_128", + (256, SdpaDType::F32) => "sdpa_vector_float_256", + (other, _) => { + return Err(MetalKernelError::SdpaHeadSizeMismatch { + variation: "vector", + got: *other, + expected: vec![32, 64, 96, 128, 256], + }) + } + }; + + let alpha = if softcapping != 1. { + alpha / softcapping + } else { + alpha + }; + + let pipeline = kernels.load_pipeline(device, Source::Sdpa, &name)?; + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + encoder.set_compute_pipeline_state(&pipeline); + + // q = (bs, qhead, seq, hidden) + // k/v = (bs, kv_head, kv_seq, hidden) + + set_params!( + encoder, + ( + (q_buffer, q_offset), + (k_buffer, k_offset), + (v_buffer, v_offset), + output, + gqa_factor, + n, + kstride, + vstride, + alpha, + softcapping + ) + ); + + let grid_dims = MTLSize { + width: 1, + height: b as u64, + depth: 1 as u64, + }; + let group_dims = MTLSize { + width: 1024, + height: 1, + depth: 1, + }; + encoder.use_resource(q_buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(k_buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(v_buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(grid_dims, group_dims); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_im2col1d_strided( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + strides: &[usize], + (k_size, stride, padding, dilation): (usize, usize, usize, usize), + input: BufferOffset, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Conv, name)?; + let l_out = (shape[2] + 2 * padding - dilation * (k_size - 1) - 1) / stride + 1; + let dst_el = shape[0] * l_out * shape[1] * k_size; + + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); + encoder.set_compute_pipeline_state(&pipeline); + set_params!( + encoder, + (dst_el, l_out, k_size, stride, padding, dilation, shape, strides, &input, output) + ); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_col2im1d( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + k_size: usize, + stride: usize, + input: BufferOffset, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Conv, name)?; + let l_in = shape[1]; + let c_out = shape[2]; + let l_out = (l_in - 1) * stride + k_size; + let dst_el = shape[0] * c_out * l_out; + + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); + encoder.set_compute_pipeline_state(&pipeline); + set_params!( + encoder, + (dst_el, l_out, l_in, c_out, k_size, stride, &input, output) + ); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_im2col_strided( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + strides: &[usize], + (h_k, w_k, stride, padding, dilation): (usize, usize, usize, usize, usize), + input: BufferOffset, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Conv, name)?; + + let h = shape[2]; + let w = shape[3]; + let h_out = (h + 2 * padding - dilation * (h_k - 1) - 1) / stride + 1; + let w_out = (w + 2 * padding - dilation * (w_k - 1) - 1) / stride + 1; + + let dst_el = shape[0] * h_out * w_out * shape[1] * h_k * w_k; + + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); + encoder.set_compute_pipeline_state(&pipeline); + set_params!( + encoder, + ( + dst_el, h_out, w_out, h_k, w_k, stride, padding, dilation, shape, strides, &input, + output + ) + ); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_upsample_nearest_2d( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + strides: &[usize], + out_w: usize, + out_h: usize, + input: BufferOffset, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Conv, name)?; + let dst_el = out_w * out_h * shape[0] * shape[1]; + let scale_w = shape[2] as f32 / out_w as f32; + let scale_h = shape[3] as f32 / out_h as f32; + let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + encoder.set_compute_pipeline_state(&pipeline); + set_params!( + encoder, + (out_w, out_h, scale_w, scale_h, shape, strides, &input, output) + ); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_random_uniform( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + name: &'static str, + min: f32, + max: f32, + length: usize, + seed: &Buffer, + buffer: &Buffer, +) -> Result<(), MetalKernelError> { + if min >= max { + return Err(MetalKernelError::LoadLibraryError( + "min must be less than max".to_string(), + )); + } + let pipeline = kernels.load_pipeline(device, Source::Random, name)?; + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + + let odd = (length % 2 != 0) as usize; + let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd); + + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (length, min, max, seed, buffer)); + + encoder.use_resource( + seed, + metal::MTLResourceUsage::Read | metal::MTLResourceUsage::Write, + ); + encoder.use_resource(buffer, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_random_normal( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + name: &'static str, + mean: f32, + stddev: f32, + length: usize, + seed: &Buffer, + buffer: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Random, name)?; + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + + let odd = (length % 2 != 0) as usize; + let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd); + + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (length, mean, stddev, seed, buffer)); + + encoder.use_resource( + seed, + metal::MTLResourceUsage::Read | metal::MTLResourceUsage::Write, + ); + encoder.use_resource(buffer, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + Ok(()) +} + +#[derive(Debug, Clone, Copy)] +pub enum GgmlDType { + Q4_0, + Q4_1, + Q5_0, + Q5_1, + Q8_0, + Q8_1, + Q2K, + Q3K, + Q4K, + Q5K, + Q6K, + Q8K, + F16, + F32, + BF16, +} + +#[allow(clippy::too_many_arguments)] +pub fn call_quantized_matmul_mv_t( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + dtype: GgmlDType, + (b, m, n, k): (usize, usize, usize, usize), + lhs: &Buffer, + lhs_offset: usize, + rhs: &Buffer, + dst_offset: usize, + dst: &Buffer, +) -> Result<(), MetalKernelError> { + // Everything is in reverse + let ne00 = k as i64; + let ne01 = n as i64; + let ne02 = b as i64; + let ne03 = 1i64; + + let nb00 = 0i64; + let nb01 = 0i64; + let nb02 = 0i64; + + let ne10 = k as i64; + let ne11 = m as i64; + let ne12 = b as i64; + let ne13 = 1i64; + + let nb10 = 0i64; + let nb11 = 0i64; + let nb12 = 0i64; + + let ne0 = n as i64; + let ne1 = m as i64; + let r2: u32 = (ne12 / ne02) as u32; + let r3: u32 = (ne13 / ne03) as u32; + + let (nth0, nth1, align) = match dtype { + GgmlDType::Q4_0 + | GgmlDType::Q4_1 + | GgmlDType::Q5_0 + | GgmlDType::Q5_1 + | GgmlDType::Q8_0 + | GgmlDType::Q8_1 => { + let nth0 = 8; + let nth1 = 8; + let align = 8; + (nth0, nth1, align) + } + GgmlDType::Q2K => { + // Fixing a bug in Metal for GGML + // https://github.com/ggerganov/llama.cpp/blob/b8109bc0139f15a5b321909f47510b89dca47ffc/ggml-metal.m#L1576 + let nth0 = 2; + let nth1 = 32; + let align = 4; + (nth0, nth1, align) + } + GgmlDType::Q4K => { + let nth0 = 4; + let nth1 = 8; + let align = 4; + (nth0, nth1, align) + } + GgmlDType::Q3K | GgmlDType::Q5K => { + let nth0 = 2; + let nth1 = 32; + let align = 4; + (nth0, nth1, align) + } + GgmlDType::Q6K => { + let nth0 = 2; + let nth1 = 32; + let align = 2; + (nth0, nth1, align) + } + GgmlDType::F16 | GgmlDType::BF16 | GgmlDType::Q8K => { + // Original implem uses rows + let nth0 = 32; + let nth1 = 1; + let align = 8; + (nth0, nth1, align) + } + GgmlDType::F32 => { + let nth0 = 32; + let nth1 = 1; + let align = 8; + (nth0, nth1, align) + } + }; + let thread_groups_count = MTLSize { + width: divide(ne01 as usize, align), + height: ne11 as u64, + depth: (ne12 * ne13) as u64, + }; + let threads_per_threadgroup = MTLSize { + width: nth0, + height: nth1, + depth: 1, + }; + let name = match dtype { + GgmlDType::Q4_0 => "kernel_mul_mv_q4_0_f32", + GgmlDType::Q4_1 => "kernel_mul_mv_q4_1_f32", + GgmlDType::Q5_0 => "kernel_mul_mv_q5_0_f32", + GgmlDType::Q5_1 => "kernel_mul_mv_q5_1_f32", + GgmlDType::Q8_0 => "kernel_mul_mv_q8_0_f32", + GgmlDType::Q8_1 => "kernel_mul_mv_q8_1_f32", + GgmlDType::Q2K => "kernel_mul_mv_q2_K_f32", + GgmlDType::Q3K => "kernel_mul_mv_q3_K_f32", + GgmlDType::Q4K => "kernel_mul_mv_q4_K_f32", + GgmlDType::Q5K => "kernel_mul_mv_q5_K_f32", + GgmlDType::Q6K => "kernel_mul_mv_q6_K_f32", + GgmlDType::Q8K => "kernel_mul_mv_q8_K_f32", + GgmlDType::F16 => "kernel_mul_mv_f16_f32", + GgmlDType::BF16 => "kernel_mul_mv_bf16_f32", + GgmlDType::F32 => "kernel_mul_mv_f32_f32", + }; + + let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?; + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + rhs, + (lhs, lhs_offset), + (dst, dst_offset), + ne00, + ne01, + ne02, + nb00, + nb01, + nb02, + ne10, + ne11, + ne12, + nb10, + nb11, + nb12, + ne0, + ne1, + r2, + r3 + ) + ); + encoder.use_resource(lhs, metal::MTLResourceUsage::Read); + encoder.use_resource(rhs, metal::MTLResourceUsage::Read); + encoder.use_resource(dst, metal::MTLResourceUsage::Write); + + encoder.dispatch_thread_groups(thread_groups_count, threads_per_threadgroup); + Ok(()) +} + +/// - src0 is usually weight +/// - src1 is usually xs +#[allow(clippy::too_many_arguments)] +pub fn call_quantized_matmul_mm_t( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + dtype: GgmlDType, + src0_shape: &[usize], + src0_stride: &[usize], + src0: &Buffer, + src1_shape: &[usize], + src1_stride: &[usize], + src1: &Buffer, + src1_offset: usize, + dst_shape: &[usize], + dst_offset: usize, + dst: &Buffer, +) -> Result<(), MetalKernelError> { + // Everything is in reverse + let ne00 = src0_shape[src0_shape.len() - 1] as i64; + let ne01 = src0_shape[src0_shape.len() - 2] as i64; + let ne02 = src0_shape[src0_shape.len() - 3] as i64; + let ne03 = src0_shape[src0_shape.len() - 4] as i64; + + let nb01 = src0_stride[src0_stride.len() - 2] as i64; + let nb02 = src0_stride[src0_stride.len() - 3] as i64; + let nb03 = src0_stride[src0_stride.len() - 4] as i64; + + let ne11 = src1_shape[src1_shape.len() - 2] as i64; + let ne12 = src1_shape[src1_shape.len() - 3] as i64; + let ne13 = src1_shape[src1_shape.len() - 4] as i64; + + let nb10 = src1_stride[src1_stride.len() - 1] as i64; + let nb11 = src1_stride[src1_stride.len() - 2] as i64; + let nb12 = src1_stride[src1_stride.len() - 3] as i64; + let nb13 = src1_stride[src1_stride.len() - 4] as i64; + + let ne0 = dst_shape[dst_shape.len() - 1] as i64; + let ne1 = dst_shape[dst_shape.len() - 2] as i64; + let r2 = (ne12 / ne02) as u32; + let r3 = (ne13 / ne03) as u32; + + let thread_groups_count = MTLSize { + width: divide(ne11 as usize, 32), + height: divide(ne01 as usize, 64), + depth: (ne12 * ne13) as u64, + }; + let threads_per_threadgroup = MTLSize { + width: 128, + height: 1, + depth: 1, + }; + let name = match dtype { + GgmlDType::Q4_0 => "kernel_mul_mm_q4_0_f32", + GgmlDType::Q4_1 => "kernel_mul_mm_q4_1_f32", + GgmlDType::Q5_0 => "kernel_mul_mm_q5_0_f32", + GgmlDType::Q5_1 => "kernel_mul_mm_q5_1_f32", + GgmlDType::Q8_0 => "kernel_mul_mm_q8_0_f32", + GgmlDType::Q8_1 => "kernel_mul_mm_q8_1_f32", + GgmlDType::Q2K => "kernel_mul_mm_q2_K_f32", + GgmlDType::Q3K => "kernel_mul_mm_q3_K_f32", + GgmlDType::Q4K => "kernel_mul_mm_q4_K_f32", + GgmlDType::Q5K => "kernel_mul_mm_q5_K_f32", + GgmlDType::Q6K => "kernel_mul_mm_q6_K_f32", + GgmlDType::Q8K => "kernel_mul_mm_q8_K_f32", + GgmlDType::F16 => "kernel_mul_mm_f16_f32", + GgmlDType::BF16 => "kernel_mul_mm_bf16_f32", + GgmlDType::F32 => "kernel_mul_mm_f32_f32", + }; + + let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?; + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + src0, + (src1, src1_offset), + (dst, dst_offset), + ne00, + ne02, + nb01, + nb02, + nb03, + ne12, + nb10, + nb11, + nb12, + nb13, + ne0, + ne1, + r2, + r3 + ) + ); + encoder.use_resource(src0, metal::MTLResourceUsage::Read); + encoder.use_resource(src1, metal::MTLResourceUsage::Read); + encoder.use_resource(dst, metal::MTLResourceUsage::Write); + + encoder.set_threadgroup_memory_length(0, 8192); + + encoder.dispatch_thread_groups(thread_groups_count, threads_per_threadgroup); + Ok(()) +} + +fn divide(m: usize, b: usize) -> NSUInteger { + ((m + b - 1) / b) as NSUInteger +} + +#[allow(clippy::too_many_arguments)] +pub fn call_pool2d( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + strides: &[usize], + out_w: usize, + out_h: usize, + w_k: usize, + h_k: usize, + w_stride: usize, + h_stride: usize, + input: &Buffer, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let dst_el = out_w * out_h * shape[0] * shape[1]; + let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?; + let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + encoder.set_compute_pipeline_state(&pipeline); + set_params!( + encoder, + (w_k, h_k, w_stride, h_stride, shape, strides, input, output) + ); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_conv_transpose1d( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + name: &'static str, + dilation: usize, + stride: usize, + padding: usize, + out_padding: usize, + c_out: usize, + l_out: usize, + b_size: usize, + src_shape: &[usize], + src_strides: &[usize], + kernel_shape: &[usize], + kernel_strides: &[usize], + input: &Buffer, + input_offset: usize, + kernel: &Buffer, + kernel_offset: usize, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let dst_el = c_out * l_out * b_size; + let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?; + let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + encoder.set_compute_pipeline_state(&pipeline); + set_params!( + encoder, + ( + l_out, + stride, + padding, + out_padding, + dilation, + src_shape, + src_strides, + kernel_shape, + kernel_strides, + (input, input_offset), + (kernel, kernel_offset), + output + ) + ); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(kernel, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + Ok(()) +} + +pub struct CallConvTranspose2dCfg<'a> { + pub dilation: usize, + pub stride: usize, + pub padding: usize, + pub output_padding: usize, + pub c_out: usize, + pub out_w: usize, + pub out_h: usize, + pub b_size: usize, + pub input_dims: &'a [usize], + pub input_stride: &'a [usize], + pub kernel_dims: &'a [usize], + pub kernel_stride: &'a [usize], + pub input_offset: usize, + pub kernel_offset: usize, +} + +#[allow(clippy::too_many_arguments)] +pub fn call_conv_transpose2d( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + name: &'static str, + cfg: CallConvTranspose2dCfg, + input: &Buffer, + kernel: &Buffer, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let dst_el = cfg.c_out * cfg.out_w * cfg.out_h * cfg.b_size; + let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?; + let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + encoder.set_compute_pipeline_state(&pipeline); + set_params!( + encoder, + ( + cfg.out_w, + cfg.out_h, + cfg.stride, + cfg.padding, + cfg.output_padding, + cfg.dilation, + cfg.input_dims, + cfg.input_stride, + cfg.kernel_dims, + cfg.kernel_stride, + (input, cfg.input_offset), + (kernel, cfg.kernel_offset), + output + ) + ); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(kernel, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_arg_sort( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + name: &'static str, + nrows: usize, + ncols: usize, + ncols_pad: usize, + src: BufferOffset, + dst: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Sort, name)?; + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (&src, dst, ncols as i64, ncols_pad as i64)); + + let thread_group_count = MTLSize { + width: 1, + height: nrows as u64, + depth: 1, + }; + let thread_group_size = MTLSize { + width: ncols_pad as u64, + height: 1, + depth: 1, + }; + + encoder.use_resource(src.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(dst, metal::MTLResourceUsage::Write); + encoder.set_threadgroup_memory_length(0, (ncols_pad * 4).max(16) as u64); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + Ok(()) +} + +#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)] +pub enum GemmDType { + BF16, + F16, + F32, +} + +#[allow(clippy::too_many_arguments)] +pub fn call_mlx_gemm( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + dtype: GemmDType, + (b, m, n, k): (usize, usize, usize, usize), + lhs_stride: &[usize], + lhs_offset: usize, + lhs_buffer: &Buffer, + rhs_stride: &[usize], + rhs_offset: usize, + rhs_buffer: &Buffer, + output: &Buffer, +) -> Result<(), MetalKernelError> { + #[derive(Debug)] + #[repr(C)] + struct GemmParams { + m: i32, + n: i32, + k: i32, + lda: i32, + ldb: i32, + ldd: i32, + tiles_n: i32, + tiles_m: i32, + batch_stride_a: isize, + batch_stride_b: isize, + batch_stride_d: isize, + swizzle_log: i32, + gemm_k_iterations_aligned: i32, + batch_ndim: i32, + } + assert!(rhs_stride.len() >= 2); + assert!(lhs_stride.len() >= 2); + let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; + let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; + let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; + let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; + // lhs has shape b, m, k + // We also allow for the case where the stride on the minor dimension is not as expected but + // there is a single element. + let (lda, a_trans) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) { + (k as i32, false) + } else if (lhs_m1 == m || k == 1) && (lhs_m2 == 1 || m == 1) { + (m as i32, true) + } else { + return Err(MetalKernelError::MatMulNonContiguous { + lhs_stride: lhs_stride.to_vec(), + rhs_stride: rhs_stride.to_vec(), + mnk: (m, n, k), + })?; + }; + // rhs has shape b, k, n + let (ldb, b_trans) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) { + (n as i32, false) + } else if (rhs_m1 == k || n == 1) && (rhs_m2 == 1 || k == 1) { + (k as i32, true) + } else { + return Err(MetalKernelError::MatMulNonContiguous { + lhs_stride: lhs_stride.to_vec(), + rhs_stride: rhs_stride.to_vec(), + mnk: (m, n, k), + })?; + }; + let (bm, bn, bk, wn, wm) = (32, 32, 16, 2, 2); + // https://github.com/ml-explore/mlx/blob/02efb310cac667bc547d1b96f21596c221f84fe7/mlx/backend/metal/matmul.cpp#L422 + let constants = Some(ConstantValues::new(vec![ + (10, Value::Bool(/* has_batch */ b > 1)), + (100, Value::Bool(/* use_out_source */ false)), + (110, Value::Bool(/* do_axpby */ false)), + (200, Value::Bool(/* align_m */ m % bm == 0)), + (201, Value::Bool(/* align_n */ n % bn == 0)), + (202, Value::Bool(/* align_k */ k % bk == 0)), + (300, Value::Bool(/* do_gather */ false)), + ])); + + let swizzle_log = 0; + let tile = 1 << swizzle_log; + let tn = n.div_ceil(bn); + let tm = m.div_ceil(bm); + let tn = tn * tile; + let tm = tm.div_ceil(tile); + + let batch_stride_a = if lhs_stride.len() > 2 { + lhs_stride[lhs_stride.len() - 3] + } else { + m * k + }; + let batch_stride_b = if rhs_stride.len() > 2 { + rhs_stride[rhs_stride.len() - 3] + } else { + n * k + }; + + let gemm_params = GemmParams { + m: m as i32, + n: n as i32, + k: k as i32, + lda, + ldb, + ldd: n as i32, + tiles_n: tn as i32, + tiles_m: tm as i32, + swizzle_log, + batch_stride_a: batch_stride_a as isize, + batch_stride_b: batch_stride_b as isize, + batch_stride_d: (m * n) as isize, + batch_ndim: 1i32, + gemm_k_iterations_aligned: (k / bk) as i32, + }; + let batch_strides = [gemm_params.batch_stride_a, gemm_params.batch_stride_b]; + + // TODO(laurent): generate the name + // template [[host_name("gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]] + let name = match (dtype, a_trans, b_trans) { + (GemmDType::F32, false, false) => "gemm_nn_f32_f32_32_32_16_2_2", + (GemmDType::F32, true, false) => "gemm_tn_f32_f32_32_32_16_2_2", + (GemmDType::F32, false, true) => "gemm_nt_f32_f32_32_32_16_2_2", + (GemmDType::F32, true, true) => "gemm_tt_f32_f32_32_32_16_2_2", + (GemmDType::BF16, false, false) => "gemm_nn_bf16_bf16_32_32_16_2_2", + (GemmDType::BF16, true, false) => "gemm_tn_bf16_bf16_32_32_16_2_2", + (GemmDType::BF16, false, true) => "gemm_nt_bf16_bf16_32_32_16_2_2", + (GemmDType::BF16, true, true) => "gemm_tt_bf16_bf16_32_32_16_2_2", + (GemmDType::F16, false, false) => "gemm_nn_f16_f16_32_32_16_2_2", + (GemmDType::F16, true, false) => "gemm_tn_f16_f16_32_32_16_2_2", + (GemmDType::F16, false, true) => "gemm_nt_f16_f16_32_32_16_2_2", + (GemmDType::F16, true, true) => "gemm_tt_f16_f16_32_32_16_2_2", + }; + let pipeline = kernels.load_pipeline_with_constants(device, Source::Gemm, name, constants)?; + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + encoder.set_compute_pipeline_state(&pipeline); + encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger); + encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger); + encoder.set_buffer(3, Some(output), 0); + encoder.set_bytes( + 4, + std::mem::size_of::() as u64, + &gemm_params as *const GemmParams as *const c_void, + ); + encoder.set_bytes( + 6, // batch_shape + std::mem::size_of::() as u64, + &(b as i32) as *const i32 as *const c_void, + ); + encoder.set_bytes( + 7, + (std::mem::size_of::() * batch_strides.len()) as u64, + batch_strides.as_ptr() as *const c_void, + ); + + let grid_size = MTLSize { + width: tn as u64, + height: tm as u64, + depth: /* batch_size_out */ b as u64, + }; + let group_size = MTLSize { + width: 32, + height: wn, + depth: wm, + }; + encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(grid_size, group_size); + Ok(()) +} + +pub fn call_const_fill( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + name: &'static str, + length: usize, + output: &Buffer, + v: f32, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Fill, name)?; + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + encoder.set_compute_pipeline_state(&pipeline); + set_params!(encoder, (output, v, length)); + let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + Ok(()) +} + +#[cfg(test)] +mod tests; diff --git a/mcandle/mcandle-metal-kernels/src/libMetalFlashAttention.metallib b/mcandle/mcandle-metal-kernels/src/libMetalFlashAttention.metallib new file mode 100644 index 000000000..1e2d1acf3 Binary files /dev/null and b/mcandle/mcandle-metal-kernels/src/libMetalFlashAttention.metallib differ diff --git a/mcandle/mcandle-metal-kernels/src/mlx_gemm.metal b/mcandle/mcandle-metal-kernels/src/mlx_gemm.metal new file mode 100644 index 000000000..1b5cad92f --- /dev/null +++ b/mcandle/mcandle-metal-kernels/src/mlx_gemm.metal @@ -0,0 +1,1440 @@ +// MLX Kernel extracted from: +// https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/kernels/steel/gemm +// Copyright © 2024 Apple Inc. + +#include +#include +#include + +#define STEEL_CONST static constant constexpr const +#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") + +using namespace metal; + +// https://github.com/ml-explore/mlx/blob/02efb310cac667bc547d1b96f21596c221f84fe7/mlx/backend/metal/kernels/steel/gemm/params.h#L1 +/////////////////////////////////////////////////////////////////////////////// +// GEMM param classes +/////////////////////////////////////////////////////////////////////////////// + +struct GEMMParams { + const int M; + const int N; + const int K; + + const int lda; + const int ldb; + const int ldd; + + const int tiles_n; + const int tiles_m; + + const size_t batch_stride_a; + const size_t batch_stride_b; + const size_t batch_stride_d; + + const int swizzle_log; + const int gemm_k_iterations_aligned; + + const int batch_ndim; +}; + +struct GEMMSpiltKParams { + const int M; + const int N; + const int K; + + const int lda; + const int ldb; + const int ldc; + + const int tiles_n; + const int tiles_m; + + const int split_k_partitions; + const int split_k_partition_stride; + const int split_k_partition_size; + + const int gemm_k_iterations_aligned; +}; + +struct GEMMAddMMParams { + const int ldc; + const int fdc; + + const size_t batch_stride_c; + + const float alpha; + const float beta; +}; + +// https://github.com/ml-explore/mlx/blob/02efb310cac667bc547d1b96f21596c221f84fe7/mlx/backend/metal/kernels/steel/gemm/loader.h#L1 +/////////////////////////////////////////////////////////////////////////////// +// Loading helper +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size, + short alignment = 1, + short n_reads = (BCOLS * BROWS) / (tgp_size), + short TCOLS = BCOLS / n_reads, + short TROWS = tgp_size / TCOLS> +struct BlockLoader { + STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS; + STEEL_CONST short vec_size = n_reads; + + // Leading dimension for src + const int src_ld; + const int tile_stride; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + const device T* src; + + struct alignas(alignment * sizeof(T)) ReadVector { + uint8_t v[sizeof(T) * vec_size]; + }; + + /* Constructor */ + METAL_FUNC BlockLoader( + const device T* src_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tile_stride(reduction_dim ? BCOLS : BROWS * src_ld), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bi * src_ld + bj) {} + + /* Apply operation to threadgroup without bound checking */ + template + METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = op.apply(dst[i * dst_ld + j]); + } + } + } + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + *((threadgroup ReadVector*)(&dst[i * dst_ld])) = + *((const device ReadVector*)(&src[i * src_ld])); + } + } + + /* Load from device memory into threadgroup memory - with bound checking */ + METAL_FUNC void load_safe(short2 src_tile_dim) const { + src_tile_dim = src_tile_dim - short2(bj, bi); + + // Skip loading if thread has no valid reads + if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + return; + } + + // Use fast thread memory for bound checks + bool tmp_idx[vec_size]; + T tmp_val[vec_size]; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + // Make sure tmp_idx only contains valid indices + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x); + } + + // Read valid indices into tmp_val + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; + } + + // Zero out uneeded values + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); + } + + // Copy values to threadgroup memory + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = tmp_val[j]; + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + src += tile_stride; + } +}; + +// https://github.com/ml-explore/mlx/blob/02efb310cac667bc547d1b96f21596c221f84fe7/mlx/backend/metal/kernels/steel/gemm/transforms.h#L1 +/////////////////////////////////////////////////////////////////////////////// +// Transforms and Epilogues +/////////////////////////////////////////////////////////////////////////////// + +template +struct TransformNone { + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } + + static METAL_FUNC OutT apply(InT x, OutT) { + return static_cast(x); + } +}; + +template +struct TransformAdd { + TransformAdd(const float, const float) {} + + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } + + static METAL_FUNC OutT apply(InT x, OutT c) { + return static_cast(x) + c; + } +}; + +template +struct TransformAxpby { + const float alpha; + const float beta; + + TransformAxpby(const float alpha_, const float beta_) + : alpha(alpha_), beta(beta_) {} + + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } + + METAL_FUNC OutT apply(InT x, OutT c) const { + return static_cast(x * alpha + (beta * c)); + } +}; + +template +struct AccumHelper { + typedef float accum_type; +}; + +struct BlockSwizzle { + static METAL_FUNC int2 + swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) { + const int tid_x = (tid.x) >> swizzle_log; + const int tid_y = + ((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1)); + return int2(tid_x, tid_y); + } +}; + +// https://github.com/ml-explore/mlx/blob/02efb310cac667bc547d1b96f21596c221f84fe7/mlx/backend/metal/kernels/steel/gemm/mma.h#L1 +/////////////////////////////////////////////////////////////////////////////// +// MMA helper +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + typename U, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + short lda_tgp, + short ldb_tgp, + typename AccumType = float, + typename Epilogue = TransformNone> +struct BlockMMA { + // Warp tile simdgroup matrix strides along M + STEEL_CONST short TM_stride = 8 * WM; + // Warp tile simdgroup matrix strides along M + STEEL_CONST short TN_stride = 8 * WN; + + // Warp tile size along M + STEEL_CONST short TM = BM / TM_stride; + // Warp tile size along N + STEEL_CONST short TN = BN / TN_stride; + + // Strides of A, B along reduction axis + STEEL_CONST short simd_stride_a = { + transpose_a ? TM_stride : TM_stride * lda_tgp}; + STEEL_CONST short simd_stride_b = { + transpose_b ? TN_stride * ldb_tgp : TN_stride}; + + // Jump between elements + STEEL_CONST short jump_a = {transpose_a ? lda_tgp : 1}; + STEEL_CONST short jump_b = {transpose_b ? ldb_tgp : 1}; + + STEEL_CONST short tile_stride_a = {transpose_a ? 8 * lda_tgp : 8}; + STEEL_CONST short tile_stride_b = {transpose_b ? 8 : 8 * ldb_tgp}; + + // Simdgroup matrices + simdgroup_matrix Asimd[TM]; + simdgroup_matrix Bsimd[TN]; + simdgroup_matrix results[TM * TN] = { + simdgroup_matrix(0)}; + + // Offsets within threadgroup + const short tm; + const short tn; + + short sm; + short sn; + + short As_offset; + short Bs_offset; + + /* Constructor */ + METAL_FUNC BlockMMA( + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) { + // Determine thread position in simdgroup matrix + short qid = simd_lane_id / 4; + sm = (qid & 4) + (simd_lane_id / 2) % 4; + sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; + + // Determine thread and simdgroup offset + As_offset = + transpose_a ? ((sn)*lda_tgp + (tm + sm)) : ((sn) + (tm + sm) * lda_tgp); + Bs_offset = + transpose_b ? ((tn + sn) * ldb_tgp + (sm)) : ((sm)*ldb_tgp + (tn + sn)); + } + + /* (BM, BK) X (BK, BN) multiply accumulate function */ + METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) { + // Adjust for simdgroup and thread location + As += As_offset; + Bs += Bs_offset; + + // Iterate over BK in blocks of 8 + STEEL_PRAGMA_UNROLL + for (short kk = 0; kk < BK; kk += 8) { + simdgroup_barrier(mem_flags::mem_none); + + // Load elements from threadgroup A as simdgroup matrices + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + Asimd[i].thread_elements()[0] = + static_cast(As[i * simd_stride_a + 0]); + Asimd[i].thread_elements()[1] = + static_cast(As[i * simd_stride_a + jump_a]); + } + + simdgroup_barrier(mem_flags::mem_none); + + // Load elements from threadgroup B as simdgroup matrices + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + Bsimd[j].thread_elements()[0] = + static_cast(Bs[j * simd_stride_b + 0]); + Bsimd[j].thread_elements()[1] = + static_cast(Bs[j * simd_stride_b + jump_b]); + } + + simdgroup_barrier(mem_flags::mem_none); + + // Multiply and accumulate into result simdgroup matrices + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + short j_serp = (i % 2) ? (TN - 1 - j) : j; + + simdgroup_multiply_accumulate( + results[i * TN + j_serp], + Asimd[i], + Bsimd[j_serp], + results[i * TN + j_serp]); + } + } + + // Progress to next simdgroup tile + As += tile_stride_a; + Bs += tile_stride_b; + } + } + + /* Store results from simdgroup_matrix results into device memory */ + METAL_FUNC void store_result(device U* D, const int ldd) const { + // Adjust for simdgroup and thread location + D += (sm + tm) * ldd + tn + sn; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = results[i * TN + j].thread_elements(); + int offset = (i * TM_stride) * ldd + (j * TN_stride); + + // Apply epilogue + U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])}; + + // Write out D + D[offset] = outs[0]; + D[offset + 1] = outs[1]; + } + } + } + + METAL_FUNC void + store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) const { + // Adjust for simdgroup and thread location + D += (sm + tm) * ldd + (tn + sn); + dst_tile_dims -= short2(tn + sn, sm + tm); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + STEEL_PRAGMA_UNROLL + for (int i = 0; i < TM; i++) { + if (i * TM_stride < dst_tile_dims.y) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = results[i * TN + j].thread_elements(); + int offset = (i * TM_stride) * ldd + (j * TN_stride); + + // Apply epilogue and output C + if (j * TN_stride < dst_tile_dims.x) { + D[offset] = Epilogue::apply(accum[0]); + } + + if (j * TN_stride + 1 < dst_tile_dims.x) { + D[offset + 1] = Epilogue::apply(accum[1]); + } + } + } + } + } + + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) { + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread auto& accum = results[i * TN + j].thread_elements(); + + // Apply epilogue + accum[0] = epilogue_op.apply(accum[0]); + accum[1] = epilogue_op.apply(accum[1]); + } + } + } + + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue( + const device U* C, + const int ldc, + const int fdc, + thread const BinaryEpilogue& epilogue_op) { + // Adjust for simdgroup and thread location + C += (sm + tm) * ldc + (tn + sn) * fdc; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread auto& accum = results[i * TN + j].thread_elements(); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + + // Apply epilogue + accum[0] = epilogue_op.apply(accum[0], C[offset_c]); + accum[1] = epilogue_op.apply(accum[1], C[offset_c + fdc]); + } + } + } + + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue_safe( + const device U* C, + const int ldc, + const int fdc, + short2 dst_tile_dims, + thread const BinaryEpilogue& epilogue_op) { + // Adjust for simdgroup and thread location + C += (sm + tm) * ldc + (tn + sn) * fdc; + dst_tile_dims -= short2(tn + sn, sm + tm); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread auto& accum = results[i * TN + j].thread_elements(); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + + // Read C + U c_elems[2] = {0}; + + if ((j * TN_stride + 1) < dst_tile_dims.x) { + c_elems[0] = C[offset_c]; + c_elems[1] = C[offset_c + fdc]; + } else if ((j * TN_stride) < dst_tile_dims.x) { + c_elems[0] = C[offset_c]; + } + + // Apply epilogue + accum[0] = epilogue_op.apply(accum[0], c_elems[0]); + accum[1] = epilogue_op.apply(accum[1], c_elems[1]); + } + } + } + + /* Store results from simdgroup_matrix results into device memory */ + METAL_FUNC void store_result( + device U* D, + const int ldd, + const device U* C, + const int ldc, + const int fdc, + thread const Epilogue& epilogue_op) const { + // Adjust for simdgroup and thread location + C += (sm + tm) * ldc + (tn + sn) * fdc; + D += (sm + tm) * ldd + tn + sn; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = results[i * TN + j].thread_elements(); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + int offset_d = (i * TM_stride) * ldd + (j * TN_stride); + + // Apply epilogue + U outs[2] = { + epilogue_op.apply(accum[0], C[offset_c]), + epilogue_op.apply(accum[1], C[offset_c + fdc])}; + + // Write out D + D[offset_d] = outs[0]; + D[offset_d + 1] = outs[1]; + } + } + } + + METAL_FUNC void store_result_safe( + device U* D, + const int ldd, + const device U* C, + const int ldc, + const int fdc, + short2 dst_tile_dims, + thread const Epilogue& epilogue_op) const { + // Adjust for simdgroup and thread location + C += (sm + tm) * ldc + (tn + sn) * fdc; + D += (sm + tm) * ldd + tn + sn; + dst_tile_dims -= short2(tn + sn, sm + tm); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + STEEL_PRAGMA_UNROLL + for (int i = 0; i < TM; i++) { + if (i * TM_stride < dst_tile_dims.y) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = results[i * TN + j].thread_elements(); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + int offset_d = (i * TM_stride) * ldd + (j * TN_stride); + + // Apply epilogue and output C + if (j * TN_stride < dst_tile_dims.x) { + D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]); + } + + if (j * TN_stride + 1 < dst_tile_dims.x) { + D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]); + } + } + } + } + } +}; + +// https://github.com/ml-explore/mlx/blob/02efb310cac667bc547d1b96f21596c221f84fe7/mlx/backend/metal/kernels/steel/gemm/gemm.h#L1 +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernel class +/////////////////////////////////////////////////////////////////////////////// + +template +struct LoopAlignment {}; + +template < + typename T, + typename U, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + bool MN_aligned, + bool K_aligned, + typename AccumType = typename AccumHelper::accum_type, + typename Epilogue = TransformNone> +struct GEMMKernel { + STEEL_CONST short tgp_padding_a = 16 / sizeof(T); + STEEL_CONST short tgp_padding_b = 16 / sizeof(T); + STEEL_CONST short tgp_mem_size_a = + transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a); + STEEL_CONST short tgp_mem_size_b = + transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b); + STEEL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b; + + STEEL_CONST short tgp_size = WM * WN * 32; + + using loader_a_t = BlockLoader< + T, + transpose_a ? BK : BM, + transpose_a ? BM : BK, + transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a, + !transpose_a, + tgp_size>; + using loader_b_t = BlockLoader< + T, + transpose_b ? BN : BK, + transpose_b ? BK : BN, + transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b, + transpose_b, + tgp_size>; + using mma_t = BlockMMA< + T, + U, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a, + transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b, + AccumType, + Epilogue>; + + /* Main kernel function */ + template + static METAL_FUNC void gemm_loop( + threadgroup T* As [[threadgroup(0)]], + threadgroup T* Bs [[threadgroup(1)]], + const int gemm_k_iterations, + thread loader_a_t& loader_a, + thread loader_b_t& loader_b, + thread mma_t& mma_op, + thread const short& tgp_bm, + thread const short& tgp_bn, + thread const short& lbk, + LoopAlignment l = {}) { + // Appease the compiler + (void)l; + + short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm); + + short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK); + + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + if (M_aligned) { + loader_a.load_unsafe(); + } else { + loader_a.load_safe(tile_dims_A); + } + + if (N_aligned) { + loader_b.load_unsafe(); + } else { + loader_b.load_safe(tile_dims_B); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + if (!K_aligned_) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + short2 tile_dims_A_last = + transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm); + short2 tile_dims_B_last = + transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk); + + loader_a.load_safe(tile_dims_A_last); + loader_b.load_safe(tile_dims_B_last); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(As, Bs); + } + } + + /* Main kernel function */ + static METAL_FUNC void run( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + device U* D [[buffer(2)]], + const constant GEMMParams* params [[buffer(3)]], + threadgroup T* As [[threadgroup(0)]], + threadgroup T* Bs [[threadgroup(1)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + // Pacifying compiler + (void)lid; + + const int tid_y = ((tid.y) << params->swizzle_log) + + ((tid.x) & ((1 << params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> params->swizzle_log; + + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } + + threadgroup_barrier(mem_flags::mem_none); + + // Find block in A, B, C + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + D += c_row_long * params->ldd + c_col_long; + + // Prepare threadgroup loading operations + thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); + thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + int gemm_k_iterations = params->gemm_k_iterations_aligned; + + /////////////////////////////////////////////////////////////////////////////// + // MNK aligned loop + if (MN_aligned) { + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + threadgroup_barrier(mem_flags::mem_none); + + // Loop tail + if (!K_aligned) { + int lbk = params->K - params->gemm_k_iterations_aligned * BK; + short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM); + short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk); + + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(As, Bs); + } + + // Store results to device memory + mma_op.store_result(D, params->ldd); + return; + + } + /////////////////////////////////////////////////////////////////////////////// + // MN unaligned loop + else { // Loop over K - unaligned case + short tgp_bm = min(BM, params->M - c_row); + short tgp_bn = min(BN, params->N - c_col); + short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK; + + if (tgp_bm == BM && tgp_bn == BN) { + gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result(D, params->ldd); + return; + + } else if (tgp_bn == BN) { + gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + return; + + } else if (tgp_bm == BM) { + gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + return; + + } else { + gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + return; + } + } + } +}; + +// utils.h +/////////////////////////////////////////////////////////////////////////////// +// Single Array with generic dims + +template +METAL_FUNC stride_t elem_to_loc( + uint elem, + device const int* shape, + device const stride_t* strides, + int ndim) { + stride_t loc = 0; + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { + loc += (elem % shape[i]) * strides[i]; + elem /= shape[i]; + } + return loc; +} + +template +METAL_FUNC stride_t elem_to_loc( + uint elem, + constant const int* shape, + constant const stride_t* strides, + int ndim) { + stride_t loc = 0; + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { + loc += (elem % shape[i]) * strides[i]; + elem /= shape[i]; + } + return loc; +} + +template +METAL_FUNC stride_t elem_to_loc( + stride_t elem, + device const int* shape, + device const stride_t* strides, + int ndim) { + stride_t loc = 0; + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { + loc += (elem % shape[i]) * strides[i]; + elem /= shape[i]; + } + return loc; +} + +template +METAL_FUNC stride_t elem_to_loc( + stride_t elem, + constant const int* shape, + constant const stride_t* strides, + int ndim) { + stride_t loc = 0; + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { + loc += (elem % shape[i]) * strides[i]; + elem /= shape[i]; + } + return loc; +} + +// Non templated version to handle arbitrary dims +template +METAL_FUNC stride_t elem_to_loc( + uint3 elem, + constant const int* shape, + constant const stride_t* strides, + int ndim) { + stride_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2]; + for (int d = ndim - 3; d >= 0; --d) { + loc += (elem.z % shape[d]) * strides[d]; + elem.z /= shape[d]; + } + return loc; +} + + +METAL_FUNC ulong2 elem_to_loc_broadcast( + uint elem, + constant const int* shape, + constant const size_t* a_strides, + constant const size_t* b_strides, + int ndim) { + ulong loc_a{0}; + ulong loc_b{0}; + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { + int pos_in_dim = (elem % shape[i]); + elem /= shape[i]; + loc_a += pos_in_dim * a_strides[i]; + loc_b += pos_in_dim * b_strides[i]; + } + return ulong2(loc_a, loc_b); +} + +METAL_FUNC ulong3 elem_to_loc_broadcast( + uint elem, + constant const int* shape, + constant const size_t* a_strides, + constant const size_t* b_strides, + constant const size_t* c_strides, + int ndim) { + ulong loc_a{0}; + ulong loc_b{0}; + ulong loc_c{0}; + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { + int pos_in_dim = (elem % shape[i]); + elem /= shape[i]; + loc_a += pos_in_dim * a_strides[i]; + loc_b += pos_in_dim * b_strides[i]; + loc_c += pos_in_dim * c_strides[i]; + } + return ulong3(loc_a, loc_b, loc_c); +} + + +// https://github.com/ml-explore/mlx/blob/02efb310cac667bc547d1b96f21596c221f84fe7/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h#L1 +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernels +/////////////////////////////////////////////////////////////////////////////// + +constant bool has_batch [[function_constant(10)]]; + +constant bool use_out_source [[function_constant(100)]]; +constant bool do_axpby [[function_constant(110)]]; + +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; +constant bool align_K [[function_constant(202)]]; + +constant bool do_gather [[function_constant(300)]]; + +constant bool gather_bias = do_gather && use_out_source; + +// clang-format off +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + const device T* C [[buffer(2), function_constant(use_out_source)]], + device T* D [[buffer(3)]], + const constant GEMMParams* params [[buffer(4)]], + const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], + const constant int* batch_shape [[buffer(6)]], + const constant size_t* batch_strides [[buffer(7)]], + const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]], + const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]], + const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]], + const constant int* operand_shape [[buffer(13), function_constant(do_gather)]], + const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]], + const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on + // Pacifying compiler + (void)lid; + + using gemm_kernel = GEMMKernel< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + true, + true, + AccumType>; + + using loader_a_t = typename gemm_kernel::loader_a_t; + using loader_b_t = typename gemm_kernel::loader_b_t; + using mma_t = typename gemm_kernel::mma_t; + + // Find block + const int tid_y = ((tid.y) << params->swizzle_log) + + ((tid.x) & ((1 << params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> params->swizzle_log; + + // Exit early if out of bounds + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } + + // Adjust for batch + + // Handle gather + if (do_gather) { + // Read indices + uint32_t indx_A, indx_B, indx_C; + + if (has_batch) { + const constant size_t* indx_A_bstrides = batch_strides; + const constant size_t* indx_B_bstrides = + batch_strides + params->batch_ndim; + + ulong2 indx_offsets = elem_to_loc_broadcast( + tid.z, + batch_shape, + indx_A_bstrides, + indx_B_bstrides, + params->batch_ndim); + indx_A = lhs_indices[indx_offsets.x]; + indx_B = rhs_indices[indx_offsets.y]; + + if (use_out_source) { + const constant size_t* indx_C_bstrides = + indx_B_bstrides + params->batch_ndim; + auto indx_offset_C = elem_to_loc( + tid.z, batch_shape, indx_C_bstrides, params->batch_ndim); + indx_C = C_indices[indx_offset_C]; + } + } else { + indx_A = lhs_indices[params->batch_stride_a * tid.z]; + indx_B = rhs_indices[params->batch_stride_b * tid.z]; + + if (use_out_source) { + indx_C = C_indices[addmm_params->batch_stride_c * tid.z]; + } + } + + // Translate indices to offsets + int batch_ndim_A = operand_batch_ndim.x; + const constant int* batch_shape_A = operand_shape; + const constant size_t* batch_strides_A = operand_strides; + A += elem_to_loc(indx_A, batch_shape_A, batch_strides_A, batch_ndim_A); + + int batch_ndim_B = operand_batch_ndim.y; + const constant int* batch_shape_B = batch_shape_A + batch_ndim_A; + const constant size_t* batch_strides_B = batch_strides_A + batch_ndim_A; + B += elem_to_loc(indx_B, batch_shape_B, batch_strides_B, batch_ndim_B); + + if (use_out_source) { + int batch_ndim_C = operand_batch_ndim.z; + const constant int* batch_shape_C = batch_shape_B + batch_ndim_B; + const constant size_t* batch_strides_C = batch_strides_B + batch_ndim_B; + C += elem_to_loc(indx_C, batch_shape_C, batch_strides_C, batch_ndim_C); + } + + } + + // Handle regular batch + else { + if (has_batch) { + const constant size_t* A_bstrides = batch_strides; + const constant size_t* B_bstrides = batch_strides + params->batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim); + + A += batch_offsets.x; + B += batch_offsets.y; + + if (use_out_source) { + const constant size_t* C_bstrides = B_bstrides + params->batch_ndim; + C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim); + } + } else { + A += params->batch_stride_a * tid.z; + B += params->batch_stride_b * tid.z; + + if (use_out_source) { + C += addmm_params->batch_stride_c * tid.z; + } + } + } + + D += params->batch_stride_d * tid.z; + + // Prepare threadgroup memory + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + + threadgroup_barrier(mem_flags::mem_none); + + // Find block in A, B, C + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + D += c_row_long * params->ldd + c_col_long; + + if (use_out_source) { + C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc; + } + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + // Prepare threadgroup loading operations + thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); + thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); + + // Prepare threadgroup bounds + const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row)); + const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col)); + + // Prepare iterations + int gemm_k_iterations = params->gemm_k_iterations_aligned; + + // Do unaligned K iterations first + if (!align_K) { + const int k_last = params->gemm_k_iterations_aligned * BK; + const int k_remain = params->K - k_last; + const size_t k_jump_a = + transpose_a ? params->lda * size_t(k_last) : size_t(k_last); + const size_t k_jump_b = + transpose_b ? size_t(k_last) : params->ldb * size_t(k_last); + + // Move loader source ahead to end + loader_a.src += k_jump_a; + loader_b.src += k_jump_b; + + // Load tile + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Do matmul + mma_op.mma(As, Bs); + + // Reset source back to start + loader_a.src -= k_jump_a; + loader_b.src -= k_jump_b; + } + + const TransformAdd epilogue_op_add( + addmm_params->alpha, addmm_params->beta); + const TransformAxpby epilogue_op_axpby( + addmm_params->alpha, addmm_params->beta); + + /////////////////////////////////////////////////////////////////////////////// + // MNK aligned loop + if (align_M && align_N) { + // Do gemm + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + threadgroup_barrier(mem_flags::mem_none); + + // Do epilogue + if (use_out_source) { + if (do_axpby) { + mma_op.apply_epilogue( + C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby); + } else { + mma_op.apply_epilogue( + C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add); + } + } + + // Store results to device memory + return mma_op.store_result(D, params->ldd); + + } + /////////////////////////////////////////////////////////////////////////////// + // MN unaligned loop + else { // Loop over K - unaligned case + const int leftover_bk = 0; + + if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { + // Do gemm + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + + // Do epilogue + if (use_out_source) { + if (do_axpby) { + mma_op.apply_epilogue( + C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby); + } else { + mma_op.apply_epilogue( + C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add); + } + } + + // Store results to device memory + return mma_op.store_result(D, params->ldd); + + } else if (align_N || tgp_bn == BN) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + + // Do epilogue + if (use_out_source) { + if (do_axpby) { + mma_op.apply_epilogue_safe( + C, + addmm_params->ldc, + addmm_params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op_axpby); + } else { + mma_op.apply_epilogue_safe( + C, + addmm_params->ldc, + addmm_params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op_add); + } + } + + // Store results to device memory + return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + + } else if (align_M || tgp_bm == BM) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + + // Do epilogue + if (use_out_source) { + if (do_axpby) { + mma_op.apply_epilogue_safe( + C, + addmm_params->ldc, + addmm_params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op_axpby); + } else { + mma_op.apply_epilogue_safe( + C, + addmm_params->ldc, + addmm_params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op_add); + } + } + + // Store results to device memory + return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + + } else { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + + // Do epilogue + if (use_out_source) { + if (do_axpby) { + mma_op.apply_epilogue_safe( + C, + addmm_params->ldc, + addmm_params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op_axpby); + } else { + mma_op.apply_epilogue_safe( + C, + addmm_params->ldc, + addmm_params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op_add); + } + } + + // Store results to device memory + return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + } + } +} + +#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + template [[host_name("gemm_" #tname "_" #iname "_" #oname "_" #bm "_" #bn "_" #bk "_" #wm "_" #wn)]] \ + [[kernel]] void gemm( \ + const device itype *A [[buffer(0)]], \ + const device itype *B [[buffer(1)]], \ + const device itype *C [[buffer(2), function_constant(use_out_source)]], \ + device itype *D [[buffer(3)]], \ + const constant GEMMParams* params [[buffer(4)]], \ + const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], \ + const constant int* batch_shape [[buffer(6)]], \ + const constant size_t* batch_strides [[buffer(7)]], \ + const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]], \ + const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]], \ + const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]], \ + const constant int* operand_shape [[buffer(13), function_constant(do_gather)]], \ + const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]], \ + const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]]); + +#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) + +instantiate_gemm_transpose_helper(f32, float, f32, float, 32, 32, 16, 2, 2) +instantiate_gemm_transpose_helper(f16, half, f16, half, 32, 32, 16, 2, 2) +#if defined(__HAVE_BFLOAT__) +instantiate_gemm_transpose_helper(bf16, bfloat, bf16, bfloat, 32, 32, 16, 2, 2) +#endif diff --git a/mcandle/mcandle-metal-kernels/src/quantized.metal b/mcandle/mcandle-metal-kernels/src/quantized.metal new file mode 100644 index 000000000..1feeb0e80 --- /dev/null +++ b/mcandle/mcandle-metal-kernels/src/quantized.metal @@ -0,0 +1,7741 @@ +#include + +using namespace metal; + +#define MAX(x, y) ((x) > (y) ? (x) : (y)) +#define MIN(x, y) ((x) < (y) ? (x) : (y)) +#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; } + +#define N_SIMDWIDTH 32 // assuming SIMD group size is 32 + +#if defined(__HAVE_BFLOAT__) +typedef matrix bfloat4x4; +#endif + +// QK = number of values after dequantization +// QK_K = super-block size + +#define QK_K 256 +#define K_SCALE_SIZE 12 + +#define QK4_0 32 +typedef struct { + half d; // delta + uint8_t qs[QK4_0 / 2]; // nibbles / quants +} block_q4_0; +static_assert(sizeof(block_q4_0) == sizeof(half) + QK4_0 / 2, "wrong q4_0 block size/padding"); + +#define QK4_1 32 +typedef struct { + union { + struct { + half d; // delta + half m; // min + }; + half2 dm; + }; + uint8_t qs[QK4_1 / 2]; // nibbles / quants +} block_q4_1; +static_assert(sizeof(block_q4_1) == 2 * sizeof(half) + QK4_1 / 2, "wrong q4_1 block size/padding"); + +#define QK5_0 32 +typedef struct { + half d; // delta + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_0 / 2]; // nibbles / quants +} block_q5_0; +static_assert(sizeof(block_q5_0) == sizeof(half) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding"); + +#define QK5_1 32 +typedef struct { + union { + struct { + half d; // delta + half m; // min + }; + half2 dm; + }; + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_1 / 2]; // nibbles / quants +} block_q5_1; +static_assert(sizeof(block_q5_1) == 2 * sizeof(half) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding"); + +#define QK8_0 32 +typedef struct { + half d; // delta + int8_t qs[QK8_0]; // quants +} block_q8_0; +static_assert(sizeof(block_q8_0) == sizeof(half) + QK8_0, "wrong q8_0 block size/padding"); + +#define QK8_1 32 +typedef struct { + union { + struct { + half d; // delta + half s; // d * sum(qs[i]) + }; + half2 ds; + }; + int8_t qs[QK8_1]; // quants +} block_q8_1; +static_assert(sizeof(block_q8_1) == 2*sizeof(half) + QK8_1, "wrong q8_1 block size/padding"); + +typedef struct { + half d[4]; // deltas for 4 q4_0 blocks + uint8_t qs[QK4_0 * 2]; // nibbles / quants for 4 q4_0 blocks +} block_q4_0x4; +static_assert(sizeof(block_q4_0x4) == 4 * sizeof(half) + QK4_0 * 2, "wrong q4_0x4 block size/padding"); + +typedef struct { + half d[8]; // deltas for 8 q4_0 blocks + uint8_t qs[QK4_0 * 4]; // nibbles / quants for 8 q4_0 blocks +} block_q4_0x8; +static_assert(sizeof(block_q4_0x8) == 8 * sizeof(half) + QK4_0 * 4, "wrong q4_0x8 block size/padding"); + +typedef struct { + half d[4]; // deltas for 4 q8_0 blocks + int8_t qs[QK8_0 * 4]; // quants for 4 q8_0 blocks +} block_q8_0x4; +static_assert(sizeof(block_q8_0x4) == 4 * sizeof(half) + QK8_0 * 4, "wrong q8_0x4 block size/padding"); + +typedef struct { + half d[8]; // deltas for 8 q8_0 blocks + int8_t qs[QK8_0 * 8]; // quants for 8 q8_0 blocks +} block_q8_0x8; +static_assert(sizeof(block_q8_0x8) == 8 * sizeof(half) + QK8_0 * 8, "wrong q8_0x8 block size/padding"); + +// +// Ternary quantization +// + +// 1.6875 bpw +typedef struct { + uint8_t qs[(QK_K - 4 * QK_K / 64) / 5]; // 5 elements per byte (3^5 = 243 < 256) + uint8_t qh[QK_K/64]; // 4 elements per byte + half d; +} block_tq1_0; +static_assert(sizeof(block_tq1_0) == sizeof(half) + QK_K / 64 + (QK_K - 4 * QK_K / 64) / 5, "wrong tq1_0 block size/padding"); + +// 2.0625 bpw +typedef struct { + uint8_t qs[QK_K/4]; // 2 bits per element + half d; +} block_tq2_0; +static_assert(sizeof(block_tq2_0) == sizeof(half) + QK_K / 4, "wrong tq2_0 block size/padding"); + +// +// Super-block quantization structures +// + +// 2-bit quantization +// weight is represented as x = a * q + b +// 16 blocks of 16 elements each +// Effectively 2.625 bits per weight +typedef struct { + uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits + uint8_t qs[QK_K/4]; // quants + union { + struct { + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins + }; + half2 dm; + }; +} block_q2_K; +static_assert(sizeof(block_q2_K) == 2*sizeof(half) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding"); + +// 3-bit quantization +// weight is represented as x = a * q +// 16 blocks of 16 elements each +// Effectively 3.4375 bits per weight +typedef struct { + uint8_t hmask[QK_K/8]; // quants - high bit + uint8_t qs[QK_K/4]; // quants - low 2 bits + uint8_t scales[12]; // scales, quantized with 6 bits + half d; // super-block scale +} block_q3_K; +static_assert(sizeof(block_q3_K) == sizeof(half) + QK_K / 4 + QK_K / 8 + 12, "wrong q3_K block size/padding"); + +// 4-bit quantization +// 8 blocks of 32 elements each +// weight is represented as x = a * q + b +// Effectively 4.5 bits per weight +typedef struct { + union { + struct { + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins + }; + half2 dm; + }; + uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits + uint8_t qs[QK_K/2]; // 4--bit quants +} block_q4_K; +static_assert(sizeof(block_q4_K) == 2*sizeof(half) + K_SCALE_SIZE + QK_K/2, "wrong q4_K block size/padding"); + +// 5-bit quantization +// 8 blocks of 32 elements each +// weight is represented as x = a * q + b +// Effectively 5.5 bits per weight +typedef struct { + union { + struct { + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins + }; + half2 dm; + }; + uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits + uint8_t qh[QK_K/8]; // quants, high bit + uint8_t qs[QK_K/2]; // quants, low 4 bits +} block_q5_K; +static_assert(sizeof(block_q5_K) == 2*sizeof(half) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding"); + +// 6-bit quantization +// weight is represented as x = a * q +// 16 blocks of 16 elements each +// Effectively 6.5625 bits per weight +typedef struct { + uint8_t ql[QK_K/2]; // quants, lower 4 bits + uint8_t qh[QK_K/4]; // quants, upper 2 bits + int8_t scales[QK_K/16]; // scales, quantized with 8 bits + half d; // super-block scale +} block_q6_K; +static_assert(sizeof(block_q6_K) == sizeof(half) + QK_K / 16 + 3*QK_K/4, "wrong q6_K block size/padding"); + +// This is only used for intermediate quantization and dot products +typedef struct { + float d; // delta + int8_t qs[QK_K]; // quants + int16_t bsums[QK_K/16]; // sum of quants in groups of 16 +} block_q8_K; +static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding"); + +// (Almost) "true" 2-bit quantization. +// Due to the need to use blocks as per ggml design, it ends up using +// 2.0625 bpw because of the 16-bit scale for each block of 256. +typedef struct { + half d; + uint16_t qs[QK_K/8]; +} block_iq2_xxs; +static_assert(sizeof(block_iq2_xxs) == sizeof(half) + QK_K/8*sizeof(uint16_t), "wrong iq2_xxs block size/padding"); + +// 2.3125 bpw quants +typedef struct { + half d; + uint16_t qs[QK_K/8]; + uint8_t scales[QK_K/32]; +} block_iq2_xs; +static_assert(sizeof(block_iq2_xs) == sizeof(half) + QK_K/8*sizeof(uint16_t) + QK_K/32, "wrong iq2_xs block size/padding"); + +// 2.5625 bpw quants +typedef struct { + half d; + uint8_t qs[QK_K/4]; + uint8_t qh[QK_K/32]; + uint8_t scales[QK_K/32]; +} block_iq2_s; +static_assert(sizeof(block_iq2_s) == sizeof(half) + QK_K/4 + QK_K/16, "wrong iq2_s block size/padding"); + +// (Almost) "true" 3-bit quantization. +// Due to the need to use blocks as per ggml design, it ends up using +// 3.0625 bpw because of the 16-bit scale for each block of 256. +typedef struct { + half d; + uint8_t qs[3*QK_K/8]; +} block_iq3_xxs; +static_assert(sizeof(block_iq3_xxs) == sizeof(half) + 3*(QK_K/8), "wrong iq3_xxs block size/padding"); + +// 3.4375 bpw +#define IQ3S_N_SCALE QK_K/64 +typedef struct { + half d; + uint8_t qs[QK_K/4]; + uint8_t qh[QK_K/32]; + uint8_t signs[QK_K/8]; + uint8_t scales[IQ3S_N_SCALE]; +} block_iq3_s; +static_assert(sizeof(block_iq3_s) == sizeof(half) + 13*(QK_K/32) + IQ3S_N_SCALE, "wrong iq3_s block size/padding"); + +// 1.5625 bpw +typedef struct { + half d; + uint8_t qs[QK_K/8]; + uint16_t qh[QK_K/32]; +} block_iq1_s; +static_assert(sizeof(block_iq1_s) == sizeof(half) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding"); + +// 1.75 bpw +typedef struct { + uint8_t qs[QK_K/8]; // grid index, low 8 bits + uint8_t qh[QK_K/16]; // grid index, high 3 bits + grid shift bit (for two groups of 8) + uint8_t scales[QK_K/32]; // 3-bit block scales (4-bit if QK_K == 64) +} block_iq1_m; +static_assert(sizeof(block_iq1_m) == QK_K/8 + QK_K/16 + QK_K/32, "wrong iq1_m block size/padding"); + +// Used by IQ1_M quants +typedef union { + half f16; + uint16_t u16; +} iq1m_scale_t; + +// Non-linear quants +#define QK4_NL 32 +typedef struct { + half d; + uint8_t qs[QK4_NL/2]; +} block_iq4_nl; +static_assert(sizeof(block_iq4_nl) == sizeof(half) + QK4_NL/2, "wrong iq4_nl block size/padding"); + +typedef struct { + half d; + uint16_t scales_h; + uint8_t scales_l[QK_K/64]; + uint8_t qs[QK_K/2]; +} block_iq4_xs; +static_assert(sizeof(block_iq4_xs) == sizeof(half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding"); + +#define GGML_TABLE_BEGIN(type, name, size) static const constant type name[size] = { +#define GGML_TABLE_END() }; + +GGML_TABLE_BEGIN(uint8_t, kmask_iq2xs, 8) + 1, 2, 4, 8, 16, 32, 64, 128 +GGML_TABLE_END() + +GGML_TABLE_BEGIN(uint8_t, ksigns_iq2xs, 128) + 0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15, + 144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159, + 160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175, + 48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63, + 192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207, + 80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95, + 96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111, + 240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255, +GGML_TABLE_END() + +//#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics +GGML_TABLE_BEGIN(uint64_t, ksigns64, 128) + 0x0000000000000000, 0xff000000000000ff, 0xff0000000000ff00, 0x000000000000ffff, + 0xff00000000ff0000, 0x0000000000ff00ff, 0x0000000000ffff00, 0xff00000000ffffff, + 0xff000000ff000000, 0x00000000ff0000ff, 0x00000000ff00ff00, 0xff000000ff00ffff, + 0x00000000ffff0000, 0xff000000ffff00ff, 0xff000000ffffff00, 0x00000000ffffffff, + 0xff0000ff00000000, 0x000000ff000000ff, 0x000000ff0000ff00, 0xff0000ff0000ffff, + 0x000000ff00ff0000, 0xff0000ff00ff00ff, 0xff0000ff00ffff00, 0x000000ff00ffffff, + 0x000000ffff000000, 0xff0000ffff0000ff, 0xff0000ffff00ff00, 0x000000ffff00ffff, + 0xff0000ffffff0000, 0x000000ffffff00ff, 0x000000ffffffff00, 0xff0000ffffffffff, + 0xff00ff0000000000, 0x0000ff00000000ff, 0x0000ff000000ff00, 0xff00ff000000ffff, + 0x0000ff0000ff0000, 0xff00ff0000ff00ff, 0xff00ff0000ffff00, 0x0000ff0000ffffff, + 0x0000ff00ff000000, 0xff00ff00ff0000ff, 0xff00ff00ff00ff00, 0x0000ff00ff00ffff, + 0xff00ff00ffff0000, 0x0000ff00ffff00ff, 0x0000ff00ffffff00, 0xff00ff00ffffffff, + 0x0000ffff00000000, 0xff00ffff000000ff, 0xff00ffff0000ff00, 0x0000ffff0000ffff, + 0xff00ffff00ff0000, 0x0000ffff00ff00ff, 0x0000ffff00ffff00, 0xff00ffff00ffffff, + 0xff00ffffff000000, 0x0000ffffff0000ff, 0x0000ffffff00ff00, 0xff00ffffff00ffff, + 0x0000ffffffff0000, 0xff00ffffffff00ff, 0xff00ffffffffff00, 0x0000ffffffffffff, + 0xffff000000000000, 0x00ff0000000000ff, 0x00ff00000000ff00, 0xffff00000000ffff, + 0x00ff000000ff0000, 0xffff000000ff00ff, 0xffff000000ffff00, 0x00ff000000ffffff, + 0x00ff0000ff000000, 0xffff0000ff0000ff, 0xffff0000ff00ff00, 0x00ff0000ff00ffff, + 0xffff0000ffff0000, 0x00ff0000ffff00ff, 0x00ff0000ffffff00, 0xffff0000ffffffff, + 0x00ff00ff00000000, 0xffff00ff000000ff, 0xffff00ff0000ff00, 0x00ff00ff0000ffff, + 0xffff00ff00ff0000, 0x00ff00ff00ff00ff, 0x00ff00ff00ffff00, 0xffff00ff00ffffff, + 0xffff00ffff000000, 0x00ff00ffff0000ff, 0x00ff00ffff00ff00, 0xffff00ffff00ffff, + 0x00ff00ffffff0000, 0xffff00ffffff00ff, 0xffff00ffffffff00, 0x00ff00ffffffffff, + 0x00ffff0000000000, 0xffffff00000000ff, 0xffffff000000ff00, 0x00ffff000000ffff, + 0xffffff0000ff0000, 0x00ffff0000ff00ff, 0x00ffff0000ffff00, 0xffffff0000ffffff, + 0xffffff00ff000000, 0x00ffff00ff0000ff, 0x00ffff00ff00ff00, 0xffffff00ff00ffff, + 0x00ffff00ffff0000, 0xffffff00ffff00ff, 0xffffff00ffffff00, 0x00ffff00ffffffff, + 0xffffffff00000000, 0x00ffffff000000ff, 0x00ffffff0000ff00, 0xffffffff0000ffff, + 0x00ffffff00ff0000, 0xffffffff00ff00ff, 0xffffffff00ffff00, 0x00ffffff00ffffff, + 0x00ffffffff000000, 0xffffffffff0000ff, 0xffffffffff00ff00, 0x00ffffffff00ffff, + 0xffffffffffff0000, 0x00ffffffffff00ff, 0x00ffffffffffff00, 0xffffffffffffffff, +GGML_TABLE_END() +//#endif + + +GGML_TABLE_BEGIN(uint64_t, iq2xxs_grid, 256) + 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08, + 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808, + 0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819, + 0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819, + 0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b, + 0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808, + 0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08, + 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b, + 0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819, + 0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08, + 0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808, + 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08, + 0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808, + 0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808, + 0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919, + 0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819, + 0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08, + 0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908, + 0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819, + 0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808, + 0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808, + 0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908, + 0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808, + 0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08, + 0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819, + 0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819, + 0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819, + 0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908, + 0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19, + 0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819, + 0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b, + 0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808, + 0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908, + 0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08, + 0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08, + 0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908, + 0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819, + 0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808, + 0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808, + 0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19, + 0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819, + 0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, + 0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b, + 0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08, + 0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808, + 0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908, + 0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b, + 0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819, + 0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08, + 0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08, + 0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808, + 0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b, + 0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b, + 0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908, + 0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819, + 0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808, + 0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908, + 0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b, + 0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808, + 0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b, + 0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b, + 0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808, + 0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19, + 0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908, +GGML_TABLE_END() + +GGML_TABLE_BEGIN(uint64_t, iq2xs_grid, 512) + 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08, + 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b, + 0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919, + 0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b, + 0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919, + 0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808, + 0x080808082b08082b, 0x080808082b081919, 0x080808082b082b08, 0x080808082b190819, + 0x080808082b191908, 0x080808082b192b19, 0x080808082b2b0808, 0x0808081908080819, + 0x0808081908081908, 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808, + 0x080808190819082b, 0x0808081908191919, 0x0808081908192b08, 0x0808081908192b2b, + 0x08080819082b0819, 0x08080819082b1908, 0x0808081919080808, 0x080808191908082b, + 0x0808081919081919, 0x0808081919082b08, 0x0808081919190819, 0x0808081919191908, + 0x08080819192b0808, 0x08080819192b2b08, 0x080808192b080819, 0x080808192b081908, + 0x080808192b190808, 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b08081919, + 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, 0x0808082b082b0808, + 0x0808082b19080819, 0x0808082b19081908, 0x0808082b19190808, 0x0808082b19191919, + 0x0808082b2b080808, 0x0808082b2b082b2b, 0x0808190808080819, 0x0808190808081908, + 0x080819080808192b, 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b, + 0x0808190808191919, 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908, + 0x0808190819080808, 0x080819081908082b, 0x0808190819081919, 0x0808190819082b08, + 0x0808190819190819, 0x0808190819191908, 0x080819081919192b, 0x08081908192b0808, + 0x080819082b080819, 0x080819082b081908, 0x080819082b190808, 0x0808191908080808, + 0x080819190808082b, 0x0808191908081919, 0x0808191908082b08, 0x0808191908190819, + 0x0808191908191908, 0x08081919082b0808, 0x0808191919080819, 0x0808191919081908, + 0x0808191919190808, 0x08081919192b0819, 0x080819192b080808, 0x0808192b08080819, + 0x0808192b08081908, 0x0808192b08190808, 0x0808192b082b192b, 0x0808192b19080808, + 0x0808192b1908082b, 0x0808192b2b081908, 0x08082b0808080808, 0x08082b080808082b, + 0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808082b2b, 0x08082b0808190819, + 0x08082b0808191908, 0x08082b08082b0808, 0x08082b08082b1919, 0x08082b0819080819, + 0x08082b0819081908, 0x08082b0819190808, 0x08082b0819192b08, 0x08082b082b080808, + 0x08082b082b2b0808, 0x08082b082b2b2b2b, 0x08082b1908080819, 0x08082b1908081908, + 0x08082b1908190808, 0x08082b1919080808, 0x08082b192b080819, 0x08082b192b082b19, + 0x08082b2b08080808, 0x08082b2b082b0808, 0x08082b2b082b2b08, 0x08082b2b2b19192b, + 0x08082b2b2b2b0808, 0x0819080808080819, 0x0819080808081908, 0x081908080808192b, + 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, 0x0819080808191919, + 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, 0x0819080819080808, + 0x081908081908082b, 0x0819080819081919, 0x0819080819082b08, 0x0819080819190819, + 0x0819080819191908, 0x08190808192b0808, 0x08190808192b2b2b, 0x081908082b080819, + 0x081908082b081908, 0x081908082b190808, 0x0819081908080808, 0x081908190808082b, + 0x0819081908081919, 0x0819081908082b08, 0x0819081908190819, 0x0819081908191908, + 0x08190819082b0808, 0x0819081919080819, 0x0819081919081908, 0x0819081919190808, + 0x081908192b080808, 0x081908192b191908, 0x081908192b19192b, 0x0819082b08080819, + 0x0819082b08081908, 0x0819082b0808192b, 0x0819082b08190808, 0x0819082b19080808, + 0x0819082b192b0808, 0x0819190808080808, 0x081919080808082b, 0x0819190808081919, + 0x0819190808082b08, 0x0819190808190819, 0x0819190808191908, 0x08191908082b0808, + 0x0819190819080819, 0x0819190819081908, 0x0819190819082b19, 0x0819190819190808, + 0x08191908192b1908, 0x081919082b080808, 0x0819191908080819, 0x0819191908081908, + 0x0819191908190808, 0x0819191919080808, 0x0819192b08080808, 0x0819192b08191908, + 0x0819192b19082b19, 0x08192b0808080819, 0x08192b0808081908, 0x08192b0808190808, + 0x08192b080819082b, 0x08192b0819080808, 0x08192b0819191908, 0x08192b082b08192b, + 0x08192b1908080808, 0x08192b1908081919, 0x08192b19192b192b, 0x08192b2b19190819, + 0x08192b2b2b2b2b19, 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919, + 0x082b080808082b08, 0x082b080808082b2b, 0x082b080808190819, 0x082b080808191908, + 0x082b0808082b0808, 0x082b080819080819, 0x082b080819081908, 0x082b080819190808, + 0x082b08082b080808, 0x082b08082b2b0808, 0x082b081908080819, 0x082b081908081908, + 0x082b081908190808, 0x082b081919080808, 0x082b081919082b08, 0x082b0819192b1919, + 0x082b082b08080808, 0x082b082b082b082b, 0x082b082b2b080808, 0x082b082b2b2b2b08, + 0x082b190808080819, 0x082b190808081908, 0x082b190808190808, 0x082b1908082b2b19, + 0x082b190819080808, 0x082b191908080808, 0x082b191919080819, 0x082b19191919082b, + 0x082b19192b192b19, 0x082b192b08080819, 0x082b192b08192b2b, 0x082b192b2b2b192b, + 0x082b2b0808080808, 0x082b2b0808082b08, 0x082b2b0808082b2b, 0x082b2b08082b0808, + 0x082b2b0819191919, 0x082b2b082b082b08, 0x082b2b082b2b082b, 0x082b2b19192b2b08, + 0x082b2b192b190808, 0x082b2b2b08082b08, 0x082b2b2b082b0808, 0x082b2b2b2b08082b, + 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, 0x1908080808081908, + 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, 0x190808080819082b, + 0x1908080808191919, 0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908, + 0x1908080819080808, 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08, + 0x1908080819082b2b, 0x1908080819190819, 0x1908080819191908, 0x19080808192b0808, + 0x19080808192b1919, 0x190808082b080819, 0x190808082b081908, 0x190808082b190808, + 0x1908081908080808, 0x190808190808082b, 0x1908081908081919, 0x1908081908082b08, + 0x1908081908190819, 0x1908081908191908, 0x19080819082b0808, 0x1908081919080819, + 0x1908081919081908, 0x1908081919190808, 0x190808192b080808, 0x190808192b081919, + 0x190808192b2b082b, 0x1908082b08080819, 0x1908082b08081908, 0x1908082b08190808, + 0x1908082b0819082b, 0x1908082b082b2b19, 0x1908082b19080808, 0x1908190808080808, + 0x190819080808082b, 0x1908190808081919, 0x1908190808082b08, 0x1908190808190819, + 0x1908190808191908, 0x1908190808192b19, 0x19081908082b0808, 0x1908190819080819, + 0x1908190819081908, 0x1908190819190808, 0x190819082b080808, 0x190819082b191908, + 0x1908191908080819, 0x1908191908081908, 0x1908191908190808, 0x19081919082b1908, + 0x1908191919080808, 0x190819192b192b2b, 0x1908192b08080808, 0x1908192b08082b2b, + 0x1908192b19081908, 0x1908192b19190808, 0x19082b0808080819, 0x19082b0808081908, + 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, 0x19082b0819191908, + 0x19082b08192b082b, 0x19082b1908080808, 0x19082b1908190819, 0x19082b1919081908, + 0x19082b1919190808, 0x19082b19192b2b19, 0x19082b2b08081908, 0x1919080808080808, + 0x191908080808082b, 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819, + 0x1919080808191908, 0x19190808082b0808, 0x19190808082b2b08, 0x1919080819080819, + 0x1919080819081908, 0x1919080819190808, 0x191908082b080808, 0x1919081908080819, + 0x1919081908081908, 0x1919081908190808, 0x1919081908191919, 0x1919081919080808, + 0x191908191908082b, 0x1919082b08080808, 0x1919082b19081908, 0x1919082b2b2b2b2b, + 0x1919190808080819, 0x1919190808081908, 0x1919190808190808, 0x19191908082b0819, + 0x1919190819080808, 0x19191908192b0808, 0x191919082b080819, 0x191919082b2b0819, + 0x1919191908080808, 0x1919191908082b08, 0x191919192b080808, 0x191919192b082b08, + 0x1919192b082b0819, 0x1919192b192b2b08, 0x1919192b2b2b0819, 0x19192b0808080808, + 0x19192b0808191908, 0x19192b0819080819, 0x19192b0819190808, 0x19192b082b192b19, + 0x19192b1908192b2b, 0x19192b1919080808, 0x19192b191908082b, 0x19192b2b2b081919, + 0x192b080808080819, 0x192b080808081908, 0x192b080808190808, 0x192b080819080808, + 0x192b080819191908, 0x192b0808192b082b, 0x192b08082b08192b, 0x192b08082b2b2b19, + 0x192b081908080808, 0x192b082b082b1908, 0x192b082b19082b2b, 0x192b082b2b19082b, + 0x192b190808080808, 0x192b19080819192b, 0x192b191908190808, 0x192b191919080808, + 0x192b191919081919, 0x192b19192b2b1908, 0x192b2b0808080819, 0x192b2b08192b2b2b, + 0x192b2b19082b1919, 0x192b2b2b0808192b, 0x192b2b2b19191908, 0x192b2b2b192b082b, + 0x2b08080808080808, 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08, + 0x2b08080808190819, 0x2b08080808191908, 0x2b080808082b0808, 0x2b080808082b2b2b, + 0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808082b080808, + 0x2b0808082b08082b, 0x2b0808082b2b2b08, 0x2b0808082b2b2b2b, 0x2b08081908080819, + 0x2b08081908081908, 0x2b0808190808192b, 0x2b08081908190808, 0x2b08081919080808, + 0x2b08081919190819, 0x2b08081919192b19, 0x2b08082b08080808, 0x2b08082b082b0808, + 0x2b08082b2b080808, 0x2b08082b2b08082b, 0x2b08082b2b2b0808, 0x2b08082b2b2b2b08, + 0x2b08190808080819, 0x2b08190808081908, 0x2b08190808190808, 0x2b0819080819082b, + 0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808, 0x2b0819082b082b19, + 0x2b08191908080808, 0x2b08191919081908, 0x2b0819192b2b1919, 0x2b08192b08192b08, + 0x2b08192b192b2b2b, 0x2b082b0808080808, 0x2b082b0808082b08, 0x2b082b08082b1919, + 0x2b082b0819192b2b, 0x2b082b082b080808, 0x2b082b082b08082b, 0x2b082b082b2b2b08, + 0x2b082b190808192b, 0x2b082b2b082b082b, 0x2b082b2b2b080808, 0x2b082b2b2b082b08, + 0x2b082b2b2b19192b, 0x2b082b2b2b2b2b08, 0x2b19080808080819, 0x2b19080808081908, + 0x2b19080808190808, 0x2b19080819080808, 0x2b1908081919192b, 0x2b1908082b081908, + 0x2b19081908080808, 0x2b190819082b082b, 0x2b190819192b1908, 0x2b19082b1919192b, + 0x2b19082b2b082b19, 0x2b19190808080808, 0x2b19190808081919, 0x2b19190819081908, + 0x2b19190819190808, 0x2b19190819192b08, 0x2b191919082b2b19, 0x2b1919192b190808, + 0x2b1919192b19082b, 0x2b19192b19080819, 0x2b192b0819190819, 0x2b192b082b2b192b, + 0x2b192b1919082b19, 0x2b192b2b08191919, 0x2b192b2b192b0808, 0x2b2b080808080808, + 0x2b2b08080808082b, 0x2b2b080808082b08, 0x2b2b080808082b2b, 0x2b2b0808082b0808, + 0x2b2b0808082b2b2b, 0x2b2b08082b2b0808, 0x2b2b081919190819, 0x2b2b081919192b19, + 0x2b2b08192b2b192b, 0x2b2b082b08080808, 0x2b2b082b0808082b, 0x2b2b082b08082b08, + 0x2b2b082b082b2b2b, 0x2b2b082b2b080808, 0x2b2b082b2b2b0808, 0x2b2b190819080808, + 0x2b2b19082b191919, 0x2b2b192b192b1919, 0x2b2b192b2b192b08, 0x2b2b2b0808082b2b, + 0x2b2b2b08082b0808, 0x2b2b2b08082b082b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b0808, + 0x2b2b2b082b2b2b08, 0x2b2b2b1908081908, 0x2b2b2b192b081908, 0x2b2b2b192b08192b, + 0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b, +GGML_TABLE_END() + +GGML_TABLE_BEGIN(uint64_t, iq2s_grid, 1024) + 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08, + 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b, + 0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919, + 0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b, + 0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919, + 0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x08080808192b192b, + 0x08080808192b2b19, 0x080808082b080808, 0x080808082b08082b, 0x080808082b081919, + 0x080808082b082b08, 0x080808082b190819, 0x080808082b191908, 0x080808082b2b0808, + 0x080808082b2b1919, 0x080808082b2b2b2b, 0x0808081908080819, 0x0808081908081908, + 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808, 0x080808190819082b, + 0x0808081908191919, 0x0808081908192b08, 0x08080819082b0819, 0x08080819082b1908, + 0x0808081919080808, 0x080808191908082b, 0x0808081919081919, 0x0808081919082b08, + 0x0808081919190819, 0x0808081919191908, 0x080808191919192b, 0x0808081919192b19, + 0x08080819192b0808, 0x08080819192b1919, 0x08080819192b2b08, 0x080808192b080819, + 0x080808192b081908, 0x080808192b190808, 0x080808192b19082b, 0x080808192b191919, + 0x080808192b2b0819, 0x080808192b2b1908, 0x0808082b08080808, 0x0808082b0808082b, + 0x0808082b08081919, 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, + 0x0808082b082b0808, 0x0808082b082b2b2b, 0x0808082b19080819, 0x0808082b19081908, + 0x0808082b1908192b, 0x0808082b19082b19, 0x0808082b19190808, 0x0808082b19191919, + 0x0808082b2b080808, 0x0808082b2b081919, 0x0808082b2b082b2b, 0x0808082b2b191908, + 0x0808082b2b2b082b, 0x0808190808080819, 0x0808190808081908, 0x080819080808192b, + 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b, 0x0808190808191919, + 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908, 0x08081908082b192b, + 0x08081908082b2b19, 0x0808190819080808, 0x080819081908082b, 0x0808190819081919, + 0x0808190819082b08, 0x0808190819082b2b, 0x0808190819190819, 0x0808190819191908, + 0x080819081919192b, 0x0808190819192b19, 0x08081908192b0808, 0x08081908192b082b, + 0x08081908192b1919, 0x080819082b080819, 0x080819082b081908, 0x080819082b08192b, + 0x080819082b082b19, 0x080819082b190808, 0x080819082b191919, 0x080819082b192b08, + 0x080819082b2b0819, 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, + 0x0808191908081919, 0x0808191908082b08, 0x0808191908082b2b, 0x0808191908190819, + 0x0808191908191908, 0x080819190819192b, 0x0808191908192b19, 0x08081919082b0808, + 0x08081919082b1919, 0x08081919082b2b08, 0x0808191919080819, 0x0808191919081908, + 0x080819191908192b, 0x0808191919082b19, 0x0808191919190808, 0x080819191919082b, + 0x0808191919191919, 0x0808191919192b08, 0x08081919192b0819, 0x08081919192b1908, + 0x080819192b080808, 0x080819192b08082b, 0x080819192b081919, 0x080819192b082b08, + 0x080819192b190819, 0x080819192b191908, 0x080819192b2b0808, 0x0808192b08080819, + 0x0808192b08081908, 0x0808192b0808192b, 0x0808192b08082b19, 0x0808192b08190808, + 0x0808192b08191919, 0x0808192b19080808, 0x0808192b19081919, 0x0808192b19082b08, + 0x0808192b19190819, 0x0808192b19191908, 0x0808192b192b0808, 0x0808192b2b080819, + 0x0808192b2b081908, 0x0808192b2b190808, 0x08082b0808080808, 0x08082b080808082b, + 0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808190819, 0x08082b0808191908, + 0x08082b080819192b, 0x08082b0808192b19, 0x08082b08082b0808, 0x08082b08082b1919, + 0x08082b08082b2b2b, 0x08082b0819080819, 0x08082b0819081908, 0x08082b081908192b, + 0x08082b0819082b19, 0x08082b0819190808, 0x08082b081919082b, 0x08082b0819191919, + 0x08082b0819192b08, 0x08082b08192b0819, 0x08082b08192b1908, 0x08082b082b080808, + 0x08082b082b081919, 0x08082b082b191908, 0x08082b082b2b2b2b, 0x08082b1908080819, + 0x08082b1908081908, 0x08082b1908190808, 0x08082b190819082b, 0x08082b1908191919, + 0x08082b1908192b08, 0x08082b19082b0819, 0x08082b1919080808, 0x08082b1919081919, + 0x08082b1919082b08, 0x08082b1919190819, 0x08082b1919191908, 0x08082b19192b0808, + 0x08082b192b080819, 0x08082b192b190808, 0x08082b2b08080808, 0x08082b2b08190819, + 0x08082b2b08191908, 0x08082b2b082b082b, 0x08082b2b082b2b08, 0x08082b2b082b2b2b, + 0x08082b2b19190808, 0x08082b2b2b192b19, 0x0819080808080819, 0x0819080808081908, + 0x081908080808192b, 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, + 0x0819080808191919, 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, + 0x08190808082b192b, 0x0819080819080808, 0x081908081908082b, 0x0819080819081919, + 0x0819080819082b08, 0x0819080819190819, 0x0819080819191908, 0x081908081919192b, + 0x0819080819192b19, 0x08190808192b0808, 0x08190808192b082b, 0x08190808192b1919, + 0x08190808192b2b08, 0x081908082b080819, 0x081908082b081908, 0x081908082b08192b, + 0x081908082b190808, 0x081908082b191919, 0x081908082b192b08, 0x081908082b2b0819, + 0x081908082b2b1908, 0x0819081908080808, 0x081908190808082b, 0x0819081908081919, + 0x0819081908082b08, 0x0819081908082b2b, 0x0819081908190819, 0x0819081908191908, + 0x081908190819192b, 0x0819081908192b19, 0x08190819082b0808, 0x08190819082b082b, + 0x08190819082b1919, 0x08190819082b2b08, 0x0819081919080819, 0x0819081919081908, + 0x081908191908192b, 0x0819081919082b19, 0x0819081919190808, 0x081908191919082b, + 0x0819081919191919, 0x0819081919192b08, 0x08190819192b0819, 0x08190819192b1908, + 0x081908192b080808, 0x081908192b08082b, 0x081908192b081919, 0x081908192b082b08, + 0x081908192b190819, 0x081908192b191908, 0x0819082b08080819, 0x0819082b08081908, + 0x0819082b08082b19, 0x0819082b08190808, 0x0819082b08191919, 0x0819082b082b0819, + 0x0819082b082b1908, 0x0819082b19080808, 0x0819082b19081919, 0x0819082b19190819, + 0x0819082b19191908, 0x0819082b2b080819, 0x0819082b2b081908, 0x0819082b2b190808, + 0x0819190808080808, 0x081919080808082b, 0x0819190808081919, 0x0819190808082b08, + 0x0819190808190819, 0x0819190808191908, 0x081919080819192b, 0x0819190808192b19, + 0x08191908082b0808, 0x08191908082b1919, 0x08191908082b2b08, 0x0819190819080819, + 0x0819190819081908, 0x081919081908192b, 0x0819190819082b19, 0x0819190819190808, + 0x081919081919082b, 0x0819190819191919, 0x0819190819192b08, 0x08191908192b0819, + 0x08191908192b1908, 0x081919082b080808, 0x081919082b08082b, 0x081919082b081919, + 0x081919082b082b08, 0x081919082b190819, 0x081919082b191908, 0x081919082b2b0808, + 0x0819191908080819, 0x0819191908081908, 0x081919190808192b, 0x0819191908082b19, + 0x0819191908190808, 0x081919190819082b, 0x0819191908191919, 0x0819191908192b08, + 0x08191919082b0819, 0x08191919082b1908, 0x0819191919080808, 0x081919191908082b, + 0x0819191919081919, 0x0819191919082b08, 0x0819191919190819, 0x0819191919191908, + 0x08191919192b0808, 0x081919192b080819, 0x081919192b081908, 0x081919192b190808, + 0x0819192b08080808, 0x0819192b08081919, 0x0819192b08082b08, 0x0819192b08190819, + 0x0819192b08191908, 0x0819192b082b0808, 0x0819192b19080819, 0x0819192b19081908, + 0x0819192b19190808, 0x0819192b2b080808, 0x0819192b2b2b2b2b, 0x08192b0808080819, + 0x08192b0808081908, 0x08192b080808192b, 0x08192b0808082b19, 0x08192b0808190808, + 0x08192b0808191919, 0x08192b0808192b08, 0x08192b08082b0819, 0x08192b0819080808, + 0x08192b081908082b, 0x08192b0819081919, 0x08192b0819082b08, 0x08192b0819190819, + 0x08192b0819191908, 0x08192b08192b0808, 0x08192b082b080819, 0x08192b082b081908, + 0x08192b1908080808, 0x08192b190808082b, 0x08192b1908081919, 0x08192b1908082b08, + 0x08192b1908190819, 0x08192b1908191908, 0x08192b19082b0808, 0x08192b1919080819, + 0x08192b1919081908, 0x08192b1919190808, 0x08192b19192b2b19, 0x08192b192b2b082b, + 0x08192b2b08081908, 0x08192b2b08190808, 0x08192b2b19080808, 0x08192b2b1919192b, + 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919, 0x082b080808082b08, + 0x082b080808190819, 0x082b080808191908, 0x082b08080819192b, 0x082b080808192b19, + 0x082b0808082b0808, 0x082b0808082b1919, 0x082b0808082b2b2b, 0x082b080819080819, + 0x082b080819081908, 0x082b080819190808, 0x082b08081919082b, 0x082b080819191919, + 0x082b0808192b1908, 0x082b08082b080808, 0x082b08082b082b2b, 0x082b08082b191908, + 0x082b08082b2b2b2b, 0x082b081908080819, 0x082b081908081908, 0x082b081908190808, + 0x082b08190819082b, 0x082b081908191919, 0x082b0819082b0819, 0x082b081919080808, + 0x082b08191908082b, 0x082b081919081919, 0x082b081919190819, 0x082b081919191908, + 0x082b0819192b0808, 0x082b08192b080819, 0x082b08192b081908, 0x082b08192b190808, + 0x082b082b08080808, 0x082b082b08082b2b, 0x082b082b082b082b, 0x082b082b082b2b08, + 0x082b082b082b2b2b, 0x082b082b19081908, 0x082b082b19190808, 0x082b082b2b082b08, + 0x082b082b2b082b2b, 0x082b082b2b2b2b08, 0x082b190808080819, 0x082b190808081908, + 0x082b19080808192b, 0x082b190808082b19, 0x082b190808190808, 0x082b190808191919, + 0x082b190808192b08, 0x082b1908082b0819, 0x082b1908082b1908, 0x082b190819080808, + 0x082b19081908082b, 0x082b190819081919, 0x082b190819082b08, 0x082b190819190819, + 0x082b190819191908, 0x082b1908192b0808, 0x082b19082b080819, 0x082b19082b081908, + 0x082b19082b190808, 0x082b191908080808, 0x082b191908081919, 0x082b191908082b08, + 0x082b191908190819, 0x082b191908191908, 0x082b1919082b0808, 0x082b191919080819, + 0x082b191919081908, 0x082b191919190808, 0x082b1919192b192b, 0x082b19192b080808, + 0x082b192b08080819, 0x082b192b08081908, 0x082b192b08190808, 0x082b192b19080808, + 0x082b192b19192b19, 0x082b2b0808080808, 0x082b2b0808081919, 0x082b2b0808190819, + 0x082b2b0808191908, 0x082b2b0819080819, 0x082b2b0819081908, 0x082b2b0819190808, + 0x082b2b082b082b2b, 0x082b2b082b2b2b2b, 0x082b2b1908080819, 0x082b2b1908081908, + 0x082b2b1908190808, 0x082b2b192b191919, 0x082b2b2b08082b2b, 0x082b2b2b082b082b, + 0x082b2b2b192b1908, 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, + 0x1908080808081908, 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, + 0x190808080819082b, 0x1908080808191919, 0x1908080808192b08, 0x1908080808192b2b, + 0x19080808082b0819, 0x19080808082b1908, 0x19080808082b192b, 0x1908080819080808, + 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08, 0x1908080819082b2b, + 0x1908080819190819, 0x1908080819191908, 0x190808081919192b, 0x1908080819192b19, + 0x19080808192b0808, 0x19080808192b082b, 0x19080808192b1919, 0x190808082b080819, + 0x190808082b081908, 0x190808082b190808, 0x190808082b191919, 0x190808082b192b08, + 0x190808082b2b0819, 0x190808082b2b1908, 0x1908081908080808, 0x190808190808082b, + 0x1908081908081919, 0x1908081908082b08, 0x1908081908190819, 0x1908081908191908, + 0x190808190819192b, 0x1908081908192b19, 0x19080819082b0808, 0x19080819082b082b, + 0x19080819082b1919, 0x1908081919080819, 0x1908081919081908, 0x190808191908192b, + 0x1908081919082b19, 0x1908081919190808, 0x190808191919082b, 0x1908081919191919, + 0x1908081919192b08, 0x19080819192b0819, 0x19080819192b1908, 0x190808192b080808, + 0x190808192b08082b, 0x190808192b081919, 0x190808192b082b08, 0x190808192b190819, + 0x190808192b191908, 0x190808192b2b0808, 0x1908082b08080819, 0x1908082b08081908, + 0x1908082b08190808, 0x1908082b0819082b, 0x1908082b08191919, 0x1908082b08192b08, + 0x1908082b082b1908, 0x1908082b19080808, 0x1908082b19081919, 0x1908082b19082b08, + 0x1908082b19190819, 0x1908082b19191908, 0x1908082b192b0808, 0x1908082b2b080819, + 0x1908082b2b081908, 0x1908190808080808, 0x190819080808082b, 0x1908190808081919, + 0x1908190808082b08, 0x1908190808082b2b, 0x1908190808190819, 0x1908190808191908, + 0x190819080819192b, 0x1908190808192b19, 0x19081908082b0808, 0x19081908082b082b, + 0x19081908082b1919, 0x19081908082b2b08, 0x1908190819080819, 0x1908190819081908, + 0x190819081908192b, 0x1908190819082b19, 0x1908190819190808, 0x190819081919082b, + 0x1908190819191919, 0x1908190819192b08, 0x19081908192b0819, 0x19081908192b1908, + 0x190819082b080808, 0x190819082b08082b, 0x190819082b081919, 0x190819082b082b08, + 0x190819082b190819, 0x190819082b191908, 0x190819082b2b0808, 0x1908191908080819, + 0x1908191908081908, 0x190819190808192b, 0x1908191908082b19, 0x1908191908190808, + 0x190819190819082b, 0x1908191908191919, 0x1908191908192b08, 0x19081919082b0819, + 0x19081919082b1908, 0x1908191919080808, 0x190819191908082b, 0x1908191919081919, + 0x1908191919082b08, 0x1908191919190819, 0x1908191919191908, 0x19081919192b0808, + 0x19081919192b2b2b, 0x190819192b080819, 0x190819192b081908, 0x190819192b190808, + 0x1908192b08080808, 0x1908192b0808082b, 0x1908192b08081919, 0x1908192b08082b08, + 0x1908192b08190819, 0x1908192b08191908, 0x1908192b082b0808, 0x1908192b19080819, + 0x1908192b19081908, 0x1908192b19190808, 0x1908192b2b080808, 0x1908192b2b2b1919, + 0x19082b0808080819, 0x19082b0808081908, 0x19082b0808082b19, 0x19082b0808190808, + 0x19082b080819082b, 0x19082b0808191919, 0x19082b0808192b08, 0x19082b08082b0819, + 0x19082b08082b1908, 0x19082b0819080808, 0x19082b081908082b, 0x19082b0819081919, + 0x19082b0819082b08, 0x19082b0819190819, 0x19082b0819191908, 0x19082b08192b0808, + 0x19082b082b081908, 0x19082b082b190808, 0x19082b1908080808, 0x19082b190808082b, + 0x19082b1908081919, 0x19082b1908082b08, 0x19082b1908190819, 0x19082b1908191908, + 0x19082b19082b0808, 0x19082b1919080819, 0x19082b1919081908, 0x19082b1919190808, + 0x19082b192b080808, 0x19082b192b19192b, 0x19082b2b08080819, 0x19082b2b08081908, + 0x19082b2b08190808, 0x19082b2b19080808, 0x1919080808080808, 0x191908080808082b, + 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819, 0x1919080808191908, + 0x191908080819192b, 0x1919080808192b19, 0x19190808082b0808, 0x19190808082b082b, + 0x19190808082b1919, 0x19190808082b2b08, 0x1919080819080819, 0x1919080819081908, + 0x191908081908192b, 0x1919080819082b19, 0x1919080819190808, 0x191908081919082b, + 0x1919080819191919, 0x1919080819192b08, 0x19190808192b0819, 0x19190808192b1908, + 0x191908082b080808, 0x191908082b08082b, 0x191908082b081919, 0x191908082b082b08, + 0x191908082b190819, 0x191908082b191908, 0x1919081908080819, 0x1919081908081908, + 0x191908190808192b, 0x1919081908082b19, 0x1919081908190808, 0x191908190819082b, + 0x1919081908191919, 0x1919081908192b08, 0x19190819082b0819, 0x19190819082b1908, + 0x1919081919080808, 0x191908191908082b, 0x1919081919081919, 0x1919081919082b08, + 0x1919081919190819, 0x1919081919191908, 0x19190819192b0808, 0x191908192b080819, + 0x191908192b081908, 0x191908192b190808, 0x1919082b08080808, 0x1919082b08081919, + 0x1919082b08082b08, 0x1919082b08190819, 0x1919082b08191908, 0x1919082b082b0808, + 0x1919082b19080819, 0x1919082b19081908, 0x1919082b19190808, 0x1919082b192b2b19, + 0x1919082b2b080808, 0x1919190808080819, 0x1919190808081908, 0x191919080808192b, + 0x1919190808082b19, 0x1919190808190808, 0x191919080819082b, 0x1919190808191919, + 0x1919190808192b08, 0x19191908082b0819, 0x19191908082b1908, 0x1919190819080808, + 0x191919081908082b, 0x1919190819081919, 0x1919190819082b08, 0x1919190819190819, + 0x1919190819191908, 0x19191908192b0808, 0x191919082b080819, 0x191919082b081908, + 0x191919082b190808, 0x1919191908080808, 0x191919190808082b, 0x1919191908081919, + 0x1919191908082b08, 0x1919191908190819, 0x1919191908191908, 0x19191919082b0808, + 0x1919191919080819, 0x1919191919081908, 0x1919191919190808, 0x191919192b080808, + 0x1919192b08080819, 0x1919192b08081908, 0x1919192b08190808, 0x1919192b082b192b, + 0x1919192b19080808, 0x19192b0808080808, 0x19192b080808082b, 0x19192b0808081919, + 0x19192b0808082b08, 0x19192b0808190819, 0x19192b0808191908, 0x19192b08082b0808, + 0x19192b0819080819, 0x19192b0819081908, 0x19192b0819190808, 0x19192b0819192b2b, + 0x19192b082b080808, 0x19192b1908080819, 0x19192b1908081908, 0x19192b1908190808, + 0x19192b1919080808, 0x19192b2b08080808, 0x19192b2b08192b19, 0x19192b2b2b081919, + 0x19192b2b2b2b2b08, 0x192b080808080819, 0x192b080808081908, 0x192b08080808192b, + 0x192b080808190808, 0x192b08080819082b, 0x192b080808191919, 0x192b080808192b08, + 0x192b0808082b0819, 0x192b0808082b1908, 0x192b080819080808, 0x192b080819081919, + 0x192b080819082b08, 0x192b080819190819, 0x192b080819191908, 0x192b0808192b0808, + 0x192b08082b081908, 0x192b08082b190808, 0x192b081908080808, 0x192b08190808082b, + 0x192b081908081919, 0x192b081908082b08, 0x192b081908190819, 0x192b081908191908, + 0x192b0819082b0808, 0x192b081919080819, 0x192b081919081908, 0x192b081919190808, + 0x192b08192b080808, 0x192b08192b192b19, 0x192b082b08081908, 0x192b082b08190808, + 0x192b082b19080808, 0x192b082b1919192b, 0x192b082b2b2b0819, 0x192b190808080808, + 0x192b190808081919, 0x192b190808082b08, 0x192b190808190819, 0x192b190808191908, + 0x192b1908082b0808, 0x192b190819080819, 0x192b190819081908, 0x192b190819190808, + 0x192b19082b080808, 0x192b191908080819, 0x192b191908081908, 0x192b191908190808, + 0x192b191919080808, 0x192b191919082b2b, 0x192b1919192b2b08, 0x192b19192b19082b, + 0x192b192b08080808, 0x192b192b2b191908, 0x192b2b0808080819, 0x192b2b0808081908, + 0x192b2b0808190808, 0x192b2b08192b1919, 0x192b2b082b192b08, 0x192b2b1908080808, + 0x192b2b19082b2b2b, 0x192b2b2b1908082b, 0x192b2b2b2b2b0819, 0x2b08080808080808, + 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08, 0x2b08080808190819, + 0x2b08080808191908, 0x2b08080808192b19, 0x2b080808082b0808, 0x2b080808082b1919, + 0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808081919082b, + 0x2b08080819191919, 0x2b08080819192b08, 0x2b080808192b0819, 0x2b0808082b080808, + 0x2b0808082b081919, 0x2b0808082b190819, 0x2b0808082b191908, 0x2b08081908080819, + 0x2b08081908081908, 0x2b08081908082b19, 0x2b08081908190808, 0x2b0808190819082b, + 0x2b08081908191919, 0x2b08081908192b08, 0x2b080819082b0819, 0x2b080819082b1908, + 0x2b08081919080808, 0x2b0808191908082b, 0x2b08081919081919, 0x2b08081919082b08, + 0x2b08081919190819, 0x2b08081919191908, 0x2b0808192b080819, 0x2b0808192b081908, + 0x2b0808192b190808, 0x2b0808192b2b2b19, 0x2b08082b08080808, 0x2b08082b08081919, + 0x2b08082b08082b2b, 0x2b08082b08190819, 0x2b08082b08191908, 0x2b08082b19080819, + 0x2b08082b19081908, 0x2b08082b19190808, 0x2b08190808080819, 0x2b08190808081908, + 0x2b0819080808192b, 0x2b08190808082b19, 0x2b08190808190808, 0x2b0819080819082b, + 0x2b08190808191919, 0x2b08190808192b08, 0x2b081908082b0819, 0x2b08190819080808, + 0x2b0819081908082b, 0x2b08190819081919, 0x2b08190819082b08, 0x2b08190819190819, + 0x2b08190819191908, 0x2b081908192b0808, 0x2b0819082b080819, 0x2b0819082b081908, + 0x2b0819082b190808, 0x2b08191908080808, 0x2b0819190808082b, 0x2b08191908081919, + 0x2b08191908082b08, 0x2b08191908190819, 0x2b08191908191908, 0x2b081919082b0808, + 0x2b08191919080819, 0x2b08191919081908, 0x2b08191919190808, 0x2b0819192b080808, + 0x2b0819192b082b2b, 0x2b08192b08080819, 0x2b08192b08081908, 0x2b08192b08190808, + 0x2b08192b082b2b19, 0x2b08192b19080808, 0x2b082b0808080808, 0x2b082b0808081919, + 0x2b082b0808190819, 0x2b082b0808191908, 0x2b082b0819080819, 0x2b082b0819081908, + 0x2b082b0819190808, 0x2b082b082b2b082b, 0x2b082b1908080819, 0x2b082b1908081908, + 0x2b082b1919080808, 0x2b082b19192b1919, 0x2b082b2b082b082b, 0x2b082b2b19192b08, + 0x2b082b2b19192b2b, 0x2b082b2b2b08082b, 0x2b082b2b2b2b082b, 0x2b19080808080819, + 0x2b19080808081908, 0x2b19080808082b19, 0x2b19080808190808, 0x2b1908080819082b, + 0x2b19080808191919, 0x2b19080808192b08, 0x2b190808082b1908, 0x2b19080819080808, + 0x2b1908081908082b, 0x2b19080819081919, 0x2b19080819082b08, 0x2b19080819190819, + 0x2b19080819191908, 0x2b190808192b0808, 0x2b1908082b080819, 0x2b1908082b081908, + 0x2b1908082b190808, 0x2b19081908080808, 0x2b19081908081919, 0x2b19081908190819, + 0x2b19081908191908, 0x2b19081919080819, 0x2b19081919081908, 0x2b19081919190808, + 0x2b19081919192b2b, 0x2b19082b08080819, 0x2b19082b08081908, 0x2b19082b08190808, + 0x2b19082b19080808, 0x2b19082b2b2b192b, 0x2b19190808080808, 0x2b1919080808082b, + 0x2b19190808081919, 0x2b19190808082b08, 0x2b19190808190819, 0x2b19190808191908, + 0x2b191908082b0808, 0x2b19190819080819, 0x2b19190819081908, 0x2b19190819190808, + 0x2b1919082b080808, 0x2b1919082b19192b, 0x2b19191908080819, 0x2b19191908081908, + 0x2b19191908190808, 0x2b19191919080808, 0x2b1919192b192b08, 0x2b1919192b2b0819, + 0x2b19192b08080808, 0x2b19192b1908192b, 0x2b19192b192b1908, 0x2b192b0808080819, + 0x2b192b0808081908, 0x2b192b0808190808, 0x2b192b08082b192b, 0x2b192b0819080808, + 0x2b192b082b2b2b19, 0x2b192b1908080808, 0x2b192b1919082b19, 0x2b192b191919082b, + 0x2b192b2b2b190808, 0x2b2b080808080808, 0x2b2b080808081919, 0x2b2b080808082b2b, + 0x2b2b080808191908, 0x2b2b0808082b082b, 0x2b2b0808082b2b2b, 0x2b2b080819080819, + 0x2b2b080819081908, 0x2b2b080819190808, 0x2b2b08082b2b082b, 0x2b2b08082b2b2b2b, + 0x2b2b081919080808, 0x2b2b0819192b1919, 0x2b2b082b0808082b, 0x2b2b082b08082b2b, + 0x2b2b082b082b082b, 0x2b2b082b082b2b08, 0x2b2b082b082b2b2b, 0x2b2b082b2b08082b, + 0x2b2b082b2b082b08, 0x2b2b082b2b082b2b, 0x2b2b082b2b2b2b08, 0x2b2b190808080819, + 0x2b2b190808081908, 0x2b2b190808190808, 0x2b2b190819080808, 0x2b2b19082b082b19, + 0x2b2b19082b2b1908, 0x2b2b191908080808, 0x2b2b191908192b19, 0x2b2b192b19190819, + 0x2b2b2b0808082b2b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b082b, 0x2b2b2b1919191908, + 0x2b2b2b192b08192b, 0x2b2b2b2b08082b08, 0x2b2b2b2b08082b2b, 0x2b2b2b2b082b0808, + 0x2b2b2b2b082b082b, 0x2b2b2b2b082b2b08, 0x2b2b2b2b2b082b08, 0x2b2b2b2b2b2b2b2b, +GGML_TABLE_END() + +GGML_TABLE_BEGIN(uint32_t, iq3xxs_grid, 256) + 0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414, + 0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14, + 0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404, + 0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e, + 0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c, + 0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c, + 0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, 0x043e0c24, 0x043e0c34, + 0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c, + 0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c, + 0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04, + 0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c, + 0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414, + 0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434, + 0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c, + 0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e, + 0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24, + 0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24, + 0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c, + 0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c, + 0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14, + 0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414, + 0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e, + 0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404, + 0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c, + 0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, 0x2c040c14, 0x2c04240c, + 0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14, + 0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c, + 0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c, + 0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14, + 0x34341c1c, 0x343e041c, 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14, + 0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c, + 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04, +GGML_TABLE_END() + +GGML_TABLE_BEGIN(uint32_t, iq3s_grid, 512) + 0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305, + 0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905, + 0x0101090b, 0x0101090f, 0x01010b03, 0x01010b07, 0x01010d01, 0x01010d05, 0x01010f03, 0x01010f09, + 0x01010f0f, 0x01030101, 0x01030103, 0x01030105, 0x01030109, 0x01030301, 0x01030303, 0x0103030b, + 0x01030501, 0x01030507, 0x0103050f, 0x01030703, 0x0103070b, 0x01030909, 0x01030d03, 0x01030d0b, + 0x01030f05, 0x01050101, 0x01050103, 0x0105010b, 0x0105010f, 0x01050301, 0x01050307, 0x0105030d, + 0x01050503, 0x0105050b, 0x01050701, 0x01050709, 0x01050905, 0x0105090b, 0x0105090f, 0x01050b03, + 0x01050b07, 0x01050f01, 0x01050f07, 0x01070107, 0x01070303, 0x0107030b, 0x01070501, 0x01070505, + 0x01070703, 0x01070707, 0x0107070d, 0x01070909, 0x01070b01, 0x01070b05, 0x01070d0f, 0x01070f03, + 0x01070f0b, 0x01090101, 0x01090307, 0x0109030f, 0x01090503, 0x01090509, 0x01090705, 0x01090901, + 0x01090907, 0x01090b03, 0x01090f01, 0x010b0105, 0x010b0109, 0x010b0501, 0x010b0505, 0x010b050d, + 0x010b0707, 0x010b0903, 0x010b090b, 0x010b090f, 0x010b0d0d, 0x010b0f07, 0x010d010d, 0x010d0303, + 0x010d0307, 0x010d0703, 0x010d0b05, 0x010d0f03, 0x010f0101, 0x010f0105, 0x010f0109, 0x010f0501, + 0x010f0505, 0x010f050d, 0x010f0707, 0x010f0b01, 0x010f0b09, 0x03010101, 0x03010103, 0x03010105, + 0x03010109, 0x03010301, 0x03010303, 0x03010307, 0x0301030b, 0x0301030f, 0x03010501, 0x03010505, + 0x03010703, 0x03010709, 0x0301070d, 0x03010b09, 0x03010b0d, 0x03010d03, 0x03010f05, 0x03030101, + 0x03030103, 0x03030107, 0x0303010d, 0x03030301, 0x03030309, 0x03030503, 0x03030701, 0x03030707, + 0x03030903, 0x03030b01, 0x03030b05, 0x03030f01, 0x03030f0d, 0x03050101, 0x03050305, 0x0305030b, + 0x0305030f, 0x03050501, 0x03050509, 0x03050705, 0x03050901, 0x03050907, 0x03050b0b, 0x03050d01, + 0x03050f05, 0x03070103, 0x03070109, 0x0307010f, 0x03070301, 0x03070307, 0x03070503, 0x0307050f, + 0x03070701, 0x03070709, 0x03070903, 0x03070d05, 0x03070f01, 0x03090107, 0x0309010b, 0x03090305, + 0x03090309, 0x03090703, 0x03090707, 0x03090905, 0x0309090d, 0x03090b01, 0x03090b09, 0x030b0103, + 0x030b0301, 0x030b0307, 0x030b0503, 0x030b0701, 0x030b0705, 0x030b0b03, 0x030d0501, 0x030d0509, + 0x030d050f, 0x030d0909, 0x030d090d, 0x030f0103, 0x030f0107, 0x030f0301, 0x030f0305, 0x030f0503, + 0x030f070b, 0x030f0903, 0x030f0d05, 0x030f0f01, 0x05010101, 0x05010103, 0x05010107, 0x0501010b, + 0x0501010f, 0x05010301, 0x05010305, 0x05010309, 0x0501030d, 0x05010503, 0x05010507, 0x0501050f, + 0x05010701, 0x05010705, 0x05010903, 0x05010907, 0x0501090b, 0x05010b01, 0x05010b05, 0x05010d0f, + 0x05010f01, 0x05010f07, 0x05010f0b, 0x05030101, 0x05030105, 0x05030301, 0x05030307, 0x0503030f, + 0x05030505, 0x0503050b, 0x05030703, 0x05030709, 0x05030905, 0x05030b03, 0x05050103, 0x05050109, + 0x0505010f, 0x05050503, 0x05050507, 0x05050701, 0x0505070f, 0x05050903, 0x05050b07, 0x05050b0f, + 0x05050f03, 0x05050f09, 0x05070101, 0x05070105, 0x0507010b, 0x05070303, 0x05070505, 0x05070509, + 0x05070703, 0x05070707, 0x05070905, 0x05070b01, 0x05070d0d, 0x05090103, 0x0509010f, 0x05090501, + 0x05090507, 0x05090705, 0x0509070b, 0x05090903, 0x05090f05, 0x05090f0b, 0x050b0109, 0x050b0303, + 0x050b0505, 0x050b070f, 0x050b0901, 0x050b0b07, 0x050b0f01, 0x050d0101, 0x050d0105, 0x050d010f, + 0x050d0503, 0x050d0b0b, 0x050d0d03, 0x050f010b, 0x050f0303, 0x050f050d, 0x050f0701, 0x050f0907, + 0x050f0b01, 0x07010105, 0x07010303, 0x07010307, 0x0701030b, 0x0701030f, 0x07010505, 0x07010703, + 0x07010707, 0x0701070b, 0x07010905, 0x07010909, 0x0701090f, 0x07010b03, 0x07010d07, 0x07010f03, + 0x07030103, 0x07030107, 0x0703010b, 0x07030309, 0x07030503, 0x07030507, 0x07030901, 0x07030d01, + 0x07030f05, 0x07030f0d, 0x07050101, 0x07050305, 0x07050501, 0x07050705, 0x07050709, 0x07050b01, + 0x07070103, 0x07070301, 0x07070309, 0x07070503, 0x07070507, 0x0707050f, 0x07070701, 0x07070903, + 0x07070907, 0x0707090f, 0x07070b0b, 0x07070f07, 0x07090107, 0x07090303, 0x0709030d, 0x07090505, + 0x07090703, 0x07090b05, 0x07090d01, 0x07090d09, 0x070b0103, 0x070b0301, 0x070b0305, 0x070b050b, + 0x070b0705, 0x070b0909, 0x070b0b0d, 0x070b0f07, 0x070d030d, 0x070d0903, 0x070f0103, 0x070f0107, + 0x070f0501, 0x070f0505, 0x070f070b, 0x09010101, 0x09010109, 0x09010305, 0x09010501, 0x09010509, + 0x0901050f, 0x09010705, 0x09010903, 0x09010b01, 0x09010f01, 0x09030105, 0x0903010f, 0x09030303, + 0x09030307, 0x09030505, 0x09030701, 0x0903070b, 0x09030907, 0x09030b03, 0x09030b0b, 0x09050103, + 0x09050107, 0x09050301, 0x0905030b, 0x09050503, 0x09050707, 0x09050901, 0x09050b0f, 0x09050d05, + 0x09050f01, 0x09070109, 0x09070303, 0x09070307, 0x09070501, 0x09070505, 0x09070703, 0x0907070b, + 0x09090101, 0x09090105, 0x09090509, 0x0909070f, 0x09090901, 0x09090f03, 0x090b010b, 0x090b010f, + 0x090b0503, 0x090b0d05, 0x090d0307, 0x090d0709, 0x090d0d01, 0x090f0301, 0x090f030b, 0x090f0701, + 0x090f0907, 0x090f0b03, 0x0b010105, 0x0b010301, 0x0b010309, 0x0b010505, 0x0b010901, 0x0b010909, + 0x0b01090f, 0x0b010b05, 0x0b010d0d, 0x0b010f09, 0x0b030103, 0x0b030107, 0x0b03010b, 0x0b030305, + 0x0b030503, 0x0b030705, 0x0b030f05, 0x0b050101, 0x0b050303, 0x0b050507, 0x0b050701, 0x0b05070d, + 0x0b050b07, 0x0b070105, 0x0b07010f, 0x0b070301, 0x0b07050f, 0x0b070909, 0x0b070b03, 0x0b070d0b, + 0x0b070f07, 0x0b090103, 0x0b090109, 0x0b090501, 0x0b090705, 0x0b09090d, 0x0b0b0305, 0x0b0b050d, + 0x0b0b0b03, 0x0b0b0b07, 0x0b0d0905, 0x0b0f0105, 0x0b0f0109, 0x0b0f0505, 0x0d010303, 0x0d010307, + 0x0d01030b, 0x0d010703, 0x0d010707, 0x0d010d01, 0x0d030101, 0x0d030501, 0x0d03050f, 0x0d030d09, + 0x0d050305, 0x0d050709, 0x0d050905, 0x0d050b0b, 0x0d050d05, 0x0d050f01, 0x0d070101, 0x0d070309, + 0x0d070503, 0x0d070901, 0x0d09050b, 0x0d090907, 0x0d090d05, 0x0d0b0101, 0x0d0b0107, 0x0d0b0709, + 0x0d0b0d01, 0x0d0d010b, 0x0d0d0901, 0x0d0f0303, 0x0d0f0307, 0x0f010101, 0x0f010109, 0x0f01010f, + 0x0f010501, 0x0f010505, 0x0f01070d, 0x0f010901, 0x0f010b09, 0x0f010d05, 0x0f030105, 0x0f030303, + 0x0f030509, 0x0f030907, 0x0f03090b, 0x0f050103, 0x0f050109, 0x0f050301, 0x0f05030d, 0x0f050503, + 0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b, + 0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101, +GGML_TABLE_END() + +#define NGRID_IQ1S 2048 +#define IQ1S_DELTA 0.125f +#define IQ1M_DELTA 0.125f +GGML_TABLE_BEGIN(uint32_t, iq1s_grid_gpu, NGRID_IQ1S) + 0x00000000, 0x00000002, 0x00000101, 0x00000200, 0x00000202, 0x00010001, 0x00010101, 0x00020000, + 0x00020002, 0x00020200, 0x00020202, 0x01000101, 0x01010001, 0x01010100, 0x01010102, 0x01020101, + 0x02000000, 0x02000002, 0x02000200, 0x02000202, 0x02010101, 0x02020000, 0x02020002, 0x02020200, + 0x02020202, 0x00000110, 0x00000111, 0x00010011, 0x00010110, 0x00010112, 0x00010211, 0x00010212, + 0x00020111, 0x01000011, 0x01000112, 0x01000211, 0x01010012, 0x01010111, 0x01010212, 0x01020011, + 0x01020110, 0x01020112, 0x01020210, 0x02000111, 0x02010011, 0x02010110, 0x02010112, 0x02020111, + 0x00000020, 0x00000022, 0x00000220, 0x00000222, 0x00010121, 0x00020020, 0x00020022, 0x00020220, + 0x00020222, 0x01000121, 0x01010021, 0x01010221, 0x01020120, 0x01020221, 0x02000020, 0x02000022, + 0x02000220, 0x02000222, 0x02010021, 0x02010121, 0x02010221, 0x02020020, 0x02020022, 0x02020220, + 0x02020222, 0x00011001, 0x00011100, 0x00011102, 0x00021101, 0x01001001, 0x01001201, 0x01011101, + 0x01011202, 0x01021100, 0x01021101, 0x02011001, 0x02011201, 0x02021101, 0x00001011, 0x00001110, + 0x00001111, 0x00001112, 0x00011111, 0x00011210, 0x00011212, 0x00021211, 0x01001010, 0x01001111, + 0x01001212, 0x01011010, 0x01011011, 0x01011110, 0x01011111, 0x01011112, 0x01011211, 0x01021010, + 0x01021012, 0x01021111, 0x01021210, 0x01021212, 0x02001011, 0x02011011, 0x02011111, 0x02011210, + 0x02011212, 0x02021011, 0x02021110, 0x02021111, 0x02021112, 0x02021211, 0x00011120, 0x00011221, + 0x01001021, 0x01001120, 0x01011020, 0x01011022, 0x01011121, 0x01011220, 0x01021020, 0x01021021, + 0x01021122, 0x01021221, 0x02001121, 0x02011021, 0x02011120, 0x02011221, 0x00002000, 0x00002002, + 0x00002200, 0x00002202, 0x00012101, 0x00022000, 0x00022002, 0x00022200, 0x00022202, 0x01002101, + 0x01012001, 0x01012102, 0x01022101, 0x02002000, 0x02002002, 0x02002200, 0x02002202, 0x02012101, + 0x02022000, 0x02022002, 0x02022200, 0x02022202, 0x00002111, 0x00012011, 0x00012110, 0x00012211, + 0x00022110, 0x00022111, 0x01002011, 0x01012010, 0x01012011, 0x01012111, 0x01022011, 0x01022110, + 0x01022211, 0x02012011, 0x02012110, 0x02012112, 0x02012211, 0x02022111, 0x00002020, 0x00002022, + 0x00002220, 0x00002222, 0x00012121, 0x00022020, 0x00022022, 0x00022220, 0x00022222, 0x01002121, + 0x01012021, 0x01012221, 0x01022021, 0x01022121, 0x02002020, 0x02002022, 0x02002121, 0x02002220, + 0x02002222, 0x02012121, 0x02022020, 0x02022022, 0x02022220, 0x02022222, 0x00110000, 0x00110001, + 0x00110100, 0x00110201, 0x00120100, 0x00120101, 0x01100001, 0x01100100, 0x01110000, 0x01110101, + 0x01110200, 0x01120001, 0x01120100, 0x01120101, 0x01120201, 0x02110001, 0x02110100, 0x02110102, + 0x02120001, 0x02120101, 0x00100011, 0x00100110, 0x00100112, 0x00100211, 0x00110010, 0x00110012, + 0x00110111, 0x00110210, 0x00120011, 0x00120110, 0x00120211, 0x01100111, 0x01100212, 0x01110010, + 0x01110011, 0x01110012, 0x01110110, 0x01110111, 0x01110112, 0x01110211, 0x01120010, 0x01120111, + 0x02100110, 0x02110012, 0x02110111, 0x02120011, 0x02120110, 0x00110021, 0x00110120, 0x00110122, + 0x00120121, 0x01100020, 0x01100122, 0x01100221, 0x01110022, 0x01110121, 0x01110220, 0x01110222, + 0x01120120, 0x01120122, 0x02100121, 0x02110021, 0x02110120, 0x02110122, 0x02120121, 0x00101001, + 0x00101102, 0x00101201, 0x00111100, 0x00111101, 0x00111200, 0x00111201, 0x00121001, 0x00121102, + 0x01101001, 0x01101101, 0x01101102, 0x01101200, 0x01101202, 0x01111001, 0x01111100, 0x01111101, + 0x01111102, 0x01111201, 0x01121002, 0x01121101, 0x01121200, 0x02101100, 0x02101201, 0x02111000, + 0x02111100, 0x02111101, 0x02111200, 0x02111201, 0x02111202, 0x02121001, 0x02121100, 0x02121101, + 0x02121201, 0x00101012, 0x00101111, 0x00101212, 0x00111011, 0x00111110, 0x00111111, 0x00111112, + 0x00111211, 0x00121010, 0x00121012, 0x00121111, 0x00121210, 0x00121212, 0x01101011, 0x01101110, + 0x01101111, 0x01101112, 0x01111011, 0x01111012, 0x01111110, 0x01111111, 0x01111112, 0x01111211, + 0x01111212, 0x01121011, 0x01121110, 0x01121111, 0x01121112, 0x01121211, 0x02101010, 0x02101012, + 0x02101110, 0x02101111, 0x02101210, 0x02101212, 0x02111010, 0x02111011, 0x02111110, 0x02111111, + 0x02111112, 0x02111211, 0x02111212, 0x02121010, 0x02121012, 0x02121111, 0x00101021, 0x00101120, + 0x00101121, 0x00101122, 0x00111121, 0x00111122, 0x00111220, 0x00111222, 0x00121021, 0x00121122, + 0x01101020, 0x01101022, 0x01101120, 0x01101121, 0x01101220, 0x01101222, 0x01111021, 0x01111121, + 0x01111122, 0x01111220, 0x01111221, 0x01121021, 0x01121120, 0x01121121, 0x01121220, 0x01121221, + 0x01121222, 0x02101122, 0x02101222, 0x02111022, 0x02111121, 0x02121120, 0x02121221, 0x00112001, + 0x00112102, 0x00122101, 0x01102001, 0x01102100, 0x01102102, 0x01102201, 0x01112000, 0x01112101, + 0x01112200, 0x01112202, 0x01122000, 0x01122001, 0x01122100, 0x01122102, 0x01122201, 0x02102101, + 0x02112001, 0x02112100, 0x02122101, 0x00112010, 0x00112012, 0x00112111, 0x00112212, 0x00122011, + 0x00122111, 0x01102012, 0x01102110, 0x01102111, 0x01102210, 0x01112011, 0x01112110, 0x01112111, + 0x01112112, 0x01112211, 0x01112212, 0x01122010, 0x01122111, 0x01122212, 0x02102211, 0x02112011, + 0x02112012, 0x02112111, 0x02112210, 0x02122011, 0x02122112, 0x02122211, 0x00102221, 0x00112122, + 0x00122120, 0x00122122, 0x01102120, 0x01102122, 0x01102221, 0x01112020, 0x01112022, 0x01112121, + 0x01112220, 0x01122021, 0x01122122, 0x01122221, 0x02102121, 0x02112021, 0x02112122, 0x02112222, + 0x00200000, 0x00200002, 0x00200200, 0x00200202, 0x00210101, 0x00220000, 0x00220002, 0x00220101, + 0x00220200, 0x00220202, 0x01200101, 0x01210001, 0x01210201, 0x01220001, 0x01220101, 0x02200000, + 0x02200002, 0x02200200, 0x02200202, 0x02210101, 0x02220000, 0x02220002, 0x02220101, 0x02220200, + 0x02220202, 0x00200111, 0x00210011, 0x00210110, 0x00210211, 0x00220111, 0x01200012, 0x01200110, + 0x01200211, 0x01210111, 0x01210210, 0x01210212, 0x01220011, 0x01220110, 0x01220111, 0x01220112, + 0x02200111, 0x02210010, 0x02210112, 0x02210211, 0x02220111, 0x00200021, 0x00200220, 0x00200222, + 0x00210021, 0x00210121, 0x00220020, 0x00220022, 0x00220220, 0x00220222, 0x01200121, 0x01210021, + 0x01210122, 0x01210221, 0x01220121, 0x02200021, 0x02200220, 0x02200222, 0x02210021, 0x02210121, + 0x02220020, 0x02220022, 0x02220220, 0x02220222, 0x00201101, 0x00211100, 0x00211102, 0x00211201, + 0x00221101, 0x01201100, 0x01201101, 0x01201102, 0x01201201, 0x01211002, 0x01211101, 0x01211200, + 0x01211202, 0x01221102, 0x02201101, 0x02211001, 0x02211100, 0x02211201, 0x02221001, 0x02221101, + 0x00201211, 0x00211111, 0x00221011, 0x00221211, 0x01201010, 0x01201111, 0x01201210, 0x01211011, + 0x01211110, 0x01211111, 0x01211211, 0x01221012, 0x01221111, 0x01221210, 0x02201211, 0x02211010, + 0x02211110, 0x02211111, 0x02211210, 0x02211212, 0x02221011, 0x02221110, 0x02221112, 0x02221211, + 0x00201121, 0x00211020, 0x00211022, 0x00211221, 0x00221121, 0x01201021, 0x01201221, 0x01211121, + 0x01221020, 0x01221021, 0x01221221, 0x02201120, 0x02201122, 0x02211020, 0x02211222, 0x00202000, + 0x00202002, 0x00202200, 0x00202202, 0x00212101, 0x00222000, 0x00222002, 0x00222200, 0x00222202, + 0x01202101, 0x01212001, 0x01212100, 0x01222101, 0x02202000, 0x02202002, 0x02202200, 0x02202202, + 0x02222000, 0x02222002, 0x02222200, 0x02222202, 0x00202211, 0x00212011, 0x00212110, 0x00212211, + 0x00222111, 0x01202112, 0x01202211, 0x01212012, 0x01212111, 0x01222011, 0x01222110, 0x01222112, + 0x01222211, 0x02202111, 0x02212010, 0x02212112, 0x02212211, 0x02222110, 0x02222111, 0x00202020, + 0x00202022, 0x00202220, 0x00202222, 0x00222020, 0x00222022, 0x00222220, 0x00222222, 0x01202121, + 0x01212021, 0x01212122, 0x01212221, 0x01222121, 0x02202020, 0x02202022, 0x02202220, 0x02202222, + 0x02212121, 0x02222020, 0x02222022, 0x02222220, 0x02222222, 0x10000101, 0x10010001, 0x10010102, + 0x10020101, 0x11000201, 0x11010002, 0x11010101, 0x11010200, 0x11010202, 0x11020001, 0x11020100, + 0x11020102, 0x12010100, 0x12010201, 0x12020001, 0x12020102, 0x10000010, 0x10000011, 0x10000110, + 0x10000112, 0x10000211, 0x10010012, 0x10010111, 0x10010112, 0x10010210, 0x10010212, 0x10020011, + 0x10020112, 0x10020211, 0x11000111, 0x11000210, 0x11000212, 0x11010011, 0x11010110, 0x11010111, + 0x11010112, 0x11010211, 0x11010212, 0x11020111, 0x11020210, 0x11020212, 0x12000011, 0x12000110, + 0x12000112, 0x12010010, 0x12010012, 0x12010111, 0x12020010, 0x12020011, 0x12020012, 0x10000121, + 0x10010021, 0x10010120, 0x10010122, 0x10020121, 0x11000021, 0x11010022, 0x11010121, 0x11010222, + 0x11020120, 0x11020221, 0x12000221, 0x12010120, 0x12020121, 0x10001001, 0x10011101, 0x10011201, + 0x10021201, 0x11001101, 0x11001200, 0x11001202, 0x11011001, 0x11011100, 0x11011101, 0x11011102, + 0x11021001, 0x11021002, 0x11021101, 0x11021200, 0x11021202, 0x12001001, 0x12001102, 0x12001201, + 0x12011000, 0x12011002, 0x12011101, 0x12021000, 0x12021001, 0x12021201, 0x10001011, 0x10001012, + 0x10001111, 0x10001212, 0x10011011, 0x10011110, 0x10011111, 0x10011112, 0x10011211, 0x10021010, + 0x10021111, 0x10021212, 0x11001011, 0x11001110, 0x11001111, 0x11001112, 0x11001211, 0x11011010, + 0x11011011, 0x11011110, 0x11011111, 0x11011112, 0x11011210, 0x11011211, 0x11021011, 0x11021110, + 0x11021111, 0x11021112, 0x11021211, 0x12001012, 0x12001110, 0x12001111, 0x12001210, 0x12011011, + 0x12011110, 0x12011111, 0x12011112, 0x12011211, 0x12011212, 0x12021111, 0x12021210, 0x12021212, + 0x10001021, 0x10001121, 0x10001221, 0x10011120, 0x10011121, 0x10011220, 0x10011222, 0x10021021, + 0x10021120, 0x10021221, 0x11001020, 0x11001022, 0x11001121, 0x11001220, 0x11011020, 0x11011021, + 0x11011022, 0x11011121, 0x11011122, 0x11011221, 0x11021022, 0x11021121, 0x11021220, 0x12001021, + 0x12001121, 0x12001222, 0x12011120, 0x12011121, 0x12021021, 0x12021120, 0x12021122, 0x10002101, + 0x10012001, 0x10012101, 0x10012202, 0x10022101, 0x11002002, 0x11002201, 0x11012000, 0x11012101, + 0x11012200, 0x11022001, 0x11022100, 0x11022102, 0x11022201, 0x12002101, 0x12012001, 0x12012100, + 0x12012102, 0x12012201, 0x12022101, 0x10002011, 0x10002111, 0x10002112, 0x10002212, 0x10012010, + 0x10012110, 0x10012111, 0x10012210, 0x10022011, 0x10022110, 0x10022112, 0x11002010, 0x11002111, + 0x11002212, 0x11012011, 0x11012012, 0x11012110, 0x11012111, 0x11012112, 0x11012211, 0x11022010, + 0x11022012, 0x11022111, 0x11022112, 0x11022212, 0x12002112, 0x12002211, 0x12012012, 0x12012111, + 0x12012112, 0x12012210, 0x12022011, 0x12022110, 0x12022112, 0x12022211, 0x10012122, 0x11002120, + 0x11002122, 0x11002221, 0x11012121, 0x11012220, 0x11012222, 0x11022120, 0x11022221, 0x12012120, + 0x12022121, 0x10100001, 0x10100100, 0x10100101, 0x10100102, 0x10100201, 0x10110002, 0x10110101, + 0x10110202, 0x10120001, 0x10120100, 0x10120201, 0x11100000, 0x11100101, 0x11100200, 0x11110001, + 0x11110100, 0x11110101, 0x11110102, 0x11110201, 0x11120101, 0x11120200, 0x12100102, 0x12100201, + 0x12110101, 0x12110200, 0x12120000, 0x12120001, 0x12120102, 0x12120201, 0x10100111, 0x10100210, + 0x10100211, 0x10100212, 0x10110011, 0x10110110, 0x10110111, 0x10110112, 0x10110210, 0x10110211, + 0x10120010, 0x10120111, 0x10120112, 0x10120210, 0x10120212, 0x11100011, 0x11100110, 0x11100111, + 0x11100112, 0x11100211, 0x11110010, 0x11110011, 0x11110012, 0x11110110, 0x11110111, 0x11110112, + 0x11110210, 0x11110211, 0x11110212, 0x11120011, 0x11120110, 0x11120111, 0x11120112, 0x11120211, + 0x12100012, 0x12100111, 0x12110011, 0x12110110, 0x12110111, 0x12110112, 0x12110211, 0x12120010, + 0x12120111, 0x12120212, 0x10100021, 0x10100122, 0x10110022, 0x10110121, 0x10110222, 0x10120021, + 0x10120120, 0x11100022, 0x11100121, 0x11100222, 0x11110021, 0x11110120, 0x11110121, 0x11110122, + 0x11110221, 0x11120022, 0x11120121, 0x12100121, 0x12110020, 0x12110022, 0x12110121, 0x12110221, + 0x12110222, 0x12120120, 0x10101100, 0x10101101, 0x10111001, 0x10111100, 0x10111101, 0x10111102, + 0x10111200, 0x10111201, 0x10121001, 0x10121101, 0x10121200, 0x10121202, 0x11101001, 0x11101100, + 0x11101101, 0x11101102, 0x11101201, 0x11101202, 0x11111000, 0x11111001, 0x11111100, 0x11111101, + 0x11111102, 0x11111200, 0x11111201, 0x11111202, 0x11121001, 0x11121002, 0x11121100, 0x11121101, + 0x11121102, 0x11121201, 0x12101000, 0x12101200, 0x12101202, 0x12111001, 0x12111100, 0x12111101, + 0x12111102, 0x12111201, 0x12121001, 0x12121100, 0x12121101, 0x12121202, 0x10101011, 0x10101012, + 0x10101110, 0x10101111, 0x10101112, 0x10101211, 0x10111010, 0x10111011, 0x10111012, 0x10111110, + 0x10111111, 0x10111112, 0x10111211, 0x10111212, 0x10121011, 0x10121110, 0x10121111, 0x10121112, + 0x10121211, 0x11101010, 0x11101011, 0x11101012, 0x11101110, 0x11101111, 0x11101112, 0x11101210, + 0x11101211, 0x11111010, 0x11111011, 0x11111012, 0x11111110, 0x11111111, 0x11111112, 0x11111210, + 0x11111211, 0x11111212, 0x11121010, 0x11121011, 0x11121110, 0x11121111, 0x11121112, 0x11121210, + 0x11121211, 0x11121212, 0x12101011, 0x12101110, 0x12101111, 0x12101211, 0x12101212, 0x12111010, + 0x12111011, 0x12111110, 0x12111111, 0x12111112, 0x12111210, 0x12111211, 0x12121011, 0x12121110, + 0x12121111, 0x12121112, 0x12121211, 0x10101020, 0x10101021, 0x10101022, 0x10101120, 0x10101122, + 0x10101220, 0x10101221, 0x10111021, 0x10111120, 0x10111121, 0x10111220, 0x10111221, 0x10121020, + 0x10121021, 0x10121022, 0x10121120, 0x10121121, 0x10121122, 0x10121220, 0x10121221, 0x11101021, + 0x11101121, 0x11101122, 0x11101220, 0x11101221, 0x11101222, 0x11111020, 0x11111021, 0x11111022, + 0x11111120, 0x11111121, 0x11111122, 0x11111220, 0x11111221, 0x11111222, 0x11121021, 0x11121120, + 0x11121121, 0x11121221, 0x12101022, 0x12101121, 0x12101122, 0x12101220, 0x12101221, 0x12101222, + 0x12111021, 0x12111121, 0x12111222, 0x12121022, 0x12121121, 0x12121122, 0x12121220, 0x12121221, + 0x10102100, 0x10102101, 0x10102102, 0x10102201, 0x10112000, 0x10112101, 0x10112200, 0x10122001, + 0x10122202, 0x11102101, 0x11102200, 0x11102202, 0x11112001, 0x11112100, 0x11112101, 0x11112102, + 0x11112200, 0x11112201, 0x11122000, 0x11122002, 0x11122100, 0x11122101, 0x12102002, 0x12102201, + 0x12112000, 0x12112002, 0x12112101, 0x12112200, 0x12122001, 0x12122201, 0x10102011, 0x10102012, + 0x10102111, 0x10102212, 0x10112011, 0x10112110, 0x10112111, 0x10112112, 0x10112211, 0x10122111, + 0x11102011, 0x11102110, 0x11102111, 0x11102112, 0x11102211, 0x11112010, 0x11112011, 0x11112012, + 0x11112110, 0x11112111, 0x11112112, 0x11112210, 0x11112211, 0x11112212, 0x11122011, 0x11122110, + 0x11122111, 0x11122112, 0x11122211, 0x12102011, 0x12102111, 0x12102211, 0x12112011, 0x12112110, + 0x12112111, 0x12112112, 0x12112210, 0x12112211, 0x12122111, 0x10102120, 0x10102220, 0x10112121, + 0x10112222, 0x10122020, 0x10122121, 0x10122122, 0x10122221, 0x11102121, 0x11102220, 0x11102221, + 0x11112021, 0x11112121, 0x11112122, 0x11112220, 0x11112221, 0x11122022, 0x11122121, 0x11122220, + 0x11122222, 0x12102021, 0x12102222, 0x12112022, 0x12112121, 0x12112122, 0x12112220, 0x12112222, + 0x12122021, 0x10200101, 0x10210100, 0x10210102, 0x10210201, 0x10220101, 0x11200100, 0x11210000, + 0x11210101, 0x11210102, 0x11210200, 0x11210202, 0x11220001, 0x11220100, 0x11220102, 0x11220201, + 0x12200001, 0x12210102, 0x12220101, 0x10200011, 0x10200110, 0x10200112, 0x10200211, 0x10210012, + 0x10210111, 0x10220011, 0x10220012, 0x10220112, 0x10220211, 0x11200111, 0x11200211, 0x11210011, + 0x11210111, 0x11210112, 0x11210211, 0x11220111, 0x11220112, 0x11220212, 0x12200110, 0x12200212, + 0x12210012, 0x12210111, 0x12220011, 0x12220112, 0x12220211, 0x10210021, 0x10210122, 0x10210221, + 0x11200020, 0x11200021, 0x11200122, 0x11210121, 0x11210122, 0x11210220, 0x11220020, 0x12200121, + 0x12210021, 0x12210122, 0x12220121, 0x10211001, 0x10211002, 0x10211101, 0x10211102, 0x10211202, + 0x10221001, 0x10221102, 0x10221201, 0x11201000, 0x11201002, 0x11201101, 0x11201200, 0x11201202, + 0x11211001, 0x11211100, 0x11211101, 0x11211102, 0x11211201, 0x11211202, 0x11221000, 0x11221002, + 0x11221101, 0x12201100, 0x12201101, 0x12201201, 0x12211000, 0x12211002, 0x12211100, 0x12211101, + 0x12211102, 0x12211200, 0x12211202, 0x12221001, 0x12221100, 0x12221201, 0x10201111, 0x10201210, + 0x10201212, 0x10211011, 0x10211111, 0x10211112, 0x10211211, 0x11201110, 0x11201111, 0x11201112, + 0x11201211, 0x11211010, 0x11211011, 0x11211110, 0x11211111, 0x11211112, 0x11211211, 0x11221011, + 0x11221110, 0x11221111, 0x11221112, 0x11221211, 0x12201112, 0x12201211, 0x12201212, 0x12211011, + 0x12211111, 0x12211112, 0x12211211, 0x12211212, 0x12221012, 0x12221111, 0x12221112, 0x12221210, + 0x10201022, 0x10201221, 0x10211121, 0x10221020, 0x10221122, 0x10221220, 0x10221221, 0x11201020, + 0x11201121, 0x11201220, 0x11201222, 0x11211021, 0x11211120, 0x11211121, 0x11211122, 0x11211220, + 0x11211222, 0x11221020, 0x11221121, 0x11221220, 0x12201020, 0x12201022, 0x12201121, 0x12201222, + 0x12211120, 0x12211122, 0x12211220, 0x12211221, 0x12221020, 0x12221120, 0x12221122, 0x12221222, + 0x10212102, 0x10212201, 0x10222101, 0x11202001, 0x11212002, 0x11212101, 0x11212202, 0x11222001, + 0x11222201, 0x12202101, 0x12212001, 0x12212200, 0x12222102, 0x10202011, 0x10202110, 0x10212010, + 0x10212111, 0x10222011, 0x10222110, 0x10222112, 0x10222211, 0x11202010, 0x11202011, 0x11202111, + 0x11202112, 0x11202210, 0x11212011, 0x11212110, 0x11212111, 0x11212112, 0x11212211, 0x11222010, + 0x11222111, 0x11222212, 0x12202012, 0x12202110, 0x12202212, 0x12212111, 0x12222011, 0x12222110, + 0x12222111, 0x12222211, 0x10212021, 0x10212122, 0x10212220, 0x11202021, 0x11202120, 0x11202221, + 0x11212020, 0x11212121, 0x11212220, 0x11212222, 0x11222120, 0x11222121, 0x11222221, 0x12202122, + 0x12212120, 0x12212220, 0x12212222, 0x12222122, 0x20000000, 0x20000002, 0x20000200, 0x20000202, + 0x20020000, 0x20020002, 0x20020200, 0x20020202, 0x21000101, 0x21010000, 0x21010001, 0x21010100, + 0x21010102, 0x21010201, 0x21020101, 0x22000000, 0x22000002, 0x22000200, 0x22000202, 0x22010101, + 0x22020000, 0x22020002, 0x22020200, 0x22020202, 0x20000111, 0x20010011, 0x20010110, 0x20010112, + 0x20010211, 0x20020111, 0x21000011, 0x21000110, 0x21000211, 0x21010010, 0x21010012, 0x21010111, + 0x21010112, 0x21010210, 0x21010211, 0x21020110, 0x21020112, 0x21020211, 0x22000111, 0x22000211, + 0x22010110, 0x22010112, 0x22010211, 0x22020111, 0x20000020, 0x20000022, 0x20000220, 0x20000222, + 0x20010121, 0x20020020, 0x20020022, 0x20020220, 0x20020222, 0x21010021, 0x21010120, 0x21010221, + 0x21020121, 0x22000020, 0x22000022, 0x22000220, 0x22000222, 0x22010121, 0x22020020, 0x22020022, + 0x22020220, 0x22020222, 0x20011100, 0x20011201, 0x21001001, 0x21001100, 0x21011001, 0x21011101, + 0x21011202, 0x21021001, 0x21021100, 0x21021201, 0x22011100, 0x22011201, 0x20001011, 0x20001211, + 0x20011012, 0x20011111, 0x20011212, 0x20021112, 0x20021211, 0x21001010, 0x21001011, 0x21001111, + 0x21001210, 0x21011011, 0x21011110, 0x21011111, 0x21011112, 0x21011211, 0x21011212, 0x21021111, + 0x21021112, 0x21021210, 0x21021212, 0x22001011, 0x22001110, 0x22001112, 0x22001211, 0x22011010, + 0x22011012, 0x22011111, 0x22011210, 0x22021112, 0x20011021, 0x20011122, 0x20011221, 0x20021121, + 0x21001021, 0x21001120, 0x21001221, 0x21001222, 0x21011020, 0x21011121, 0x21011221, 0x21011222, + 0x21021021, 0x21021122, 0x21021222, 0x22001121, 0x22011021, 0x22011222, 0x22021120, 0x20002000, + 0x20002002, 0x20002200, 0x20002202, 0x20012101, 0x20022000, 0x20022002, 0x20022200, 0x20022202, + 0x21002001, 0x21002101, 0x21012001, 0x21012100, 0x21012201, 0x21022101, 0x21022201, 0x22002000, + 0x22002002, 0x22002200, 0x22002202, 0x22012101, 0x22022000, 0x22022002, 0x22022200, 0x22022202, + 0x20002111, 0x20002112, 0x20012011, 0x20012110, 0x20012112, 0x20022111, 0x21002011, 0x21002110, + 0x21002112, 0x21002211, 0x21012010, 0x21012012, 0x21012111, 0x21012212, 0x21022011, 0x21022110, + 0x22002111, 0x22012112, 0x22012211, 0x22022111, 0x20002020, 0x20002022, 0x20002220, 0x20002222, + 0x20012121, 0x20022020, 0x20022022, 0x20022220, 0x20022222, 0x21002121, 0x21012021, 0x21012120, + 0x21012122, 0x22002020, 0x22002022, 0x22002220, 0x22002222, 0x22012121, 0x22022020, 0x22022022, + 0x22022220, 0x22022222, 0x20100101, 0x20110001, 0x20110102, 0x20110200, 0x20110201, 0x20120101, + 0x21100001, 0x21100102, 0x21100201, 0x21110101, 0x21110200, 0x21110202, 0x21120201, 0x21120202, + 0x22100101, 0x22110001, 0x22110100, 0x22110102, 0x22110201, 0x22120101, 0x20100011, 0x20100110, + 0x20100112, 0x20100211, 0x20110010, 0x20110111, 0x20110210, 0x20110212, 0x20120011, 0x20120110, + 0x20120112, 0x20120211, 0x21100010, 0x21100111, 0x21110010, 0x21110011, 0x21110110, 0x21110111, + 0x21110112, 0x21110211, 0x21120012, 0x21120111, 0x22100110, 0x22100112, 0x22110012, 0x22110111, + 0x22110210, 0x22120011, 0x22120110, 0x22120112, 0x22120211, 0x20100121, 0x20110021, 0x20110120, + 0x20110221, 0x20120121, 0x21100120, 0x21100122, 0x21100221, 0x21110020, 0x21110022, 0x21110121, + 0x21110220, 0x21120122, 0x21120221, 0x22100121, 0x22110120, 0x22110122, 0x22120221, 0x20101001, + 0x20101100, 0x20101102, 0x20111000, 0x20111101, 0x20111200, 0x20121102, 0x21101000, 0x21101202, + 0x21111001, 0x21111100, 0x21111101, 0x21111102, 0x21111200, 0x21111201, 0x21121000, 0x21121001, + 0x21121002, 0x21121101, 0x22101100, 0x22101102, 0x22111002, 0x22111100, 0x22111101, 0x22111200, + 0x22121001, 0x22121201, 0x20101010, 0x20101111, 0x20101210, 0x20101212, 0x20111010, 0x20111011, + 0x20111110, 0x20111111, 0x20111112, 0x20111211, 0x20121011, 0x20121111, 0x20121211, 0x20121212, + 0x21101011, 0x21101110, 0x21101111, 0x21101112, 0x21101211, 0x21111010, 0x21111011, 0x21111012, + 0x21111110, 0x21111111, 0x21111112, 0x21111210, 0x21111211, 0x21111212, 0x21121011, 0x21121110, + 0x21121111, 0x21121112, 0x21121211, 0x22101011, 0x22101111, 0x22101210, 0x22111011, 0x22111012, + 0x22111110, 0x22111111, 0x22111112, 0x22111211, 0x22111212, 0x22121010, 0x22121012, 0x22121111, + 0x22121210, 0x22121212, 0x20101021, 0x20101120, 0x20111020, 0x20111121, 0x20111221, 0x20121020, + 0x20121122, 0x20121221, 0x21101121, 0x21101220, 0x21101221, 0x21111021, 0x21111022, 0x21111121, + 0x21111122, 0x21111221, 0x21121121, 0x21121220, 0x22101022, 0x22101120, 0x22101221, 0x22101222, + 0x22111022, 0x22111120, 0x22111121, 0x22121120, 0x22121122, 0x22121221, 0x20102101, 0x20112102, + 0x20112201, 0x20122101, 0x21102001, 0x21102102, 0x21112000, 0x21112002, 0x21112101, 0x21112102, + 0x21112202, 0x21122100, 0x21122101, 0x22102101, 0x22112001, 0x22112102, 0x22112201, 0x22122101, + 0x20102110, 0x20102112, 0x20102211, 0x20112010, 0x20112012, 0x20112111, 0x20112210, 0x20112212, + 0x20122010, 0x20122011, 0x20122110, 0x20122112, 0x21102010, 0x21102012, 0x21102111, 0x21102210, + 0x21102212, 0x21112011, 0x21112110, 0x21112111, 0x21112112, 0x21112211, 0x21122012, 0x21122111, + 0x21122112, 0x21122212, 0x22102011, 0x22102110, 0x22112010, 0x22112012, 0x22112111, 0x22112212, + 0x22122011, 0x22122112, 0x20102121, 0x20112121, 0x20122121, 0x21102120, 0x21102122, 0x21102221, + 0x21112020, 0x21112121, 0x21112220, 0x21122021, 0x22102121, 0x22112021, 0x22112120, 0x22112121, + 0x22112122, 0x20200000, 0x20200002, 0x20200200, 0x20200202, 0x20210101, 0x20220000, 0x20220002, + 0x20220200, 0x20220202, 0x21200101, 0x21210001, 0x21210100, 0x21210102, 0x21210201, 0x22200000, + 0x22200002, 0x22200200, 0x22200202, 0x22210101, 0x22220000, 0x22220002, 0x22220200, 0x22220202, + 0x20200111, 0x20200211, 0x20210011, 0x20210110, 0x20210112, 0x20210211, 0x20210212, 0x21200112, + 0x21200211, 0x21210011, 0x21210111, 0x21210210, 0x21210212, 0x21220011, 0x21220110, 0x22200111, + 0x22210010, 0x22210012, 0x22210112, 0x22210211, 0x20200022, 0x20200220, 0x20200222, 0x20210020, + 0x20210221, 0x20220022, 0x20220220, 0x20220222, 0x21200121, 0x21210021, 0x21210122, 0x21210221, + 0x21220121, 0x22200020, 0x22200022, 0x22200220, 0x22200222, 0x22210121, 0x22220020, 0x22220022, + 0x22220220, 0x22220222, 0x20211201, 0x20221101, 0x21201001, 0x21201100, 0x21211000, 0x21211100, + 0x21211101, 0x21211200, 0x21211202, 0x21221001, 0x21221101, 0x21221102, 0x21221200, 0x21221201, + 0x22201101, 0x20201112, 0x20201211, 0x20211010, 0x20211012, 0x20211111, 0x20211210, 0x20221112, + 0x20221211, 0x21201012, 0x21201111, 0x21211011, 0x21211110, 0x21211111, 0x21211112, 0x21211211, + 0x21221111, 0x21221212, 0x22201011, 0x22201110, 0x22201111, 0x22201112, 0x22201211, 0x22211012, + 0x22211111, 0x22211210, 0x20201121, 0x20211021, 0x20211122, 0x20211222, 0x20221021, 0x20221121, + 0x21201120, 0x21201122, 0x21201222, 0x21211022, 0x21211121, 0x21211122, 0x21211220, 0x21221020, + 0x21221022, 0x22201122, 0x22211020, 0x22211121, 0x22211122, 0x22211221, 0x22221021, 0x22221120, + 0x22221122, 0x20202000, 0x20202002, 0x20202200, 0x20202202, 0x20222000, 0x20222002, 0x20222200, + 0x20222202, 0x21212001, 0x21212100, 0x21212102, 0x21212201, 0x22202000, 0x22202002, 0x22202200, + 0x22202202, 0x22212101, 0x22222000, 0x22222002, 0x22222200, 0x22222202, 0x20202111, 0x20212110, + 0x20212211, 0x20222011, 0x20222111, 0x21202011, 0x21212010, 0x21212111, 0x21212212, 0x21222011, + 0x21222112, 0x21222211, 0x22212010, 0x22212112, 0x20202020, 0x20202022, 0x20202220, 0x20202222, + 0x20222020, 0x20222022, 0x20222220, 0x20222222, 0x21212021, 0x21212120, 0x21212122, 0x22202020, + 0x22202022, 0x22202220, 0x22202222, 0x22212121, 0x22222020, 0x22222022, 0x22222220, 0x22222222, +GGML_TABLE_END() + + +enum ggml_sort_order { + GGML_SORT_ORDER_ASC, + GGML_SORT_ORDER_DESC, +}; + +// general-purpose kernel for addition, subtraction, multiplication and division of two tensors +// pros: works for non-contiguous tensors, supports broadcast across all dims +// cons: not very efficient +kernel void kernel_add( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant int64_t & offs, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig.z; + const int64_t i02 = tgpig.y; + const int64_t i01 = tgpig.x; + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs; + device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; + device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + const int i10 = i0 % ne10; + *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) + *((device float *)(src1_ptr + i10*nb10)); + } +} + +kernel void kernel_sub( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant int64_t & offs, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig.z; + const int64_t i02 = tgpig.y; + const int64_t i01 = tgpig.x; + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs; + device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; + device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + const int i10 = i0 % ne10; + *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) - *((device float *)(src1_ptr + i10*nb10)); + } +} + +kernel void kernel_mul( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig.z; + const int64_t i02 = tgpig.y; + const int64_t i01 = tgpig.x; + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; + device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; + device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + const int i10 = i0 % ne10; + *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) * *((device float *)(src1_ptr + i10*nb10)); + } +} + +kernel void kernel_div( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig.z; + const int64_t i02 = tgpig.y; + const int64_t i01 = tgpig.x; + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; + device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; + device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + const int i10 = i0 % ne10; + *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) / *((device float *)(src1_ptr + i10*nb10)); + } +} + +template +kernel void kernel_repeat( + device const char * src0, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; + + const int64_t i03 = i3 % ne03; + const int64_t i02 = i2 % ne02; + const int64_t i01 = i1 % ne01; + + device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; + device char * dst_ptr = dst + i3*nb3 + i2*nb2 + i1*nb1 ; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + const int i00 = i0 % ne00; + *((device T *)(dst_ptr + i0*nb0)) = *((device T *)(src0_ptr + i00*nb00)); + } +} + +typedef decltype(kernel_repeat) kernel_repeat_t; + +template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat; +template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat; +template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat; +template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat; + +// assumption: src1 is a row +// broadcast src1 into src0 +kernel void kernel_add_row( + device const float4 * src0, + device const float4 * src1, + device float4 * dst, + constant uint64_t & nb [[buffer(28)]], + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] + src1[tpig % nb]; +} + +kernel void kernel_sub_row( + device const float4 * src0, + device const float4 * src1, + device float4 * dst, + constant uint64_t & nb [[buffer(28)]], + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] - src1[tpig % nb]; +} + +kernel void kernel_mul_row( + device const float4 * src0, + device const float4 * src1, + device float4 * dst, + constant uint64_t & nb [[buffer(28)]], + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * src1[tpig % nb]; +} + +kernel void kernel_div_row( + device const float4 * src0, + device const float4 * src1, + device float4 * dst, + constant uint64_t & nb [[buffer(28)]], + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] / src1[tpig % nb]; +} + +kernel void kernel_scale( + device const float * src0, + device float * dst, + constant float & scale, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * scale; +} + +kernel void kernel_scale_4( + device const float4 * src0, + device float4 * dst, + constant float & scale, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * scale; +} + +kernel void kernel_clamp( + device const float * src0, + device float * dst, + constant float & min, + constant float & max, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] < min ? min : (src0[tpig] > max ? max : src0[tpig]); +} + +kernel void kernel_relu( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = max(0.0f, src0[tpig]); +} + +kernel void kernel_sigmoid( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig])); +} + +kernel void kernel_tanh( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + dst[tpig] = precise::tanh(x); +} + +constant float GELU_COEF_A = 0.044715f; +constant float GELU_QUICK_COEF = -1.702f; +constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + +kernel void kernel_gelu( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + + dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + +kernel void kernel_gelu_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + device const float4 & x = src0[tpig]; + + // BEWARE !!! + // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs! + // This was observed with Falcon 7B and 40B models + // + dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + +kernel void kernel_gelu_quick( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + + dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); +} + +kernel void kernel_gelu_quick_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + device const float4 & x = src0[tpig]; + + dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); +} + +kernel void kernel_silu( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + dst[tpig] = x / (1.0f + exp(-x)); +} + +kernel void kernel_silu_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + device const float4 & x = src0[tpig]; + dst[tpig] = x / (1.0f + exp(-x)); +} + +kernel void kernel_sqr( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * src0[tpig]; +} + +kernel void kernel_sqrt( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = sqrt(src0[tpig]); +} + +kernel void kernel_sin( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = sin(src0[tpig]); +} + +kernel void kernel_cos( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = cos(src0[tpig]); +} + +kernel void kernel_sum_rows( + device const float * src0, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tpig[[thread_position_in_grid]]) { + int64_t i3 = tpig.z; + int64_t i2 = tpig.y; + int64_t i1 = tpig.x; + + if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { + return; + } + + device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03); + device float * dst_row = (device float *) ((device char *) dst + i1*nb1 + i2*nb2 + i3*nb3); + + float row_sum = 0; + + for (int64_t i0 = 0; i0 < ne00; i0++) { + row_sum += src_row[i0]; + } + + dst_row[0] = row_sum; +} + +template +kernel void kernel_soft_max( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant float & scale, + constant float & max_bias, + constant float & m0, + constant float & m1, + constant uint32_t & n_head_log2, + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + const int64_t i03 = (tgpig) / (ne02*ne01); + const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; + const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); + + device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr; + device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + const int64_t h = i02; + + const float base = h < n_head_log2 ? m0 : m1; + const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } + + // parallel max + float lmax = -INFINITY; + + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); + } + + // find the max value in the block + float max_val = simd_max(lmax); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = -INFINITY; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = max_val; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + max_val = buf[tiisg]; + max_val = simd_max(max_val); + } + + // parallel sum + float lsum = 0.0f; + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val); + lsum += exp_psrc0; + pdst[i00] = exp_psrc0; + } + + // This barrier fixes a failing test + // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335 + threadgroup_barrier(mem_flags::mem_none); + + float sum = simd_sum(lsum); + + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = sum; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sum = buf[tiisg]; + sum = simd_sum(sum); + } + + const float inv_sum = 1.0f/sum; + + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + pdst[i00] *= inv_sum; + } +} + +template +kernel void kernel_soft_max_4( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant float & scale, + constant float & max_bias, + constant float & m0, + constant float & m1, + constant uint32_t & n_head_log2, + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + const int64_t i03 = (tgpig) / (ne02*ne01); + const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; + const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); + + device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; + device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr; + device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; + + float slope = 1.0f; + + if (max_bias > 0.0f) { + const int64_t h = i02; + + const float base = h < n_head_log2 ? m0 : m1; + const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } + + // parallel max + float4 lmax4 = -INFINITY; + + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))); + } + + const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); + + float max_val = simd_max(lmax); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = -INFINITY; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = max_val; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + max_val = buf[tiisg]; + max_val = simd_max(max_val); + } + + // parallel sum + float4 lsum4 = 0.0f; + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val); + lsum4 += exp_psrc4; + pdst4[i00] = exp_psrc4; + } + + const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3]; + + // This barrier fixes a failing test + // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335 + threadgroup_barrier(mem_flags::mem_none); + + float sum = simd_sum(lsum); + + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = sum; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sum = buf[tiisg]; + sum = simd_sum(sum); + } + + const float inv_sum = 1.0f/sum; + + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + pdst4[i00] *= inv_sum; + } +} + +typedef decltype(kernel_soft_max) kernel_soft_max_t; +typedef decltype(kernel_soft_max_4) kernel_soft_max_4_t; + +template [[host_name("kernel_soft_max_f16")]] kernel kernel_soft_max_t kernel_soft_max; +template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kernel_soft_max; +template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; +template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; + +kernel void kernel_diag_mask_inf( + device const float * src0, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int & n_past, + uint3 tpig[[thread_position_in_grid]]) { + const int64_t i02 = tpig[2]; + const int64_t i01 = tpig[1]; + const int64_t i00 = tpig[0]; + + if (i00 > n_past + i01) { + dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY; + } else { + dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00]; + } +} + +kernel void kernel_diag_mask_inf_8( + device const float4 * src0, + device float4 * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int & n_past, + uint3 tpig[[thread_position_in_grid]]) { + + const int64_t i = 2*tpig[0]; + + dst[i+0] = src0[i+0]; + dst[i+1] = src0[i+1]; + int64_t i4 = 4*i; + const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01; + const int64_t i01 = i4/(ne00); i4 -= i01*ne00; + const int64_t i00 = i4; + for (int k = 3; k >= 0; --k) { + if (i00 + 4 + k <= n_past + i01) { + break; + } + dst[i+1][k] = -INFINITY; + if (i00 + k > n_past + i01) { + dst[i][k] = -INFINITY; + } + } +} + +// ref: ggml.c:ggml_compute_forward_ssm_conv_f32 +// TODO: optimize +kernel void kernel_ssm_conv_f32( + device const void * src0, + device const void * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t ir = tgpig.x; + const int64_t i2 = tgpig.y; + const int64_t i3 = tgpig.z; + + const int64_t nc = ne10; + const int64_t ncs = ne00; + const int64_t nr = ne01; + const int64_t n_t = ne1; + const int64_t n_s = ne2; + + device const float * s = (device const float *) ((device const char *) src0 + ir*nb01 + i2*nb00 + i3*nb02); + device const float * c = (device const float *) ((device const char *) src1 + ir*nb11); + device float * x = (device float *) ((device char *) dst + ir*nb0 + i2*nb1 + i3*nb2); + + float sumf = 0.0f; + + for (int64_t i0 = 0; i0 < nc; ++i0) { + sumf += s[i0] * c[i0]; + } + + x[0] = sumf; +} + +// ref: ggml.c:ggml_compute_forward_ssm_scan_f32 +// TODO: optimize +kernel void kernel_ssm_scan_f32( + device const void * src0, + device const void * src1, + device const void * src2, + device const void * src3, + device const void * src4, + device const void * src5, + device float * dst, + constant int64_t & d_state, + constant int64_t & d_inner, + constant int64_t & n_seq_tokens, + constant int64_t & n_seqs, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant uint64_t & nb20, + constant uint64_t & nb21, + constant uint64_t & nb22, + constant uint64_t & nb30, + constant uint64_t & nb31, + constant uint64_t & nb40, + constant uint64_t & nb41, + constant uint64_t & nb42, + constant uint64_t & nb50, + constant uint64_t & nb51, + constant uint64_t & nb52, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t ir = tgpig.x; + const int64_t i3 = tgpig.y; + + const int64_t nc = d_state; + const int64_t nr = d_inner; + const int64_t n_t = n_seq_tokens; + const int64_t n_s = n_seqs; + + for (int64_t i2 = 0; i2 < n_t; ++i2) { + device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02); + device const float * x = (device const float *) ((device const char *) src1 + ir*nb10 + i2*nb11 + i3*nb12); + device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22); + device const float * A = (device const float *) ((device const char *) src3 + ir*nb31); + device const float * B = (device const float *) ((device const char *) src4 + i2*nb41 + i3*nb42); + device const float * C = (device const float *) ((device const char *) src5 + i2*nb51 + i3*nb52); + device float * y = (device float *) ((device char *) dst + ir*nb10 + i2*nb11 + i3*nb12); // TODO: do not use src1 strides + device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb02 + nb13); + + if (i2 > 0) { + s0 = s; + } + + // i1 == 0 + float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; + float x_dt = x[0] * dt_soft_plus; + float sumf = 0.0f; + + for (int64_t i0 = 0; i0 < nc; ++i0) { + int64_t i = i0; + float state = (s0[i] * exp(dt_soft_plus * A[i])) + (B[i0] * x_dt); + sumf += state * C[i0]; + s[i] = state; + } + + y[0] = sumf; + } +} + +kernel void kernel_norm( + device const void * src0, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant float & eps, + threadgroup float * sum [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint ntg[[threads_per_threadgroup]]) { + device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01); + // MEAN + // parallel sum + sum[tpitg] = 0.0f; + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + sum[tpitg] += x[i00]; + } + // reduce + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint i = ntg/2; i > 0; i /= 2) { + if (tpitg < i) { + sum[tpitg] += sum[tpitg + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + const float mean = sum[0] / ne00; + + // recenter and VARIANCE + threadgroup_barrier(mem_flags::mem_threadgroup); + device float * y = dst + tgpig*ne00; + sum[tpitg] = 0.0f; + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + y[i00] = x[i00] - mean; + sum[tpitg] += y[i00] * y[i00]; + } + + // reduce + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint i = ntg/2; i > 0; i /= 2) { + if (tpitg < i) { + sum[tpitg] += sum[tpitg + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + const float variance = sum[0] / ne00; + + const float scale = 1.0f/sqrt(variance + eps); + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + y[i00] = y[i00] * scale; + } +} + +kernel void kernel_rms_norm( + device const void * src0, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant float & eps, + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01); + + float4 sumf = 0; + float all_sum = 0; + + // parallel sum + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + sumf += x[i00] * x[i00]; + } + all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3]; + all_sum = simd_sum(all_sum); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = all_sum; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + all_sum = buf[tiisg]; + all_sum = simd_sum(all_sum); + } + + const float mean = all_sum/ne00; + const float scale = 1.0f/sqrt(mean + eps); + + device float4 * y = (device float4 *) (dst + tgpig*ne00); + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + y[i00] = x[i00] * scale; + } +} + +kernel void kernel_group_norm( + device const float * src0, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int32_t & n_groups, + constant float & eps, + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + const int64_t ne = ne00*ne01*ne02; + const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups); + + int start = tgpig * gs; + int end = start + gs; + + start += tpitg; + + if (end >= ne) { + end = ne; + } + + float tmp = 0.0f; // partial sum for thread in warp + + for (int j = start; j < end; j += ntg) { + tmp += src0[j]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + tmp = simd_sum(tmp); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = tmp; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + tmp = buf[tiisg]; + tmp = simd_sum(tmp); + } + + const float mean = tmp / gs; + tmp = 0.0f; + + for (int j = start; j < end; j += ntg) { + float xi = src0[j] - mean; + dst[j] = xi; + tmp += xi * xi; + } + + tmp = simd_sum(tmp); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = tmp; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + tmp = buf[tiisg]; + tmp = simd_sum(tmp); + } + + const float variance = tmp / gs; + const float scale = 1.0f/sqrt(variance + eps); + for (int j = start; j < end; j += ntg) { + dst[j] *= scale; + } +} + +// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q4 quants begin (0 or QK4_0/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) { + float d = qb_curr->d; + + float2 acc = 0.f; + + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2); + + for (int i = 0; i < 8; i+=2) { + acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) + + yl[i + 1] * (qs[i / 2] & 0x0F00); + acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0) + + yl[i + 9] * (qs[i / 2] & 0xF000); + } + return d * (sumy * -8.f + acc[0] + acc[1]); +} + +// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q4 quants begin (0 or QK4_0/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) { + float d = qb_curr->d; + float m = qb_curr->m; + + float2 acc = 0.f; + + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2); + + for (int i = 0; i < 8; i+=2) { + acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) + + yl[i + 1] * (qs[i / 2] & 0x0F00); + acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0) + + yl[i + 9] * (qs[i / 2] & 0xF000); + } + return d * (acc[0] + acc[1]) + sumy * m; +} + +// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q5 quants begin (0 or QK5_0/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) { + float d = qb_curr->d; + + float2 acc = 0.f; + + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2); + const uint32_t qh = *((device const uint32_t *)qb_curr->qh); + + for (int i = 0; i < 8; i+=2) { + acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)) + + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); + acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)) + + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); + } + return d * (sumy * -16.f + acc[0] + acc[1]); +} + +// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q5 quants begin (0 or QK5_1/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) { + float d = qb_curr->d; + float m = qb_curr->m; + + float2 acc = 0.f; + + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2); + const uint32_t qh = *((device const uint32_t *)qb_curr->qh); + + for (int i = 0; i < 8; i+=2) { + acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)) + + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); + acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)) + + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); + } + return d * (acc[0] + acc[1]) + sumy * m; +} + +// putting them in the kernel cause a significant performance penalty +#define N_DST 4 // each SIMD group works on 4 rows +#define N_SIMDGROUP 2 // number of SIMD groups in a thread group +//Note: This is a template, but strictly speaking it only applies to +// quantizations where the block size is 32. It also does not +// guard against the number of rows not being divisible by +// N_DST, so this is another explicit assumption of the implementation. +template +void mul_vec_q_n_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, uint tiisg, uint sgitg) { + const int nb = ne00/QK4_0; + + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * nsg + sgitg) * nr; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_q_type * x = (device const block_q_type *) src0 + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[16]; // src1 vector cache + float sumf[nr] = {0.f}; + + const int ix = (tiisg/2); + const int il = (tiisg%2)*8; + + device const float * yb = y + ix * QK4_0 + il; + + // each thread in a SIMD group deals with half a block. + for (int ib = ix; ib < nb; ib += nw/2) { + float sumy = 0; + for (int i = 0; i < 8; i += 2) { + sumy += yb[i] + yb[i+1]; + yl[i+0] = yb[i+ 0]; + yl[i+1] = yb[i+ 1]/256.f; + + sumy += yb[i+16] + yb[i+17]; + yl[i+8] = yb[i+16]/16.f; + yl[i+9] = yb[i+17]/4096.f; + } + + for (int row = 0; row < nr; row++) { + sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il); + } + + yb += QK4_0 * 16; + } + + for (int row = 0; row < nr; ++row) { + const float tot = simd_sum(sumf[row]); + if (tiisg == 0 && first_row + row < ne01) { + dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot; + } + } +} + +kernel void kernel_mul_mv_q4_0_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); +} + +kernel void kernel_mul_mv_q4_1_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); +} + +kernel void kernel_mul_mv_q5_0_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); +} + +kernel void kernel_mul_mv_q5_1_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); +} + + +#define NB_Q8_0 8 + +void kernel_mul_mv_q8_0_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { + const int nr = N_DST; + const int nsg = N_SIMDGROUP; + const int nw = N_SIMDWIDTH; + + const int nb = ne00/QK8_0; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * nsg + sgitg) * nr; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[NB_Q8_0]; + float sumf[nr]={0.f}; + + const int ix = tiisg/4; + const int il = tiisg%4; + + device const float * yb = y + ix * QK8_0 + NB_Q8_0*il; + + // each thread in a SIMD group deals with NB_Q8_0 quants at a time + for (int ib = ix; ib < nb; ib += nw/4) { + for (int i = 0; i < NB_Q8_0; ++i) { + yl[i] = yb[i]; + } + + for (int row = 0; row < nr; row++) { + device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il; + float sumq = 0.f; + for (int iq = 0; iq < NB_Q8_0; ++iq) { + sumq += qs[iq] * yl[iq]; + } + sumf[row] += sumq*x[ib+row*nb].d; + } + + yb += NB_Q8_0 * nw; + } + + for (int row = 0; row < nr; ++row) { + const float tot = simd_sum(sumf[row]); + if (tiisg == 0 && first_row + row < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; + } + } +} + +[[host_name("kernel_mul_mv_q8_0_f32")]] +kernel void kernel_mul_mv_q8_0_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); +} + +#define N_MV_T_T 4 + +template +void kernel_mul_mv_impl( + device const char * src0, + device const char * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + uint64_t nb00, + uint64_t nb01, + uint64_t nb02, + int64_t ne10, + int64_t ne11, + int64_t ne12, + uint64_t nb10, + uint64_t nb11, + uint64_t nb12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + uint3 tgpig, + uint tiisg) { + const int64_t r0 = tgpig.x; + const int64_t rb = tgpig.y*N_MV_T_T; + const int64_t im = tgpig.z; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; + + device const T0 * x = (device const T0 *) (src0 + offset0); + + if (ne00 < 128) { + for (int row = 0; row < N_MV_T_T; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12); + + float sumf = 0; + for (int i = tiisg; i < ne00; i += 32) { + sumf += (T0) x[i] * (T1) y[i]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } else { + device const T04 * x4 = (device const T04 *) x; + for (int row = 0; row < N_MV_T_T; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12); + device const T14 * y4 = (device const T14 *) y; + + float sumf = 0; + for (int i = tiisg; i < ne00/4; i += 32) { + for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]); + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]); + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } +} + +template +kernel void kernel_mul_mv( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + kernel_mul_mv_impl( + src0, + src1, + dst, + ne00, + ne01, + ne02, + nb00, + nb01, + nb02, + ne10, + ne11, + ne12, + nb10, + nb11, + nb12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg); +} + +typedef decltype(kernel_mul_mv) mul_mv_t; + +template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t kernel_mul_mv; +template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t kernel_mul_mv; +template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t kernel_mul_mv; + +template +kernel void kernel_mul_mv_1row( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + const int64_t im = tgpig.z; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; + + device const T * x = (device const T *) (src0 + offset0); + device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + + float sumf = 0; + if (ne00 < 128) { + for (int i = tiisg; i < ne00; i += 32) { + sumf += (float) x[i] * (float) y[i]; + } + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } else { + device const T4 * x4 = (device const T4 *) x; + device const float4 * y4 = (device const float4 *) y; + + for (int i = tiisg; i < ne00/4; i += 32) { + for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]); + } + + float all_sum = simd_sum(sumf); + + if (tiisg == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]); + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } +} + +typedef decltype(kernel_mul_mv_1row) mul_mv_1row_t; + +template [[host_name("kernel_mul_mv_f16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row; + +// Assumes row size (ne00) is a multiple of 4 +template +kernel void kernel_mul_mv_l4( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + + const int nrows = ne11; + const int64_t r0 = tgpig.x; + const int64_t im = tgpig.z; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; + + device const T4 * x4 = (device const T4 *) (src0 + offset0); + + for (int r1 = 0; r1 < nrows; ++r1) { + device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12); + + float sumf = 0; + for (int i = tiisg; i < ne00/4; i += 32) { + for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]); + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } +} + +typedef decltype(kernel_mul_mv_l4) mul_mv_l4_t; + +template [[host_name("kernel_mul_mv_f16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4; + +static float rope_yarn_ramp(const float low, const float high, const int i0) { + const float y = (i0 / 2 - low) / max(0.001f, high - low); + return 1.0f - min(1.0f, max(0.0f, y)); +} + +// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn +// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. +static void rope_yarn( + float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale, + thread float * cos_theta, thread float * sin_theta) { + // Get n-d rotational scaling corrected for extrapolation + float theta_interp = freq_scale * theta_extrap; + float theta = theta_interp; + if (ext_factor != 0.0f) { + float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; + theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + + // Get n-d magnitude scaling corrected for interpolation + mscale *= 1.0f + 0.1f * log(1.0f / freq_scale); + } + *cos_theta = cos(theta) * mscale; + *sin_theta = sin(theta) * mscale; +} + +// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get +// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` +static float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) { + return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base)); +} + +static void rope_yarn_corr_dims( + int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2] +) { + // start and end correction dims + dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base))); + dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base))); +} + +template +kernel void kernel_rope_norm( + device const void * src0, + device const int32_t * src1, + device const float * src2, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant int & n_past, + constant int & n_dims, + constant int & n_ctx_orig, + constant float & freq_base, + constant float & freq_scale, + constant float & ext_factor, + constant float & attn_factor, + constant float & beta_fast, + constant float & beta_slow, + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg[[threads_per_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]]) { + const int64_t i3 = tgpig[2]; + const int64_t i2 = tgpig[1]; + const int64_t i1 = tgpig[0]; + + float corr_dims[2]; + rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); + + device const int32_t * pos = src1; + + const float theta_base = (float) pos[i2]; + const float inv_ndims = -1.f/n_dims; + + float cos_theta; + float sin_theta; + + for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { + if (i0 < n_dims) { + const int64_t ic = i0/2; + + const float theta = theta_base * pow(freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + + device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float x0 = src[0]; + const float x1 = src[1]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[1] = x0*sin_theta + x1*cos_theta; + } else { + device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +template +kernel void kernel_rope_neox( + device const void * src0, + device const int32_t * src1, + device const float * src2, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant int & n_past, + constant int & n_dims, + constant int & n_ctx_orig, + constant float & freq_base, + constant float & freq_scale, + constant float & ext_factor, + constant float & attn_factor, + constant float & beta_fast, + constant float & beta_slow, + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg[[threads_per_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]]) { + const int64_t i3 = tgpig[2]; + const int64_t i2 = tgpig[1]; + const int64_t i1 = tgpig[0]; + + float corr_dims[2]; + rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); + + device const int32_t * pos = src1; + + const float theta_base = (float) pos[i2]; + const float inv_ndims = -1.f/n_dims; + + float cos_theta; + float sin_theta; + + for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { + if (i0 < n_dims) { + const int64_t ic = i0/2; + + const float theta = theta_base * pow(freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + + device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; + } else { + device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +typedef decltype(kernel_rope_norm) kernel_rope_norm_t; +typedef decltype(kernel_rope_neox) kernel_rope_neox_t; + +template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm; +template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm; + +template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox; +template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox; + +typedef void (im2col_t)( + device const float * x, + device char * dst, + constant int32_t & ofs0, + constant int32_t & ofs1, + constant int32_t & IW, + constant int32_t & IH, + constant int32_t & CHW, + constant int32_t & s0, + constant int32_t & s1, + constant int32_t & p0, + constant int32_t & p1, + constant int32_t & d0, + constant int32_t & d1, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]); + +template +kernel void kernel_im2col( + device const float * x, + device char * dst, + constant int32_t & ofs0, + constant int32_t & ofs1, + constant int32_t & IW, + constant int32_t & IH, + constant int32_t & CHW, + constant int32_t & s0, + constant int32_t & s1, + constant int32_t & p0, + constant int32_t & p1, + constant int32_t & d0, + constant int32_t & d1, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int32_t iiw = tgpig[2] * s0 + tpitg[2] * d0 - p0; + const int32_t iih = tgpig[1] * s1 + tpitg[1] * d1 - p1; + + const int32_t offset_dst = + (tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW + + (tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]); + + device T * pdst = (device T *) (dst); + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + pdst[offset_dst] = 0.0f; + } else { + const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1; + pdst[offset_dst] = x[offset_src + iih * IW + iiw]; + } +} + +template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col; +template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col; + +typedef void (im2col_ext_t)( + device const float * x, + device char * dst, + constant int32_t & ofs0, + constant int32_t & ofs1, + constant int32_t & IW, + constant int32_t & IH, + constant int32_t & CHW, + constant int32_t & s0, + constant int32_t & s1, + constant int32_t & p0, + constant int32_t & p1, + constant int32_t & d0, + constant int32_t & d1, + constant int32_t & N, + constant int32_t & KH, + constant int32_t & KW, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]); + +template +kernel void kernel_im2col_ext( + device const float * x, + device char * dst, + constant int32_t & ofs0, + constant int32_t & ofs1, + constant int32_t & IW, + constant int32_t & IH, + constant int32_t & CHW, + constant int32_t & s0, + constant int32_t & s1, + constant int32_t & p0, + constant int32_t & p1, + constant int32_t & d0, + constant int32_t & d1, + constant int32_t & N, + constant int32_t & KH, + constant int32_t & KW, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1] + const int32_t KHW = KH * KW; // KHW == ntg[1] * ntg[2], KW == ntg[2] + + const int32_t d = tgpig[0] / CHW; + const int32_t chw = tgpig[0] % CHW; + const int32_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1) + const int32_t HW = tgpig[0] % KHW; + + const int32_t tpitg_0 = (d * ntg[0]) + tpitg[0]; + if (tpitg_0 >= N) { + return; + } + + const int32_t tpitg_1 = HW / KW; + const int32_t tpitg_2 = HW % KW; + + const int32_t iiw = tgpig[2] * s0 + tpitg_2 * d0 - p0; + const int32_t iih = tgpig[1] * s1 + tpitg_1 * d1 - p1; + + const int32_t offset_dst = + (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW + + (tgpig_0 * KHW + tpitg_1 * KW + tpitg_2); + + device T * pdst = (device T *) (dst); + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + pdst[offset_dst] = 0.0f; + } else { + const int32_t offset_src = tpitg_0 * ofs0 + tgpig_0 * ofs1; + pdst[offset_dst] = x[offset_src + iih * IW + iiw]; + } +} + +template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext; +template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext; + +kernel void kernel_upscale_f32( + device const char * src0, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant float & sf0, + constant float & sf1, + constant float & sf2, + constant float & sf3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; + + const int64_t i03 = i3/sf3; + const int64_t i02 = i2/sf2; + const int64_t i01 = i1/sf1; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + const int64_t i00 = i0/sf0; + + device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_ptr[0] = src0_ptr[0]; + } +} + +kernel void kernel_pad_f32( + device const char * src0, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; + + const int64_t i03 = i3; + const int64_t i02 = i2; + const int64_t i01 = i1; + + device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01); + device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1); + + if (i1 < ne01 && i2 < ne02 && i3 < ne03) { + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + if (i0 < ne00) { + dst_ptr[i0] = src0_ptr[i0]; + } else { + dst_ptr[i0] = 0.0f; + } + } + + return; + } + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + dst_ptr[i0] = 0.0f; + } +} + +kernel void kernel_arange_f32( + device char * dst, + constant int64_t & ne0, + constant float & start, + constant float & step, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + device float * dst_ptr = (device float *) dst; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + dst_ptr[i0] = start + step * i0; + } +} + +kernel void kernel_timestep_embedding_f32( + device const char * src0, + device char * dst, + constant uint64_t & nb1, + constant int & dim, + constant int & max_period, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + int i = tgpig.x; + device float * embed_data = (device float *)(dst + i*nb1); + + int half_ = dim / 2; + for (int j = tpitg.x; j < half_; j += ntg.x) { + float timestep = ((device float *)src0)[i]; + float freq = (float)exp(-log((float)max_period) * j / half_); + float arg = timestep * freq; + embed_data[j ] = cos(arg); + embed_data[j + half_] = sin(arg); + } + + if (dim % 2 != 0 && tpitg.x == 0) { + embed_data[dim] = 0.f; + } +} + +// bitonic sort implementation following the CUDA kernels as reference +typedef void (argsort_t)( + device const float * x, + device int32_t * dst, + constant int64_t & ncols, + constant int64_t & ncols_pad, + threadgroup int32_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]]); + +template +kernel void kernel_argsort_f32_i32( + device const float * x, + device int32_t * dst, + constant int64_t & ncols, + constant int64_t & ncols_pad, + threadgroup int32_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]]) { + // bitonic sort + int col = tpitg[0]; + int row = tgpig[1]; + + if (col >= ncols_pad) return; + + device const float * x_row = x + row * ncols; + threadgroup int32_t * dst_row = shared_values; + + // initialize indices + dst_row[col] = col; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int k = 2; k <= ncols_pad; k *= 2) { + for (int j = k / 2; j > 0; j /= 2) { + int ixj = col ^ j; + if (ixj > col) { + if ((col & k) == 0) { + if (dst_row[col] >= ncols || + (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ? + x_row[dst_row[col]] > x_row[dst_row[ixj]] : + x_row[dst_row[col]] < x_row[dst_row[ixj]])) + ) { + SWAP(dst_row[col], dst_row[ixj]); + } + } else { + if (dst_row[ixj] >= ncols || + (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ? + x_row[dst_row[col]] < x_row[dst_row[ixj]] : + x_row[dst_row[col]] > x_row[dst_row[ixj]])) + ) { + SWAP(dst_row[col], dst_row[ixj]); + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + } + + // copy the result to dst without the padding + if (col < ncols) { + dst[row * ncols + col] = dst_row[col]; + } +} + +template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32; +template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32; + +kernel void kernel_leaky_relu_f32( + device const float * src0, + device float * dst, + constant float & slope, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope; +} + +typedef void (flash_attn_ext_f16_t)( + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device float * dst, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant uint64_t & nb21, + constant uint64_t & nb22, + constant uint64_t & nb23, + constant uint64_t & nb31, + constant int64_t & ne1, + constant int64_t & ne2, + constant float & scale, + constant float & max_bias, + constant float & m0, + constant float & m1, + constant uint32_t & n_head_log2, + constant float & logit_softcap, + threadgroup half * shared, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]); + +// ref: https://arxiv.org/pdf/2307.08691.pdf +template // head size, queries per threadgroup, cache items per threadgroup +kernel void kernel_flash_attn_ext_f16( + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device float * dst, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant uint64_t & nb21, + constant uint64_t & nb22, + constant uint64_t & nb23, + constant uint64_t & nb31, + constant int64_t & ne1, + constant int64_t & ne2, + constant float & scale, + constant float & max_bias, + constant float & m0, + constant float & m1, + constant uint32_t & n_head_log2, + constant float & logit_softcap, + threadgroup half * shared [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short nsg = ntg.y; // number of simdgroups + + const short iq3 = tgpig[2]; + const short iq2 = tgpig[1]; + const short iq1 = tgpig[0]*Q; + + const short D4 = D/4; + const short D8 = D/8; + //const short Q8 = Q/8; + const short NW = N_SIMDWIDTH; + const short SH = (C + Q); // shared memory per simdgroup in (half) + + const short T = D + 2*nsg*SH; // shared memory size per query in (half) + const short TF = T/2; // shared memory size per query in (float) + const short T4 = T/4; // shared memory size per query in (half4) + + threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data + threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 + threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix + + // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) + simdgroup_half8x8 lo[D8]; + + // load heads from Q to shared memory + for (short j = sgitg; j < Q; j += nsg) { + device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); + + for (short i = tiisg; i < D4; i += NW) { + if (iq1 + j < ne01) { + sq4[j*T4 + i] = (half4) q4[i]; + } else { + sq4[j*T4 + i] = 0.0h; + } + } + } + + // zero out lo + for (short i = 0; i < D8; ++i) { + lo[i] = make_filled_simdgroup_matrix(0.0h); + } + + // zero out shared memory SH + for (short j = 0; j < Q; ++j) { + for (short i = tiisg; i < SH; i += NW) { + ss[j*TF + i] = 0.0f; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + { + float S[Q] = { [0 ... Q-1] = 0.0h }; + float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 }; + + // assume K and V are same shape + const short ne22 = ne12; + const short ne23 = ne13; + + // broadcast + const short rk2 = ne02/ne12; + const short rk3 = ne03/ne13; + + const short rv2 = ne02/ne22; + const short rv3 = ne03/ne23; + + // k indices + const short ik2 = iq2/rk2; + const short ik3 = iq3/rk3; + + // v indices + const short iv2 = iq2/rv2; + const short iv3 = iq3/rv3; + + // load the queries from shared memory into local memory + simdgroup_half8x8 mq[D8]; + + for (short i = 0; i < D8; ++i) { + simdgroup_load(mq[i], sq + i*8, T); + } + + // pointer to the mask + device const half * mp = (device const half *) (mask + iq1*nb31); + + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + const uint32_t h = iq2; + + const float base = h < n_head_log2 ? m0 : m1; + const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exph); + } + + // loop over the KV cache + // each simdgroup handles blocks of Q rows and C columns + for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { + const int ic = ic0 + C*sgitg; + if (ic >= ne11) { + break; + } + + // Q*K^T + { + for (short cc = 0; cc < C/8; ++cc) { + simdgroup_float8x8 mqk = make_filled_simdgroup_matrix(0.h); + + device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); + + for (short i = 0; i < D8; ++i) { + simdgroup_half8x8 mk; + simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose + + simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); + } + + simdgroup_store(mqk, ss + 8*cc, TF, 0, false); + } + } + + // used to detect blocks full of -INF + float smax = -INFINITY; + + // online softmax + { + float ms[Q]; + + for (short j = 0; j < Q; ++j) { + const float m = M[j]; + + // scale and apply the logitcap / mask + float s = ss[j*TF + tiisg]*scale; + + if (logit_softcap != 0.0f) { + s = logit_softcap*precise::tanh(s); + } + + if (mask != q) { + // mqk = mqk + mask*slope + s += slope*mp[ic + j*nb31/sizeof(half) + tiisg]; + } + + smax = simd_max(max(smax, s)); + M[j] = simd_max(max(M[j], s)); + + ms[j] = exp(m - M[j]); + const float vs = exp(s - M[j]); + + S[j] = S[j]*ms[j] + simd_sum(vs); + + // the P matrix from the paper (Q rows, C columns) + ss[j*TF + tiisg] = vs; + } + + // create a QxQ diagonal matrix for rescaling the output + if (tiisg < Q) { + ss[tiisg*TF + C + tiisg] = ms[tiisg]; + } + } + + // skip -INF blocks + if (smax == -INFINITY) { + continue; + } + + // O = diag(ms)*O + { + simdgroup_float8x8 mm; + simdgroup_load(mm, ss + C, TF, 0, false); + + for (short i = 0; i < D8; ++i) { + simdgroup_multiply(lo[i], mm, lo[i]); + } + } + + // O = O + (Q*K^T)*V + { + for (short cc = 0; cc < C/8; ++cc) { + device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); + + for (short i = 0; i < D8; ++i) { + simdgroup_half8x8 mk; + simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false); + + simdgroup_float8x8 mv; + simdgroup_load(mv, ss + 8*cc, TF, 0, false); + + simdgroup_multiply_accumulate(lo[i], mv, mk, lo[i]); + } + } + } + } + + // these are needed for reducing the results from the simdgroups (reuse the ss buffer) + for (short j = 0; j < Q; ++j) { + if (tiisg == 0) { + ss[j*TF + 0] = S[j]; + ss[j*TF + 1] = M[j]; + } + } + } + + // reduce the warps sequentially + for (short sg = 1; sg < nsg; ++sg) { + float S = { 0.0h }; + float M = { -FLT_MAX/2 }; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // each simdgroup stores its output to shared memory, reusing sq + if (sgitg == sg) { + for (short i = 0; i < D8; ++i) { + simdgroup_store(lo[i], sq + i*8, T, 0, false); + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // the first simdgroup accumulates the results from the other simdgroups + if (sgitg == 0) { + for (short j = 0; j < Q; ++j) { + const float S0 = ss[j*TF + 0]; + const float S1 = ss[j*TF + sg*SH + 0]; + + const float M0 = ss[j*TF + 1]; + const float M1 = ss[j*TF + sg*SH + 1]; + + M = max(M0, M1); + + const float ms0 = exp(M0 - M); + const float ms1 = exp(M1 - M); + + S = S0*ms0 + S1*ms1; + + if (tiisg == 0) { + ss[j*TF + 0] = S; + ss[j*TF + 1] = M; + + ss[j*TF + C + j ] = ms0; + ss[j*TF + C + j + sg*SH] = ms1; + } + } + + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + { + simdgroup_half8x8 t; + simdgroup_float8x8 ms0; + simdgroup_float8x8 ms1; + + simdgroup_load(ms0, ss + C, TF, 0, false); + simdgroup_load(ms1, ss + C + sg*SH, TF, 0, false); + + for (short i = 0; i < D8; ++i) { + simdgroup_load (t, sq + i*8, T, 0, false); + simdgroup_multiply(t, ms1, t); + + simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t); + } + } + } + } + + // store result to shared memory (reuse sq) + if (sgitg == 0) { + for (short i = 0; i < D8; ++i) { + simdgroup_store(lo[i], sq + i*8, T, 0, false); + } + } + + device float4 * dst4 = (device float4 *) dst; + + // final rescale with 1/S and store to global memory + if (sgitg == 0) { + for (short j = 0; j < Q && iq1 + j < ne01; ++j) { + const float S = ss[j*TF + 0]; + + for (short i = tiisg; i < D4; i += NW) { + dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S; + } + } + } +} + +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>; +template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>; +template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>; +//template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>; + +template // head size, queries per threadgroup, cache items per threadgroup +kernel void kernel_flash_attn_ext_vec_f16( + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device float * dst, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant uint64_t & nb21, + constant uint64_t & nb22, + constant uint64_t & nb23, + constant uint64_t & nb31, + constant int64_t & ne1, + constant int64_t & ne2, + constant float & scale, + constant float & max_bias, + constant float & m0, + constant float & m1, + constant uint32_t & n_head_log2, + constant float & logit_softcap, + threadgroup half * shared [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short nsg = ntg.y; // number of simdgroups + + const short iq3 = tgpig[2]; + const short iq2 = tgpig[1]; + const short iq1 = tgpig[0]; + + const short D4 = D/4; + const short NW = N_SIMDWIDTH; + const short SH = (C + Q); // shared memory per simdgroup in (half) + + const short T = D + 2*nsg*SH; // shared memory size per query in (half) + + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + const uint32_t h = iq2; + + const float base = h < n_head_log2 ? m0 : m1; + const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } + + //threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data + threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 + threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix + threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4 + threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results + + // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) + half4 lo[D4/NW]; + + // load heads from Q to shared memory + device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)); + + for (short i = tiisg; i < D4; i += NW) { + if (iq1 < ne01) { + sq4[i] = (half4) q4[i]; + } else { + sq4[i] = 0.0h; + } + } + + // zero out lo + for (short i = tiisg; i < D4; i += NW) { + lo[i/NW] = 0.0h; + } + + // zero out shared memory SH + for (short i = tiisg; i < SH/4; i += NW) { + ss4[i] = 0.0h; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + { + float S = { 0.0h }; + float M = { -FLT_MAX/2 }; + + // assume K and V are same shape + const short ne22 = ne12; + const short ne23 = ne13; + + // broadcast + const short rk2 = ne02/ne12; + const short rk3 = ne03/ne13; + + const short rv2 = ne02/ne22; + const short rv3 = ne03/ne23; + + // k indices + const short ik2 = iq2 / rk2; + const short ik3 = iq3 / rk3; + + // v indices + const short iv2 = iq2 / rv2; + const short iv3 = iq3 / rv3; + + // load the queries from shared memory into local memory + float4 mq[D4]; + + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + mq[i] = (float4) sq4[i]; + } + + // pointer to the mask + device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31); + + // loop over the KV cache + // each simdgroup handles blocks of Q rows and C columns + for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { + const int ic = ic0 + C*sgitg; + if (ic >= ne11) { + break; + } + + // Q*K^T + { +#pragma unroll + for (short cc = 0; cc < C/4; ++cc) { + float4 mqk = { 0.0h }; + + device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13)); + +#pragma unroll + for (short ii = 0; ii < D4; ii += NW) { + const short i = ii + tiisg; + + float4x4 mk; + mk[0] = (float4) pk4[i + 0*(nb11/8)]; + mk[1] = (float4) pk4[i + 1*(nb11/8)]; + mk[2] = (float4) pk4[i + 2*(nb11/8)]; + mk[3] = (float4) pk4[i + 3*(nb11/8)]; + + mqk += (float4) (mq[i] * mk); + } + + // reduce the results from the threads in the simdgroup + mqk += simd_shuffle_down(mqk, 16); + mqk += simd_shuffle_down(mqk, 8); + mqk += simd_shuffle_down(mqk, 4); + mqk += simd_shuffle_down(mqk, 2); + mqk += simd_shuffle_down(mqk, 1); + + // mqk = mqk*scale + mask*slope + if (tiisg == 0) { + mqk *= scale; + + if (logit_softcap != 0.0f) { + mqk = logit_softcap*precise::tanh(mqk); + } + + mqk += (mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f; + + ss4[cc] = mqk; + } + } + } + + // online softmax + { + const short p = tiisg; + + const float m = M; + const float s = ss[p]; + + M = simd_max(max(M, s)); + + const float ms = exp(m - M); + const float vs = exp(s - M); + + S = S*ms + simd_sum(vs); + + // the P matrix from the paper (Q rows, C columns) + ss[p] = vs; + + // O = diag(ms)*O +#pragma unroll + for (short ii = 0; ii < D4; ii += NW) { + const short i = ii + tiisg; + lo[i/NW] *= ms; + } + } + + // O = O + (Q*K^T)*V + { +#pragma unroll + for (short cc = 0; cc < C/4; ++cc) { + device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + 4*cc)*nb21 + iv2*nb22 + iv3*nb23)); + +#pragma unroll + for (short ii = 0; ii < D4; ii += NW) { + const short i = ii + tiisg; + + lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0]; + lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1]; + lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2]; + lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3]; + } + } + } + + } + + // these are needed for reducing the results from the simdgroups (reuse the ss buffer) + if (tiisg == 0) { + ss[0] = S; + ss[1] = M; + } + } + + // store results to shared memory + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + sr4[i] = lo[ii/NW]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // parallel reduce + for (short r = nsg/2; r > 0; r >>= 1) { + if (sgitg < r) { + const float S0 = ss[ 0]; + const float S1 = ss[r*SH + 0]; + + const float M0 = ss[ 1]; + const float M1 = ss[r*SH + 1]; + + const float M = max(M0, M1); + + const float ms0 = exp(M0 - M); + const float ms1 = exp(M1 - M); + + const float S = S0*ms0 + S1*ms1; + + if (tiisg == 0) { + ss[0] = S; + ss[1] = M; + } + + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + device float4 * dst4 = (device float4 *) dst; + + // final rescale with 1/S and store to global memory + if (sgitg == 0) { + const float S = ss[0]; + + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S; + } + } +} + +template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>; +//template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>; + +template +kernel void kernel_cpy( + device const void * src0, + device void * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + device T1 * dst_data = (device T1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { + device const T0 * src = (device T0 *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + dst_data[i00] = (T1) src[0]; + } +} + +typedef decltype(kernel_cpy) kernel_cpy_t; + +template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy; + +kernel void kernel_cpy_f32_q8_0( + device const float * src0, + device void * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK8_0; + + device block_q8_0 * dst_data = (device block_q8_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x*QK8_0; i00 < ne00; i00 += ntg.x*QK8_0) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + const float v = src[j]; + amax = MAX(amax, fabs(v)); + } + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + dst_data[i00/QK8_0].d = d; + + for (int j = 0; j < QK8_0; ++j) { + const float x0 = src[j]*id; + + dst_data[i00/QK8_0].qs[j] = round(x0); + } + } +} + +kernel void kernel_cpy_f32_q4_0( + device const float * src0, + device void * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_0; + + device block_q4_0 * dst_data = (device block_q4_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x*QK4_0; i00 < ne00; i00 += ntg.x*QK4_0) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int j = 0; j < QK4_0; j++) { + const float v = src[j]; + if (amax < fabs(v)) { + amax = fabs(v); + max = v; + } + } + + const float d = max / -8; + const float id = d ? 1.0f/d : 0.0f; + + dst_data[i00/QK4_0].d = d; + + for (int j = 0; j < QK4_0/2; ++j) { + const float x0 = src[0 + j]*id; + const float x1 = src[QK4_0/2 + j]*id; + + const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f)); + const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f)); + + dst_data[i00/QK4_0].qs[j] = xi0; + dst_data[i00/QK4_0].qs[j] |= xi1 << 4; + } + } +} + +kernel void kernel_cpy_f32_q4_1( + device const float * src0, + device void * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_1; + + device block_q4_1 * dst_data = (device block_q4_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x*QK4_1; i00 < ne00; i00 += ntg.x*QK4_1) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + float min = FLT_MAX; + float max = -FLT_MAX; + + for (int j = 0; j < QK4_1; j++) { + const float v = src[j]; + if (min > v) min = v; + if (max < v) max = v; + } + + const float d = (max - min) / ((1 << 4) - 1); + const float id = d ? 1.0f/d : 0.0f; + + dst_data[i00/QK4_1].d = d; + dst_data[i00/QK4_1].m = min; + + for (int j = 0; j < QK4_1/2; ++j) { + const float x0 = (src[0 + j] - min)*id; + const float x1 = (src[QK4_1/2 + j] - min)*id; + + const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f)); + const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f)); + + dst_data[i00/QK4_1].qs[j] = xi0; + dst_data[i00/QK4_1].qs[j] |= xi1 << 4; + } + } +} + +kernel void kernel_cpy_f32_q5_0( + device const float * src0, + device void * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_0; + + device block_q5_0 * dst_data = (device block_q5_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x*QK5_0; i00 < ne00; i00 += ntg.x*QK5_0) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int j = 0; j < QK5_0; j++) { + const float v = src[j]; + if (amax < fabs(v)) { + amax = fabs(v); + max = v; + } + } + + const float d = max / -16; + const float id = d ? 1.0f/d : 0.0f; + + dst_data[i00/QK5_0].d = d; + + uint32_t qh = 0; + for (int j = 0; j < QK5_0/2; ++j) { + const float x0 = src[0 + j]*id; + const float x1 = src[QK5_0/2 + j]*id; + + const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f)); + const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f)); + + dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2); + } + thread const uint8_t * qh8 = (thread const uint8_t *)&qh; + for (int j = 0; j < 4; ++j) { + dst_data[i00/QK5_0].qh[j] = qh8[j]; + } + } +} + +kernel void kernel_cpy_f32_q5_1( + device const float * src0, + device void * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_1; + + device block_q5_1 * dst_data = (device block_q5_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x*QK5_1; i00 < ne00; i00 += ntg.x*QK5_1) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + float max = src[0]; + float min = src[0]; + + for (int j = 1; j < QK5_1; j++) { + const float v = src[j]; + min = v < min ? v : min; + max = v > max ? v : max; + } + + const float d = (max - min) / 31; + const float id = d ? 1.0f/d : 0.0f; + + dst_data[i00/QK5_1].d = d; + dst_data[i00/QK5_1].m = min; + + uint32_t qh = 0; + for (int j = 0; j < QK5_1/2; ++j) { + const float x0 = (src[0 + j] - min)*id; + const float x1 = (src[QK5_1/2 + j] - min)*id; + + const uint8_t xi0 = (uint8_t)(x0 + 0.5f); + const uint8_t xi1 = (uint8_t)(x1 + 0.5f); + + dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2); + } + thread const uint8_t * qh8 = (thread const uint8_t *)&qh; + for (int j = 0; j < 4; ++j) { + dst_data[i00/QK5_1].qh[j] = qh8[j]; + } + } +} + +static inline int best_index_int8(int n, constant float * val, float x) { + if (x <= val[0]) return 0; + if (x >= val[n-1]) return n-1; + int ml = 0, mu = n-1; + while (mu-ml > 1) { + int mav = (ml+mu)/2; + if (x < val[mav]) mu = mav; else ml = mav; + } + return x - val[mu-1] < val[mu] - x ? mu-1 : mu; +} + +constexpr constant static float kvalues_iq4nl_f[16] = { + -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f +}; + +kernel void kernel_cpy_f32_iq4_nl( + device const float * src0, + device void * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_NL; + + device block_iq4_nl * dst_data = (device block_iq4_nl *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x*QK4_NL; i00 < ne00; i00 += ntg.x*QK4_NL) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int j = 0; j < QK4_0; j++) { + const float v = src[j]; + if (amax < fabs(v)) { + amax = fabs(v); + max = v; + } + } + + const float d = max / kvalues_iq4nl_f[0]; + const float id = d ? 1.0f/d : 0.0f; + + float sumqx = 0, sumq2 = 0; + for (int j = 0; j < QK4_NL/2; ++j) { + const float x0 = src[0 + j]*id; + const float x1 = src[QK4_NL/2 + j]*id; + + const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0); + const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1); + + dst_data[i00/QK4_NL].qs[j] = xi0 | (xi1 << 4); + + const float v0 = kvalues_iq4nl_f[xi0]; + const float v1 = kvalues_iq4nl_f[xi1]; + const float w0 = src[0 + j]*src[0 + j]; + const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j]; + sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j]; + sumq2 += w0*v0*v0 + w1*v1*v1; + + } + + dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d; + + } +} + +kernel void kernel_concat( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant int32_t & dim, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; + + int64_t o[4] = {0, 0, 0, 0}; + o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03)); + + device const float * x; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { + x = (device const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00); + } else { + x = (device const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10); + } + + device float * y = (device float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + *y = *x; + } +} + +void kernel_mul_mv_q2_K_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int step = sizeof(block_q2_K) * nb; + + const int ix = tiisg/8; // 0...3 + const int it = tiisg%8; // 0...7 + const int iq = it/4; // 0 or 1 + const int ir = it%4; // 0...3 + const int is = (8*ir)/16;// 0 or 1 + + device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir; + + for (int ib = ix; ib < nb; ib += 4) { + + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; + yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8]; + yl[i+16] = y4[i+64]; sumy[2] += yl[i+16]; + yl[i+24] = y4[i+96]; sumy[3] += yl[i+24]; + } + + device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*iq + is; + device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir; + device const half * dh = &x[ib].d; + + for (int row = 0; row < N_DST; row++) { + + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003); + acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300); + acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c); + acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00); + acc1[2] += yl[i+16] * (qs[i/2] & 0x0030); + acc2[2] += yl[i+17] * (qs[i/2] & 0x3000); + acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0); + acc2[3] += yl[i+25] * (qs[i/2] & 0xc000); + } + float dall = dh[0]; + float dmin = dh[1] * 1.f/16.f; + sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f + + (acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f + + (acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f + + (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) - + dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0)); + + qs += step/2; + sc += step; + dh += step/2; + } + + y4 += 4 * QK_K; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} + +[[host_name("kernel_mul_mv_q2_K_f32")]] +kernel void kernel_mul_mv_q2_K_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); +} + +void kernel_mul_mv_q3_K_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const int nb = ne00/QK_K; + + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + const int64_t im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0; + device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + + //const uint16_t kmask1 = 0x3030; + //const uint16_t kmask2 = 0x0f0f; + + const int tid = tiisg/4; + const int ix = tiisg%4; + const int ip = tid/4; // 0 or 1 + const int il = 2*((tid%4)/2); // 0 or 2 + const int ir = tid%2; + const int n = 8; + const int l0 = n*ir; + + // One would think that the Metal compiler would figure out that ip and il can only have + // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it + // with these two tales. + // + // Possible masks for the high bit + const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200}, // ip = 0, il = 0 + {0x0004, 0x0400, 0x0008, 0x0800}, // ip = 0, il = 2 + {0x0010, 0x1000, 0x0020, 0x2000}, // ip = 1, il = 0 + {0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2 + + // Possible masks for the low 2 bits + const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}}; + + const ushort4 hm = mm[2*ip + il/2]; + + const int shift = 2*il; + const float v1 = il == 0 ? 4.f : 64.f; + const float v2 = 4.f * v1; + + const uint16_t s_shift1 = 4*ip; + const uint16_t s_shift2 = s_shift1 + il; + + const int q_offset = 32*ip + l0; + const int y_offset = 128*ip + 32*il + l0; + + const int step = sizeof(block_q3_K) * nb / 2; + + device const float * y1 = yy + ix*QK_K + y_offset; + + uint32_t scales32, aux32; + thread uint16_t * scales16 = (thread uint16_t *)&scales32; + thread const int8_t * scales = (thread const int8_t *)&scales32; + + float sumf1[2] = {0.f}; + float sumf2[2] = {0.f}; + for (int i = ix; i < nb; i += 4) { + + for (int l = 0; l < 8; ++l) { + yl[l+ 0] = y1[l+ 0]; + yl[l+ 8] = y1[l+16]; + yl[l+16] = y1[l+32]; + yl[l+24] = y1[l+48]; + } + + device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset); + device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0); + device const uint16_t * a = (device const uint16_t *)(x[i].scales); + device const half * dh = &x[i].d; + + for (int row = 0; row < 2; ++row) { + + const float d_all = (float)dh[0]; + + scales16[0] = a[4]; + scales16[1] = a[5]; + aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030; + scales16[0] = a[il+0]; + scales16[1] = a[il+1]; + scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32; + + float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0; + for (int l = 0; l < n; l += 2) { + const int32_t qs = q[l/2]; + s1 += yl[l+0] * (qs & qm[il/2][0]); + s2 += yl[l+1] * (qs & qm[il/2][1]); + s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]); + s4 += yl[l+16] * (qs & qm[il/2][2]); + s5 += yl[l+17] * (qs & qm[il/2][3]); + s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]); + } + float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1); + float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2); + sumf1[row] += d1 * (scales[0] - 32); + sumf2[row] += d2 * (scales[2] - 32); + + s1 = s2 = s3 = s4 = s5 = s6 = 0; + for (int l = 0; l < n; l += 2) { + const int32_t qs = q[l/2+8]; + s1 += yl[l+8] * (qs & qm[il/2][0]); + s2 += yl[l+9] * (qs & qm[il/2][1]); + s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]); + s4 += yl[l+24] * (qs & qm[il/2][2]); + s5 += yl[l+25] * (qs & qm[il/2][3]); + s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]); + } + d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1); + d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2); + sumf1[row] += d1 * (scales[1] - 32); + sumf2[row] += d2 * (scales[3] - 32); + + q += step; + h += step; + a += step; + dh += step; + + } + + y1 += 4 * QK_K; + + } + + for (int row = 0; row < 2; ++row) { + const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift); + sumf1[row] = simd_sum(sumf); + } + if (tiisg == 0) { + for (int row = 0; row < 2; ++row) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = sumf1[row]; + } + } +} + +[[host_name("kernel_mul_mv_q3_K_f32")]] +kernel void kernel_mul_mv_q3_K_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); +} + +void kernel_mul_mv_q4_K_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + + const int ix = tiisg/8; // 0...3 + const int it = tiisg%8; // 0...7 + const int iq = it/4; // 0 or 1 + const int ir = it%4; // 0...3 + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int first_row = r0 * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[16]; + float yh[16]; + float sumf[N_DST]={0.f}, all_sum; + + const int step = sizeof(block_q4_K) * nb / 2; + + device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir; + + uint16_t sc16[4]; + thread const uint8_t * sc8 = (thread const uint8_t *)sc16; + + for (int ib = ix; ib < nb; ib += 4) { + + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0]; + yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8]; + yh[i+0] = y4[i+128]; sumy[2] += yh[i+0]; + yh[i+8] = y4[i+160]; sumy[3] += yh[i+8]; + } + + device const uint16_t * sc = (device const uint16_t *)x[ib].scales + iq; + device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir; + device const half * dh = &x[ib].d; + + for (int row = 0; row < N_DST; row++) { + + sc16[0] = sc[0] & kmask1; + sc16[1] = sc[2] & kmask1; + sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2); + sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2); + + device const uint16_t * q2 = q1 + 32; + + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1[0] += yl[i+0] * (q1[i/2] & 0x000F); + acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00); + acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0); + acc1[3] += yl[i+9] * (q1[i/2] & 0xF000); + acc2[0] += yh[i+0] * (q2[i/2] & 0x000F); + acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00); + acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0); + acc2[3] += yh[i+9] * (q2[i/2] & 0xF000); + } + + float dall = dh[0]; + float dmin = dh[1]; + sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] + + (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f + + (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] + + (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) - + dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); + + q1 += step; + sc += step; + dh += step; + } + + y4 += 4 * QK_K; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} + +[[host_name("kernel_mul_mv_q4_K_f32")]] +kernel void kernel_mul_mv_q4_K_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); +} + +void kernel_mul_mv_q5_K_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const int nb = ne00/QK_K; + + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0; + device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float sumf[2]={0.f}; + + const int step = sizeof(block_q5_K) * nb; + + float yl[16], yh[16]; + + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + + const int tid = tiisg/4; + const int ix = tiisg%4; + const int iq = tid/4; + const int ir = tid%4; + const int n = 8; + + const int l0 = n*ir; + const int q_offset = 32*iq + l0; + const int y_offset = 64*iq + l0; + + const uint8_t hm1 = 1u << (2*iq); + const uint8_t hm2 = hm1 << 1; + const uint8_t hm3 = hm1 << 4; + const uint8_t hm4 = hm2 << 4; + + uint16_t sc16[4]; + thread const uint8_t * sc8 = (thread const uint8_t *)sc16; + + device const float * y1 = yy + ix*QK_K + y_offset; + + for (int i = ix; i < nb; i += 4) { + + device const uint8_t * q1 = x[i].qs + q_offset; + device const uint8_t * qh = x[i].qh + l0; + device const half * dh = &x[i].d; + device const uint16_t * a = (device const uint16_t *)x[i].scales + iq; + + device const float * y2 = y1 + 128; + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int l = 0; l < 8; ++l) { + yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0]; + yl[l+8] = y1[l+32]; sumy[1] += yl[l+8]; + yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0]; + yh[l+8] = y2[l+32]; sumy[3] += yh[l+8]; + } + + for (int row = 0; row < 2; ++row) { + + device const uint8_t * q2 = q1 + 64; + + sc16[0] = a[0] & kmask1; + sc16[1] = a[2] & kmask1; + sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2); + sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2); + + float4 acc1 = {0.f}; + float4 acc2 = {0.f}; + for (int l = 0; l < n; ++l) { + uint8_t h = qh[l]; + acc1[0] += yl[l+0] * (q1[l] & 0x0F); + acc1[1] += yl[l+8] * (q1[l] & 0xF0); + acc1[2] += yh[l+0] * (q2[l] & 0x0F); + acc1[3] += yh[l+8] * (q2[l] & 0xF0); + acc2[0] += h & hm1 ? yl[l+0] : 0.f; + acc2[1] += h & hm2 ? yl[l+8] : 0.f; + acc2[2] += h & hm3 ? yh[l+0] : 0.f; + acc2[3] += h & hm4 ? yh[l+8] : 0.f; + } + const float dall = dh[0]; + const float dmin = dh[1]; + sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) + + sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) + + sc8[4] * (acc1[2] + 16.f*acc2[2]) + + sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) - + dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); + + q1 += step; + qh += step; + dh += step/2; + a += step/2; + + } + + y1 += 4 * QK_K; + + } + + for (int row = 0; row < 2; ++row) { + const float tot = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; + } + } +} + +[[host_name("kernel_mul_mv_q5_K_f32")]] +kernel void kernel_mul_mv_q5_K_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); +} + +void kernel_mul_mv_q6_K_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const uint8_t kmask1 = 0x03; + const uint8_t kmask2 = 0x0C; + const uint8_t kmask3 = 0x30; + const uint8_t kmask4 = 0xC0; + + const int nb = ne00/QK_K; + + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + const int im = tgpig.z; + + const int row = 2 * r0 + sgitg; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0; + device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float sumf = 0; + + const int tid = tiisg/2; + const int ix = tiisg%2; + const int ip = tid/8; // 0 or 1 + const int il = tid%8; + const int n = 4; + const int l0 = n*il; + const int is = 8*ip + l0/16; + + const int y_offset = 128*ip + l0; + const int q_offset_l = 64*ip + l0; + const int q_offset_h = 32*ip + l0; + + for (int i = ix; i < nb; i += 2) { + + device const uint8_t * q1 = x[i].ql + q_offset_l; + device const uint8_t * q2 = q1 + 32; + device const uint8_t * qh = x[i].qh + q_offset_h; + device const int8_t * sc = x[i].scales + is; + + device const float * y = yy + i * QK_K + y_offset; + + const float dall = x[i].d; + + float4 sums = {0.f, 0.f, 0.f, 0.f}; + for (int l = 0; l < n; ++l) { + sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); + sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); + sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32); + sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); + } + + sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]); + + } + + const float tot = simd_sum(sumf); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + row] = tot; + } +} + +[[host_name("kernel_mul_mv_q6_K_f32")]] +kernel void kernel_mul_mv_q6_K_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); +} + +// ======================= "True" 2-bit + +void kernel_mul_mv_iq2_xxs_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_iq2_xxs * x = (device const block_iq2_xxs *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int nb32 = nb * (QK_K / 32); + + threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values; + threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256); + { + int nval = 4; + int pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) values[pos + i] = iq2xxs_grid[pos + i]; + nval = 2; + pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + + for (int i = 0; i < 32; ++i) { + yl[i] = y4[i]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq2_xxs * xr = x + ibl; + device const uint16_t * q2 = xr->qs + 4 * ib; + device const half * dh = &xr->d; + + for (int row = 0; row < N_DST; row++) { + + const float db = dh[0]; + device const uint8_t * aux8 = (device const uint8_t *)q2; + const uint32_t aux32 = q2[2] | (q2[3] << 16); + const float d = db * (0.5f + (aux32 >> 28)); + + float sum = 0; + for (int l = 0; l < 4; ++l) { + const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + aux8[l]); + const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127]; + for (int j = 0; j < 8; ++j) { + sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + } + } + sumf[row] += d * sum; + + dh += nb*sizeof(block_iq2_xxs)/2; + q2 += nb*sizeof(block_iq2_xxs)/2; + } + + y4 += 32 * 32; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f; + } + } +} + +[[host_name("kernel_mul_mv_iq2_xxs_f32")]] +kernel void kernel_mul_mv_iq2_xxs_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +} + +void kernel_mul_mv_iq2_xs_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_iq2_xs * x = (device const block_iq2_xs *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int nb32 = nb * (QK_K / 32); + + threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values; + threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 512); + { + int nval = 8; + int pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) values[pos + i] = iq2xs_grid[pos + i]; + nval = 2; + pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + + for (int i = 0; i < 32; ++i) { + yl[i] = y4[i]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq2_xs * xr = x + ibl; + device const uint16_t * q2 = xr->qs + 4 * ib; + device const uint8_t * sc = xr->scales + ib; + device const half * dh = &xr->d; + + for (int row = 0; row < N_DST; row++) { + + const float db = dh[0]; + const uint8_t ls1 = sc[0] & 0xf; + const uint8_t ls2 = sc[0] >> 4; + const float d1 = db * (0.5f + ls1); + const float d2 = db * (0.5f + ls2); + + float sum1 = 0, sum2 = 0; + for (int l = 0; l < 2; ++l) { + const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511)); + const uint8_t signs = shared_signs[(q2[l] >> 9)]; + for (int j = 0; j < 8; ++j) { + sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + } + } + for (int l = 2; l < 4; ++l) { + const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511)); + const uint8_t signs = shared_signs[(q2[l] >> 9)]; + for (int j = 0; j < 8; ++j) { + sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + } + } + sumf[row] += d1 * sum1 + d2 * sum2; + + dh += nb*sizeof(block_iq2_xs)/2; + q2 += nb*sizeof(block_iq2_xs)/2; + sc += nb*sizeof(block_iq2_xs); + } + + y4 += 32 * 32; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f; + } + } +} + +[[host_name("kernel_mul_mv_iq2_xs_f32")]] +kernel void kernel_mul_mv_iq2_xs_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +} + +void kernel_mul_mv_iq3_xxs_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_iq3_xxs * x = (device const block_iq3_xxs *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int nb32 = nb * (QK_K / 32); + + threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values; + threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256); + { + int nval = 4; + int pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) values[pos + i] = iq3xxs_grid[pos + i]; + nval = 2; + pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + + for (int i = 0; i < 32; ++i) { + yl[i] = y4[i]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq3_xxs * xr = x + ibl; + device const uint8_t * q3 = xr->qs + 8 * ib; + device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib; + device const half * dh = &xr->d; + + for (int row = 0; row < N_DST; row++) { + + const float db = dh[0]; + const uint32_t aux32 = gas[0] | (gas[1] << 16); + const float d = db * (0.5f + (aux32 >> 28)); + + float2 sum = {0}; + for (int l = 0; l < 4; ++l) { + const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + q3[2*l+0]); + const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + q3[2*l+1]); + const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127]; + for (int j = 0; j < 4; ++j) { + sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f); + sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f); + } + } + sumf[row] += d * (sum[0] + sum[1]); + + dh += nb*sizeof(block_iq3_xxs)/2; + q3 += nb*sizeof(block_iq3_xxs); + gas += nb*sizeof(block_iq3_xxs)/2; + } + + y4 += 32 * 32; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.5f; + } + } +} + +[[host_name("kernel_mul_mv_iq3_xxs_f32")]] +kernel void kernel_mul_mv_iq3_xxs_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +} + +void kernel_mul_mv_iq3_s_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_iq3_s * x = (device const block_iq3_s *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int nb32 = nb * (QK_K / 32); + + threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values; + { + int nval = 8; + int pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) values[pos + i] = iq3s_grid[pos + i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + + for (int i = 0; i < 32; ++i) { + yl[i] = y4[i]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq3_s * xr = x + ibl; + device const uint8_t * qs = xr->qs + 8 * ib; + device const uint8_t * qh = xr->qh + ib; + device const uint8_t * sc = xr->scales + (ib/2); + device const uint8_t * signs = xr->signs + 4 * ib; + device const half * dh = &xr->d; + + for (int row = 0; row < N_DST; row++) { + + const float db = dh[0]; + const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf)); + + float2 sum = {0}; + for (int l = 0; l < 4; ++l) { + const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? values + 256 : values; + const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? values + 256 : values; + const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]); + const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]); + for (int j = 0; j < 4; ++j) { + sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]); + sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]); + } + } + sumf[row] += d * (sum[0] + sum[1]); + + dh += nb*sizeof(block_iq3_s)/2; + qs += nb*sizeof(block_iq3_s); + qh += nb*sizeof(block_iq3_s); + sc += nb*sizeof(block_iq3_s); + signs += nb*sizeof(block_iq3_s); + } + + y4 += 32 * 32; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} + +[[host_name("kernel_mul_mv_iq3_s_f32")]] +kernel void kernel_mul_mv_iq3_s_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +} + +void kernel_mul_mv_iq2_s_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + device const block_iq2_s * x = (device const block_iq2_s *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int nb32 = nb * (QK_K / 32); + + //threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values; + //{ + // int nval = 32; + // int pos = (32*sgitg + tiisg)*nval; + // for (int i = 0; i < nval; ++i) values[pos + i] = iq2s_grid[pos + i]; + // threadgroup_barrier(mem_flags::mem_threadgroup); + //} + + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + + for (int i = 0; i < 32; ++i) { + yl[i] = y4[i]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq2_s * xr = x + ibl; + device const uint8_t * qs = xr->qs + 4 * ib; + device const uint8_t * qh = xr->qh + ib; + device const uint8_t * sc = xr->scales + ib; + device const uint8_t * signs = qs + QK_K/8; + device const half * dh = &xr->d; + + for (int row = 0; row < N_DST; row++) { + + const float db = dh[0]; + const float d1 = db * (0.5f + (sc[0] & 0xf)); + const float d2 = db * (0.5f + (sc[0] >> 4)); + + float2 sum = {0}; + for (int l = 0; l < 2; ++l) { + //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); + //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); + constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); + constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); + for (int j = 0; j < 8; ++j) { + sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]); + sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]); + } + } + sumf[row] += d1 * sum[0] + d2 * sum[1]; + + dh += nb*sizeof(block_iq2_s)/2; + qs += nb*sizeof(block_iq2_s); + qh += nb*sizeof(block_iq2_s); + sc += nb*sizeof(block_iq2_s); + signs += nb*sizeof(block_iq2_s); + } + + y4 += 32 * 32; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f; + } + } +} + +[[host_name("kernel_mul_mv_iq2_s_f32")]] +kernel void kernel_mul_mv_iq2_s_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +} + +void kernel_mul_mv_iq1_s_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_value, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int nb32 = nb * (QK_K / 32); + + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + + float sumy = 0; + for (int i = 0; i < 32; ++i) { + yl[i] = y4[i]; + sumy += yl[i]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq1_s * xr = x + ibl; + device const uint8_t * qs = xr->qs + 4 * ib; + device const uint16_t * qh = xr->qh + ib; + device const half * dh = &xr->d; + + for (int row = 0; row < N_DST; row++) { + + constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); + constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 5) & 0x700))); + constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[0] << 2) & 0x700))); + constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[0] >> 1) & 0x700))); + + float sum = 0; + for (int j = 0; j < 4; ++j) { + sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4) + + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4) + + yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4) + + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4); + } + sumf[row] += (float)dh[0] * (sum + sumy * (qh[0] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA)) * (2*((qh[0] >> 12) & 7) + 1); + + dh += nb*sizeof(block_iq1_s)/2; + qs += nb*sizeof(block_iq1_s); + qh += nb*sizeof(block_iq1_s)/2; + } + + y4 += 32 * 32; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} + +void kernel_mul_mv_iq1_m_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_value, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + device const block_iq1_m * x = (device const block_iq1_m *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int nb32 = nb * (QK_K / 32); + + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + + iq1m_scale_t scale; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + + float4 sumy = {0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; + yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8]; + yl[i+16] = y4[i+16]; sumy[2] += yl[i+16]; + yl[i+24] = y4[i+24]; sumy[3] += yl[i+24]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq1_m * xr = x + ibl; + device const uint8_t * qs = xr->qs + 4 * ib; + device const uint8_t * qh = xr->qh + 2 * ib; + device const uint16_t * sc = (device const uint16_t *)xr->scales; + + for (int row = 0; row < N_DST; row++) { + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + + constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); + constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700))); + constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[1] << 8) & 0x700))); + constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700))); + + float2 sum = {0.f}; + for (int j = 0; j < 4; ++j) { + sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4) + + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4); + sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4) + + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4); + } + const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); + const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); + + sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) + + (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1)); + + sc += nb*sizeof(block_iq1_m)/2; + qs += nb*sizeof(block_iq1_m); + qh += nb*sizeof(block_iq1_m); + } + + y4 += 32 * 32; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} + +void kernel_mul_mv_iq4_nl_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values_i8, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + threadgroup float * shared_values = (threadgroup float *)shared_values_i8; + const int nb = ne00/QK4_NL; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + const int first_row = (r0 * 2 + sgitg) * 2; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + device const block_iq4_nl * x = (device const block_iq4_nl *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + const int ix = tiisg/2; // 0...15 + const int it = tiisg%2; // 0 or 1 + + shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + float4 yl[4]; + float sumf[2]={0.f}, all_sum; + + device const float * yb = y + ix * QK4_NL + it * 8; + + uint32_t aux32[2]; + thread const uint8_t * q8 = (thread const uint8_t *)aux32; + + float4 qf1, qf2; + + for (int ib = ix; ib < nb; ib += 16) { + + device const float4 * y4 = (device const float4 *)yb; + yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; + + for (int row = 0; row < 2 && first_row + row < ne01; ++row) { + + device const block_iq4_nl & xb = x[row*nb + ib]; + device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it); + + float4 acc1 = {0.f}, acc2 = {0.f}; + + aux32[0] = q4[0] | (q4[1] << 16); + aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f; + aux32[0] &= 0x0f0f0f0f; + qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; + qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + acc1 += yl[0] * qf1; + acc2 += yl[1] * qf2; + + aux32[0] = q4[2] | (q4[3] << 16); + aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f; + aux32[0] &= 0x0f0f0f0f; + qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; + qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + acc1 += yl[2] * qf1; + acc2 += yl[3] * qf2; + + acc1 += acc2; + + sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); + + } + + yb += 16 * QK4_NL; + } + + for (int row = 0; row < 2 && first_row + row < ne01; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} + +void kernel_mul_mv_iq4_xs_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values_i8, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + threadgroup float * shared_values = (threadgroup float *)shared_values_i8; + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + const int first_row = (r0 * 2 + sgitg) * 2; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + device const block_iq4_xs * x = (device const block_iq4_xs *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + const int ix = tiisg/16; // 0 or 1 + const int it = tiisg%16; // 0...15 + const int ib = it/2; + const int il = it%2; + + shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + float4 yl[4]; + float sumf[2]={0.f}, all_sum; + + device const float * yb = y + ix * QK_K + ib * 32 + il * 8; + + uint32_t aux32[2]; + thread const uint8_t * q8 = (thread const uint8_t *)aux32; + + float4 qf1, qf2; + + for (int ibl = ix; ibl < nb; ibl += 2) { + + device const float4 * y4 = (device const float4 *)yb; + yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; + + for (int row = 0; row < 2; ++row) { + + device const block_iq4_xs & xb = x[row*nb + ibl]; + device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il); + + float4 acc1 = {0.f}, acc2 = {0.f}; + + aux32[0] = q4[0] & 0x0f0f0f0f; + aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f; + qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; + qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + acc1 += yl[0] * qf1; + acc2 += yl[1] * qf2; + + aux32[0] = q4[1] & 0x0f0f0f0f; + aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f; + qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; + qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + acc1 += yl[2] * qf1; + acc2 += yl[3] * qf2; + + acc1 += acc2; + + const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32; + sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); + + } + + yb += 2 * QK_K; + } + + for (int row = 0; row < 2; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} + +[[host_name("kernel_mul_mv_iq1_s_f32")]] +kernel void kernel_mul_mv_iq1_s_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); +} + +[[host_name("kernel_mul_mv_iq1_m_f32")]] +kernel void kernel_mul_mv_iq1_m_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); +} + +[[host_name("kernel_mul_mv_iq4_nl_f32")]] +kernel void kernel_mul_mv_iq4_nl_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +} + +[[host_name("kernel_mul_mv_iq4_xs_f32")]] +kernel void kernel_mul_mv_iq4_xs_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +} + +//============================= templates and their specializations ============================= + +// NOTE: this is not dequantizing - we are simply fitting the template +template +void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) { + float4x4 temp = *(((device float4x4 *)src)); + for (int i = 0; i < 16; i++){ + reg[i/4][i%4] = temp[i/4][i%4]; + } +} + +template +void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) { + half4x4 temp = *(((device half4x4 *)src)); + for (int i = 0; i < 16; i++){ + reg[i/4][i%4] = temp[i/4][i%4]; + } +} + +#if defined(__HAVE_BFLOAT__) +template +void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) { + reg = (type4x4)(*src); +} +#endif + +template +void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 1); + const float d1 = il ? (xb->d / 16.h) : xb->d; + const float d2 = d1 / 256.f; + const float md = -8.h * xb->d; + const ushort mask0 = il ? 0x00F0 : 0x000F; + const ushort mask1 = mask0 << 8; + + for (int i=0;i<8;i++) { + reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md; + reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md; + } +} + +template +void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 2); + const float d1 = il ? (xb->d / 16.h) : xb->d; + const float d2 = d1 / 256.f; + const float m = xb->m; + const ushort mask0 = il ? 0x00F0 : 0x000F; + const ushort mask1 = mask0 << 8; + + for (int i=0;i<8;i++) { + reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m; + reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m; + } +} + +template +void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 3); + const float d = xb->d; + const float md = -16.h * xb->d; + const ushort mask = il ? 0x00F0 : 0x000F; + + const uint32_t qh = *((device const uint32_t *)xb->qh); + + const int x_mv = il ? 4 : 0; + + const int gh_mv = il ? 12 : 0; + const int gh_bk = il ? 0 : 4; + + for (int i = 0; i < 8; i++) { + // extract the 5-th bits for x0 and x1 + const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; + const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; + + // combine the 4-bits from qs with the 5th bit + const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); + const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); + + reg[i/2][2*(i%2)+0] = d * x0 + md; + reg[i/2][2*(i%2)+1] = d * x1 + md; + } +} + +template +void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 4); + const float d = xb->d; + const float m = xb->m; + const ushort mask = il ? 0x00F0 : 0x000F; + + const uint32_t qh = *((device const uint32_t *)xb->qh); + + const int x_mv = il ? 4 : 0; + + const int gh_mv = il ? 12 : 0; + const int gh_bk = il ? 0 : 4; + + for (int i = 0; i < 8; i++) { + // extract the 5-th bits for x0 and x1 + const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; + const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; + + // combine the 4-bits from qs with the 5th bit + const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); + const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); + + reg[i/2][2*(i%2)+0] = d * x0 + m; + reg[i/2][2*(i%2)+1] = d * x1 + m; + } +} + +template +void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) { + device const int8_t * qs = ((device const int8_t *)xb->qs); + const half d = xb->d; + + for (int i = 0; i < 16; i++) { + reg[i/4][i%4] = (qs[i + 16*il] * d); + } +} + +template +void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) { + const float d = xb->d; + const float min = xb->dmin; + device const uint8_t * q = (device const uint8_t *)xb->qs; + float dl, ml; + uint8_t sc = xb->scales[il]; + + q = q + 32*(il/8) + 16*(il&1); + il = (il/2)%4; + + half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); + uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); + dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4); + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = dl * (q[i] & mask) - ml; + } +} + +template +void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) { + const half d_all = xb->d; + device const uint8_t * q = (device const uint8_t *)xb->qs; + device const uint8_t * h = (device const uint8_t *)xb->hmask; + device const int8_t * scales = (device const int8_t *)xb->scales; + + q = q + 32 * (il/8) + 16 * (il&1); + h = h + 16 * (il&1); + uint8_t m = 1 << (il/2); + uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \ + ((il/4)>0 ? 12 : 3); + uint16_t kmask2 = il/8 ? 0xF0 : 0x0F; + uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4]; + int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) + : (scale_2&kmask2) | ((scale_1&kmask1) << 4); + float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f); + const float ml = 4.f * dl; + + il = (il/2) & 3; + const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); + const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); + dl *= coef; + + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml); + } +} + +static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) { + return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)} + : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))}; +} + +template +void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) { + device const uchar * q = xb->qs; + + short is = (il/4) * 2; + q = q + (il/4) * 32 + 16 * (il&1); + il = il & 3; + const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales); + const float d = il < 2 ? xb->d : xb->d / 16.h; + const float min = xb->dmin; + const float dl = d * sc[0]; + const float ml = min * sc[1]; + + const ushort mask = il<2 ? 0x0F : 0xF0; + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = dl * (q[i] & mask) - ml; + } +} + +template +void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) { + device const uint8_t * q = xb->qs; + device const uint8_t * qh = xb->qh; + + short is = (il/4) * 2; + q = q + 32 * (il/4) + 16 * (il&1); + qh = qh + 16 * (il&1); + uint8_t ul = 1 << (il/2); + il = il & 3; + const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales); + const float d = il < 2 ? xb->d : xb->d / 16.f; + const float min = xb->dmin; + const float dl = d * sc[0]; + const float ml = min * sc[1]; + + const ushort mask = il<2 ? 0x0F : 0xF0; + const float qh_val = il<2 ? 16.f : 256.f; + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml; + } +} + +template +void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) { + const half d_all = xb->d; + device const uint8_t * ql = (device const uint8_t *)xb->ql; + device const uint8_t * qh = (device const uint8_t *)xb->qh; + device const int8_t * scales = (device const int8_t *)xb->scales; + + ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1); + qh = qh + 32*(il/8) + 16*(il&1); + float sc = scales[(il%2) + 2 * ((il/2))]; + il = (il/2) & 3; + + const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); + const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F; + const float coef = il>1 ? 1.f/16.f : 1.f; + const float ml = d_all * sc * 32.f; + const float dl = d_all * sc * coef; + for (int i = 0; i < 16; ++i) { + const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2)) + : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4)); + reg[i/4][i%4] = dl * q - ml; + } +} + +template +void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const float d = xb->d; + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + // each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's. + device const uint16_t * q2 = xb->qs + 4*ib32; + const uint32_t aux32_g = q2[0] | (q2[1] << 16); + const uint32_t aux32_s = q2[2] | (q2[3] << 16); + thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g; + const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f; + constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]); + uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127]; + for (int i = 0; i < 8; ++i) { + reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); + } + grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]); + signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127]; + for (int i = 0; i < 8; ++i) { + reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); + } +} + +template +void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const float d = xb->d; + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + device const uint16_t * q2 = xb->qs + 4*ib32; + const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f; + constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511)); + uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9]; + for (int i = 0; i < 8; ++i) { + reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); + } + grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511)); + signs = ksigns_iq2xs[q2[2*il+1] >> 9]; + for (int i = 0; i < 8; ++i) { + reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); + } +} + +template +void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const float d = xb->d; + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + device const uint8_t * q3 = xb->qs + 8*ib32; + device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4) + 2*ib32; + const uint32_t aux32 = gas[0] | (gas[1] << 16); + const float dl = d * (0.5f + (aux32 >> 28)) * 0.5f; + constant uint8_t * grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+0]); + constant uint8_t * grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+1]); + uint8_t signs = ksigns_iq2xs[(aux32 >> 14*il) & 127]; + for (int i = 0; i < 4; ++i) { + reg[0][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f); + reg[1][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f); + } + grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+2]); + grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+3]); + signs = ksigns_iq2xs[(aux32 >> (14*il+7)) & 127]; + for (int i = 0; i < 4; ++i) { + reg[2][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f); + reg[3][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f); + } +} + +template +void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const float d = xb->d; + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + device const uint8_t * qs = xb->qs + 8*ib32; + device const uint8_t * signs = xb->signs + 4*ib32 + 2*il; + const uint8_t qh = xb->qh[ib32] >> 4*il; + const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf)); + constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256))); + constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256))); + for (int i = 0; i < 4; ++i) { + reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]); + reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]); + } + grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256))); + grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256))); + for (int i = 0; i < 4; ++i) { + reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]); + reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]); + } +} + +template +void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const float d = xb->d; + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + device const uint8_t * qs = xb->qs + 4*ib32 + 2*il; + device const uint8_t * signs = qs + QK_K/8; + const uint8_t qh = xb->qh[ib32] >> 4*il; + const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f; + constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[0] | ((qh << 8) & 0x300))); + constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[1] | ((qh << 6) & 0x300))); + for (int i = 0; i < 8; ++i) { + reg[i/4+0][i%4] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i]); + reg[i/4+2][i%4] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i]); + } +} + +template +void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const int ib32 = il/2; + il = il%2; + const float d = xb->d; + device const uint8_t * qs = xb->qs + 4*ib32 + 2*il; + device const uint16_t * qh = xb->qh; + const float dl = d * (2*((qh[ib32] >> 12) & 7) + 1); + const float ml = dl * (qh[ib32] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA); + const uint16_t h = qh[ib32] >> 6*il; + constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((h << 8) & 0x700))); + constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((h << 5) & 0x700))); + for (int i = 0; i < 4; ++i) { + reg[0][i] = dl * (grid1[i] & 0xf) + ml; + reg[1][i] = dl * (grid1[i] >> 4) + ml; + reg[2][i] = dl * (grid2[i] & 0xf) + ml; + reg[3][i] = dl * (grid2[i] >> 4) + ml; + } +} + +template +void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const int ib32 = il/2; + il = il%2; + device const uint16_t * sc = (device const uint16_t *)xb->scales; + + iq1m_scale_t scale; + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + const float d = scale.f16; + + device const uint8_t * qs = xb->qs + 4*ib32 + 2*il; + device const uint8_t * qh = xb->qh + 2*ib32 + il; + + const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1); + const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); + const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); + constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); + constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700))); + for (int i = 0; i < 4; ++i) { + reg[0][i] = dl * (grid1[i] & 0xf) + ml1; + reg[1][i] = dl * (grid1[i] >> 4) + ml1; + reg[2][i] = dl * (grid2[i] & 0xf) + ml2; + reg[3][i] = dl * (grid2[i] >> 4) + ml2; + } +} + +template +void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) { + device const uint16_t * q4 = (device const uint16_t *)xb->qs; + const float d = xb->d; + uint32_t aux32; + thread const uint8_t * q8 = (thread const uint8_t *)&aux32; + for (int i = 0; i < 4; ++i) { + aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f; + reg[i][0] = d * kvalues_iq4nl_f[q8[0]]; + reg[i][1] = d * kvalues_iq4nl_f[q8[1]]; + reg[i][2] = d * kvalues_iq4nl_f[q8[2]]; + reg[i][3] = d * kvalues_iq4nl_f[q8[3]]; + } +} + +template +void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32; + const int ls = ((xb->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((xb->scales_h >> 2*ib32) & 3) << 4); + const float d = (float)xb->d * (ls - 32); + uint32_t aux32; + thread const uint8_t * q8 = (thread const uint8_t *)&aux32; + for (int i = 0; i < 4; ++i) { + aux32 = (q4[i] >> 4*il) & 0x0f0f0f0f; + reg[i][0] = d * kvalues_iq4nl_f[q8[0]]; + reg[i][1] = d * kvalues_iq4nl_f[q8[1]]; + reg[i][2] = d * kvalues_iq4nl_f[q8[2]]; + reg[i][3] = d * kvalues_iq4nl_f[q8[3]]; + } +} + +template +kernel void kernel_get_rows_q( + device const void * src0, + device const void * src1, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb1, + constant uint64_t & nb2, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg [[threads_per_threadgroup]]) { + const int64_t i10 = tgpig.x; + const int64_t i11 = tgpig.y; + + const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0]; + + const int64_t i02 = i11; + + for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) { + float4x4 temp; + dequantize_func(((device const block_q *) ((const device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp); + *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp; + } +} + +template +kernel void kernel_get_rows_f( + device const void * src0, + device const void * src1, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb1, + constant uint64_t & nb2, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg [[threads_per_threadgroup]]) { + const int64_t i10 = tgpig.x; + const int64_t i11 = tgpig.y; + + const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0]; + + const int64_t i02 = i11; + + for (int ind = tiitg; ind < ne00; ind += tptg.x) { + (( device float *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] = + ((const device T *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind]; + } +} + +kernel void kernel_get_rows_i32( + device const void * src0, + device const void * src1, + device int32_t * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb1, + constant uint64_t & nb2, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg [[threads_per_threadgroup]]) { + const int64_t i10 = tgpig.x; + const int64_t i11 = tgpig.y; + + const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0]; + + const int64_t i02 = i11; + + for (int ind = tiitg; ind < ne00; ind += tptg.x) { + (( device int32_t *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] = + ((const device int32_t *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind]; + } +} + + +#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A +#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B +#define BLOCK_SIZE_K 32 +#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A +#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B +#define THREAD_PER_BLOCK 128 +#define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers +#define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers +#define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8 +#define SG_MAT_ROW 8 + +// each block_q contains 16*nl weights +template +kernel void kernel_mul_mm(device const uchar * src0, + device const uchar * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne02, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup uchar * shared_memory [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + threadgroup T * sa = (threadgroup T *)(shared_memory); + threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); + + const uint r0 = tgpig.y; + const uint r1 = tgpig.x; + const uint im = tgpig.z; + + // if this block is of 64x32 shape or smaller + short n_rows = (ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M; + short n_cols = (ne1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N; + + // a thread shouldn't load data outside of the matrix + short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; + short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; + + simdgroup_T8x8 ma[4]; + simdgroup_float8x8 mb[2]; + simdgroup_float8x8 mc[8]; + + for (short i = 0; i < 8; i++){ + mc[i] = make_filled_simdgroup_matrix(0.f); + } + + short il = (tiitg % THREAD_PER_ROW); + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + uint offset0 = (i12/r2)*nb02 + (i13/r3)*nb03; + ushort offset1 = il/nl; + + device const block_q * x = (device const block_q *)(src0 + (r0*BLOCK_SIZE_M + thread_row)*nb01 + offset0) + offset1; + device const float * y = (device const float *)(src1 + + nb13 * i13 + + nb12 * i12 + + nb11 * (r1 * BLOCK_SIZE_N + thread_col) + + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); + + for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { + // load data and store to threadgroup memory + T4x4 temp_a; + dequantize_func(x, il, temp_a); + threadgroup_barrier(mem_flags::mem_threadgroup); + + #pragma unroll(16) + for (short i = 0; i < 16; i++) { + *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \ + + (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \ + + (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4]; + } + + *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL)*8*32 + 8*(tiitg/THREAD_PER_COL)) = *((device float2x4 *) y); + + il = (il + 2 < nl) ? il + 2 : il % 2; + x = (il < 2) ? x + (2+nl-1)/nl : x; + y += BLOCK_SIZE_K; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // load matrices from threadgroup memory and conduct outer products + threadgroup T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2)); + threadgroup float * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2)); + + #pragma unroll(4) + for (short ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { + #pragma unroll(4) + for (short i = 0; i < 4; i++) { + simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i); + } + simdgroup_barrier(mem_flags::mem_none); + #pragma unroll(2) + for (short i = 0; i < 2; i++) { + simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i); + } + + lsma += BLOCK_SIZE_M/SG_MAT_ROW * SG_MAT_SIZE; + lsmb += BLOCK_SIZE_N/SG_MAT_ROW * SG_MAT_SIZE; + + #pragma unroll(8) + for (short i = 0; i < 8; i++){ + simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]); + } + } + } + + if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) { + device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \ + + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0; + for (short i = 0; i < 8; i++) { + simdgroup_store(mc[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0); + } + } else { + // block is smaller than 64x32, we should avoid writing data outside of the matrix + threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup float * temp_str = ((threadgroup float *) shared_memory) \ + + 32 * (sgitg&1) + (16 * (sgitg>>1))*BLOCK_SIZE_M; + for (short i = 0; i < 8; i++) { + simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (sgitg == 0) { + for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { + device float * D = dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*ne0 + im*ne1*ne0; + device float4 * D4 = (device float4 *) D; + + threadgroup float * C = temp_str + (j*BLOCK_SIZE_M); + threadgroup float4 * C4 = (threadgroup float4 *) C; + + int i = 0; + for (; i < n_rows/4; i++) { + *(D4 + i) = *(C4 + i); + } + + i *= 4; + for (; i < n_rows; i++) { + *(D + i) = *(C + i); + } + } + } + } +} + +// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids +template +void kernel_mul_mm_id_impl( + device const uchar * src0, + device const uchar * src1, + threadgroup ushort2 * rowids, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne02, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + int64_t ne1, + int64_t ne0ne1, + threadgroup uchar * shared_memory, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + threadgroup half * sa = (threadgroup half *)(shared_memory); + threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); + + const uint r0 = tgpig.y; + const uint r1 = tgpig.x; + + if (r1 * BLOCK_SIZE_N >= ne1) return; + + // if this block is of 64x32 shape or smaller + short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M; + short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N; + + // a thread shouldn't load data outside of the matrix + short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; + short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; + + simdgroup_half8x8 ma[4]; + simdgroup_float8x8 mb[2]; + simdgroup_float8x8 c_res[8]; + for (int i = 0; i < 8; i++){ + c_res[i] = make_filled_simdgroup_matrix(0.f); + } + short il = (tiitg % THREAD_PER_ROW); + + ushort offset1 = il/nl; + + threadgroup const auto & id = rowids[r1 * BLOCK_SIZE_N + thread_col]; + + device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01) + offset1; + device const float * y = (device const float *)(src1 + + nb12 * id[1] + + nb11 * (id[0] % ne11) + + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); + + for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { + // load data and store to threadgroup memory + half4x4 temp_a; + dequantize_func(x, il, temp_a); + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int i = 0; i < 16; i++) { + *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \ + + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \ + + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4]; + } + + *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y); + + il = (il + 2 < nl) ? il + 2 : il % 2; + x = (il < 2) ? x + (2+nl-1)/nl : x; + y += BLOCK_SIZE_K; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // load matrices from threadgroup memory and conduct outer products + threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); + threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); + + for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { + for (int i = 0; i < 4; i++) { + simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i); + } + simdgroup_barrier(mem_flags::mem_none); + for (int i = 0; i < 2; i++) { + simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i); + } + + lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; + lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE; + + for (int i = 0; i < 8; i++){ + simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]); + } + } + } + + { + threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup float * temp_str = ((threadgroup float *)shared_memory) \ + + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M; + for (int i = 0; i < 8; i++) { + simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + device float * C = dst + (BLOCK_SIZE_M * r0); + if (sgitg == 0) { + for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { + threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j]; + int joff = jid[0] * ne0 + jid[1] * ne0ne1; + for (int i = 0; i < n_rows; i++) { + *(C + i + joff) = *(temp_str + i + j * BLOCK_SIZE_M); + } + } + } + } +} + +template +kernel void kernel_mul_mm_id( + device const uchar * src0s, + device const uchar * src1, + device float * dst, + device const uchar * ids, + constant int64_t & nei0, + constant int64_t & nei1, + constant uint64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne02, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint64_t & nb1, + threadgroup uchar * shared_memory [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + const int32_t i02 = tgpig.z; + tgpig.z = 0; + + device const uchar * src0 = src0s + i02*nb02; + + // row indices + threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192); + + // TODO: parallelize this loop + int64_t _ne1 = 0; + for (ushort ii1 = 0; ii1 < nei1; ii1++) { + for (ushort ii0 = 0; ii0 < nei0; ii0++) { + int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0]; + if (id == i02) { + //if (tiitg == 0) { + rowids[_ne1] = ushort2(ii0, ii1); + //} + _ne1++; + } + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + kernel_mul_mm_id_impl( + src0, + src1, + rowids, + dst, + ne00, + ne02, + nb01, + nb02, + ne11, + ne12, + nb10, + nb11, + nb12, + ne0, + _ne1, + ne0*ne1, + shared_memory, + tgpig, + tiitg, + sgitg); +} + +#define QK_NL 16 + +// +// get rows +// + +typedef decltype(kernel_get_rows_f) get_rows_f_t; + +template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f; +template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f; +#if defined(__HAVE_BFLOAT__) +template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f; +#endif + +typedef decltype(kernel_get_rows_q) get_rows_q_t; + +template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q; + +// +// matrix-matrix multiplication +// + +typedef decltype(kernel_mul_mm) mat_mm_t; + +template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; +#if defined(__HAVE_BFLOAT__) +template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mat_mm_t kernel_mul_mm; +#endif +template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm; + +// +// indirect matrix-matrix multiplication +// + +typedef decltype(kernel_mul_mm_id) mat_mm_id_t; + +template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; + +// +// matrix-vector multiplication +// + +typedef void (kernel_mul_mv_impl_t)( + device const char * src0, + device const char * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + uint64_t nb00, + uint64_t nb01, + uint64_t nb02, + int64_t ne10, + int64_t ne11, + int64_t ne12, + uint64_t nb10, + uint64_t nb11, + uint64_t nb12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + uint3 tgpig, + uint tiisg); + +typedef void (kernel_mul_mv2_impl_t)( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg); + +template +void mmv_fn( + device const char * src0, + device const char * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + uint64_t nb00, + uint64_t nb01, + uint64_t nb02, + int64_t ne10, + int64_t ne11, + int64_t ne12, + int64_t ne13, + uint64_t nb10, + uint64_t nb11, + uint64_t nb12, + int64_t ne0, + int64_t ne1, + uint64_t nb1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiitg, + uint tiisg, + uint sgitg) { + impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,ne10,ne11,ne12,nb10,nb11,nb12,ne0,ne1,r2,r3,tgpig,tiisg); +} + +template +void mmv_fn( + device const char * src0, + device const char * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + uint64_t nb00, + uint64_t nb01, + uint64_t nb02, + int64_t ne10, + int64_t ne11, + int64_t ne12, + int64_t ne13, + uint64_t nb10, + uint64_t nb11, + uint64_t nb12, + int64_t ne0, + int64_t ne1, + uint64_t nb1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiitg, + uint tiisg, + uint sgitg) { + impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg); +} + +typedef decltype(mmv_fn>) mul_mv_impl_fn_t; + +template +kernel void kernel_mul_mv_id( + device const char * src0s, + device const char * src1, + device float * dst, + device const char * ids, + constant int64_t & nei0, + constant int64_t & nei1, + constant uint64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint64_t & nb1, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + const int iid1 = tgpig.z/nei0; + const int idx = tgpig.z%nei0; + + tgpig.z = 0; + + const int32_t i02 = ((device const int32_t *) (ids + iid1*nbi1))[idx]; + + const int64_t i11 = idx % ne11; + const int64_t i12 = iid1; + + const int64_t i1 = idx; + const int64_t i2 = i12; + + device const char * src0_cur = src0s + i02*nb02; + device const char * src1_cur = src1 + i11*nb11 + i12*nb12; + device float * dst_cur = dst + i1*ne0 + i2*ne1*ne0; + + impl_fn( + /* src0 */ src0_cur, + /* src1 */ src1_cur, + /* dst */ dst_cur, + /* ne00 */ ne00, + /* ne01 */ ne01, + /* ne02 */ 1,//ne02, + /* nb00 */ nb00, + /* nb01 */ nb01, + /* nb02 */ nb02, + /* ne10 */ ne10, + /* ne11 */ 1,//ne11, + /* ne12 */ 1,//ne12, + /* ne13 */ 1,//ne13, + /* nb10 */ nb10, + /* nb11 */ nb11, + /* nb12 */ nb12, + /* ne0 */ ne0, + /* ne1 */ 1,//ne1, + /* nb1 */ nb1, + /* r2 */ 1, + /* r3 */ 1, + shared_values, + tgpig, + tiitg, + tiisg, + sgitg); +} + +typedef decltype(kernel_mul_mv_id>>) kernel_mul_mv_id_t; + +template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; + +kernel void kernel_pool_2d_max_f32( + device const float * src0, + device float * dst, + constant int32_t & k0, + constant int32_t & k1, + constant int32_t & s0, + constant int32_t & s1, + constant int32_t & p0, + constant int32_t & p1, + constant int64_t & IH, + constant int64_t & IW, + constant int64_t & OH, + constant int64_t & OW, + constant int64_t & parallel_elements, + uint gid[[thread_position_in_grid]]) { + + if (gid >= parallel_elements) { + return; + } + + const int idx = gid; + const int I_HW = IH * IW; + const int O_HW = OH * OW; + const int nc = idx / O_HW; + const int cur_oh = idx % O_HW / OW; + const int cur_ow = idx % O_HW % OW; + + device const float * i_ptr = src0 + nc * I_HW; + device float * o_ptr = dst + nc * O_HW; + + const int start_h = cur_oh * s1 - p1; + const int bh = MAX(0, start_h); + const int eh = MIN(IH, start_h + k1); + const int start_w = cur_ow * s0 - p0; + const int bw = MAX(0, start_w); + const int ew = MIN(IW, start_w + k0); + + float res = -INFINITY; + + for (int i = bh; i < eh; i += 1) { + for (int j = bw; j < ew; j += 1) { + res = MAX(res, i_ptr[i * IW + j]); + } + } + + o_ptr[cur_oh * OW + cur_ow] = res; +} + +kernel void kernel_pool_2d_avg_f32( + device const float * src0, + device float * dst, + constant int32_t & k0, + constant int32_t & k1, + constant int32_t & s0, + constant int32_t & s1, + constant int32_t & p0, + constant int32_t & p1, + constant int64_t & IH, + constant int64_t & IW, + constant int64_t & OH, + constant int64_t & OW, + constant int64_t & parallel_elements, + uint gid[[thread_position_in_grid]]) { + + if (gid >= parallel_elements) { + return; + } + + const int idx = gid; + const int I_HW = IH * IW; + const int O_HW = OH * OW; + const int nc = idx / O_HW; + const int cur_oh = idx % O_HW / OW; + const int cur_ow = idx % O_HW % OW; + + device const float * i_ptr = src0 + nc * I_HW; + device float * o_ptr = dst + nc * O_HW; + + const int start_h = cur_oh * s1 - p1; + const int bh = MAX(0, start_h); + const int eh = MIN(IH, start_h + k1); + const int start_w = cur_ow * s0 - p0; + const int bw = MAX(0, start_w); + const int ew = MIN(IW, start_w + k0); + // const float scale = 1. / ((eh - bh) * (ew - bw)); + const float scale = 1. / (k0 * k1); + + float res = 0; + + for (int i = bh; i < eh; i += 1) { + for (int j = bw; j < ew; j += 1) { + float cur = i_ptr[i * IW + j]; + res += cur * scale; + } + } + + o_ptr[cur_oh * OW + cur_ow] = res; +} diff --git a/mcandle/mcandle-metal-kernels/src/random.metal b/mcandle/mcandle-metal-kernels/src/random.metal new file mode 100644 index 000000000..c1a94199b --- /dev/null +++ b/mcandle/mcandle-metal-kernels/src/random.metal @@ -0,0 +1,214 @@ +#include +#include +#include + +using namespace metal; + +// Constants +// 2^32 and 1/2^32. Useful for converting between float and uint. +static constexpr constant ulong UNIF01_NORM32 = 4294967296; +static constexpr constant float UNIF01_INV32 = 2.328306436538696289e-10; +// 2 * pi +static constexpr constant float TWO_PI = 2.0 * M_PI_F; +static constexpr constant int3 S1 = {13, 19, 12}; +static constexpr constant int3 S2 = {2, 25, 4}; +static constexpr constant int3 S3 = {3, 11, 17}; + +// Used to prevent bad seeds. +static constexpr constant uint64_t PHI[16] = { + 0x9E3779B97F4A7C15, + 0xF39CC0605CEDC834, + 0x1082276BF3A27251, + 0xF86C6A11D0C18E95, + 0x2767F0B153D27B7F, + 0x0347045B5BF1827F, + 0x01886F0928403002, + 0xC1D64BA40F335E36, + 0xF06AD7AE9717877E, + 0x85839D6EFFBD7DC6, + 0x64D325D1C5371682, + 0xCADD0CCCFDFFBBE1, + 0x626E33B8D04B4331, + 0xBBF73C790D94F79D, + 0x471C4AB3ED3D82A5, + 0xFEC507705E4AE6E5, +}; + +// Combined Tausworthe and LCG Random Number Generator. +// https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-37-efficient-random-number-generation-and-application +// https://indico.cern.ch/event/93877/contributions/2118070/attachments/1104200/1575343/acat3_revised_final.pdf +struct HybridTaus { + + float state; + + HybridTaus() thread = default; + HybridTaus() threadgroup = default; + HybridTaus() device = default; + HybridTaus() constant = default; + + // Generate seeds for each thread. + METAL_FUNC static uint4 seed_per_thread(const ulong4 seeds) { + return uint4(ulong4(seeds) * ulong4(PHI[0], PHI[1], PHI[2], PHI[3]) * ulong4(1099087573UL)); + } + + // Tausworthe generator. + METAL_FUNC static uint taus(const uint z, const int3 s, const uint M) { + uint b = (((z << s.x) ^ z) >> s.y); + return (((z & M) << s.z) ^ b); + } + + // LCG generator. + METAL_FUNC static uint lcg(const uint z) { + return (1664525 * z + 1013904223UL); + } + + // Initialize the RNG state. + METAL_FUNC static HybridTaus init(const ulong4 seeds) { + uint4 seed = seed_per_thread(seeds); + + // Seed #1 + uint z1 = taus(seed.x, S1, 4294967294UL); + uint z2 = taus(seed.y, S2, 4294967288UL); + uint z3 = taus(seed.z, S3, 4294967280UL); + uint z4 = lcg(seed.x); + + // Seed #2 + uint r1 = (z1^z2^z3^z4^seed.y); + z1 = taus(r1, S1, 429496729UL); + z2 = taus(r1, S2, 4294967288UL); + z3 = taus(r1, S3, 429496280UL); + z4 = lcg(r1); + + // Seed #3 + r1 = (z1^z2^z3^z4^seed.z); + z1 = taus(r1, S1, 429496729UL); + z2 = taus(r1, S2, 4294967288UL); + z3 = taus(r1, S3, 429496280UL); + z4 = lcg(r1); + + // Seed #4 + r1 = (z1^z2^z3^z4^seed.w); + z1 = taus(r1, S1, 429496729UL); + z2 = taus(r1, S2, 4294967288UL); + z3 = taus(r1, S3, 429496280UL); + z4 = lcg(r1); + + HybridTaus rng; + rng.state = (z1^z2^z3^z4) * UNIF01_INV32; + return rng; + } + + METAL_FUNC float rand() { + uint seed = this->state * UNIF01_NORM32; + uint z1 = taus(seed, S1, 429496729UL); + uint z2 = taus(seed, S2, 4294967288UL); + uint z3 = taus(seed, S3, 429496280UL); + uint z4 = lcg(seed); + + thread float result = this->state; + this->state = (z1^z2^z3^z4) * UNIF01_INV32; + return result; + } +}; + +template METAL_FUNC void rand_uniform( + constant size_t &size, + constant float &min, + constant float &max, + device atomic_uint *seed, + device T *out, + uint tid [[thread_position_in_grid]] +) { + if (tid >= size) { + return; + } + + // Evenly sized vectors need an offset when writing the mirror element. + uint off = 1 - size % 2; + float diff = abs(min - max); + uint s = atomic_load_explicit(seed, memory_order_relaxed); + HybridTaus rng = HybridTaus::init({ulong(s), tid, 1, 1}); + out[tid] = static_cast(rng.rand() * diff + min); + if (tid == 0) { + atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed); + // Return early if tid == 0 && off == 0, otherwise we will write to out[size]. + if (off == 0) + return; + } + // Use symmetry to fill the other half of the array. + out[size - off - tid] = static_cast(rng.rand() * diff + min); +} + +// Create Gaussian normal distribution using Box-Muller transform: +// https://en.wikipedia.org/wiki/Box–Muller_transform +template METAL_FUNC void normal( + constant size_t &size, + constant float &mean, + constant float &stddev, + device atomic_uint *seed, + device T *out, + uint tid [[thread_position_in_grid]] +) { + if (tid >= size) { + return; + } + // Evenly sized vectors need an offset when writing the mirror element. + uint off = 1 - size % 2; + uint s = atomic_load_explicit(seed, memory_order_relaxed); + HybridTaus rng = HybridTaus::init({ulong(s), tid, 1, 1}); + float u1 = rng.rand(); + float u2 = rng.rand(); + + float cosval; + float sinval = sincos(TWO_PI * u2, cosval); + float mag = stddev * sqrt(-2.0 * log(u1)); + float z0 = mag * cosval + mean; + float z1 = mag * sinval + mean; + + out[tid] = static_cast(z0); + + if (tid == 0) { + atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed); + // Return early if tid == 0 && off == 0, otherwise we will write to out[size]. + if (off == 0) + return; + } + // Use symmetry to fill the other half of the array. + out[size - off - tid] = static_cast(z1); +} + +#define UNIFORM_OP(NAME, T) \ +kernel void rand_uniform_##NAME( \ + constant size_t &size, \ + constant float &min, \ + constant float &max, \ + device atomic_uint *seed, \ + device T *out, \ + uint tid [[thread_position_in_grid]] \ +) { \ + rand_uniform(size, min, max, seed, out, tid); \ +} \ + +#define NORMAL_OP(NAME, T) \ +kernel void rand_normal_##NAME( \ + constant size_t &size, \ + constant float &mean, \ + constant float &stddev, \ + device atomic_uint *seed, \ + device T *out, \ + uint tid [[thread_position_in_grid]] \ +) { \ + normal(size, mean, stddev, seed, out, tid); \ +} \ + + +#define RANDOM_OPS(NAME, T) \ +UNIFORM_OP(NAME, T) \ +NORMAL_OP(NAME, T) \ + +RANDOM_OPS(f32, float) +RANDOM_OPS(f16, half) + +#if __METAL_VERSION__ >= 310 +RANDOM_OPS(bf16, bfloat) +#endif diff --git a/mcandle/mcandle-metal-kernels/src/reduce.metal b/mcandle/mcandle-metal-kernels/src/reduce.metal new file mode 100644 index 000000000..f82eea076 --- /dev/null +++ b/mcandle/mcandle-metal-kernels/src/reduce.metal @@ -0,0 +1,827 @@ +#include +using namespace metal; + +#define MAX(x, y) ((x) > (y) ? (x) : (y)) +#define MIN(x, y) ((x) < (y) ? (x) : (y)) + +#define N_SIMDWIDTH 32 // assuming SIMD group size is 32 + +METAL_FUNC uint get_strided_index( + uint idx, + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides +) { + uint strided_i = 0; + for (uint d = 0; d < num_dims; d++) { + uint dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; +} + +constant int THREADGROUP_SIZE = 2048; + +template +METAL_FUNC void argmin( + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides, + constant size_t &el_to_sum_per_block, + device const T *src, + device uint *dst, + uint id, + uint tid, + uint dst_id, + uint block_dim, + threadgroup T *shared_memory, + threadgroup uint *shared_indices +) { + bool notset = true; + // Elements summed in this block range from dst_id * el_to_sum_per_block + // to (dst_id + 1) * el_to_sum_per_block. + size_t start_idx = dst_id * el_to_sum_per_block; + size_t stop_idx = start_idx + el_to_sum_per_block; + size_t idx = start_idx + tid; + while (idx < stop_idx) { + // TODO: Fast version for the contiguous case. + size_t strided_i = get_strided_index(idx, num_dims, dims, strides); + if (notset || src[strided_i] < shared_memory[tid]) { + shared_memory[tid] = src[strided_i]; + /* Assume that the reduction takes place over the last dimension which is contiguous. */ + shared_indices[tid] = idx % dims[num_dims - 1]; + notset = false; + } + idx += block_dim; + } + + threadgroup_barrier(mem_flags::mem_none); + // reduction in shared memory + for (uint s = block_dim / 2; s > 0; s >>= 1) { + if (tid < s && shared_memory[tid + s] < shared_memory[tid]) { + shared_indices[tid] = shared_indices[tid + s]; + shared_memory[tid] = shared_memory[tid + s]; + } \ + threadgroup_barrier(mem_flags::mem_none); + } + if (tid == 0) { + dst[dst_id] = shared_indices[0]; + } +} + +#define ARGMIN(NAME, T, MAXVALUE) \ +kernel void NAME( \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + constant size_t &el_to_sum_per_block, \ + device const T *src, \ + device uint *dst, \ + uint id [[ thread_position_in_grid ]], \ + uint tid [[ thread_index_in_threadgroup ]], \ + uint dst_id [[ threadgroup_position_in_grid ]], \ + uint block_dim [[ threads_per_threadgroup ]] \ +) { \ + threadgroup T shared_memory[THREADGROUP_SIZE]; \ + threadgroup uint shared_indices[THREADGROUP_SIZE]; \ + shared_memory[tid] = MAXVALUE; \ + shared_indices[tid] = 0xFFFFFFFF; \ + argmin(num_dims, dims, strides, el_to_sum_per_block, src, dst, id, tid, dst_id, block_dim, shared_memory, shared_indices); \ +} \ + + +template +METAL_FUNC void argmax( + constant size_t & num_dims, + constant size_t * dims, + constant size_t * strides, + constant size_t & el_to_sum_per_block, + device const T * src, + device uint * dst, + uint id, + uint tid, + uint dst_id, + uint block_dim, + threadgroup T * shared_memory, + threadgroup uint * shared_indices + ) { + // Elements summed in this block range from dst_id * el_to_sum_per_block + // to (dst_id + 1) * el_to_sum_per_block. + size_t start_idx = dst_id * el_to_sum_per_block; + size_t stop_idx = start_idx + el_to_sum_per_block; + size_t idx = start_idx + tid; + bool notset = true; + while (idx < stop_idx) { + // TODO: Fast version for the contiguous case. + size_t strided_i = get_strided_index(idx, num_dims, dims, strides); + if (notset || shared_memory[tid] < src[strided_i]) { + shared_memory[tid] = src[strided_i]; + shared_indices[tid] = idx % dims[num_dims - 1]; + notset = false; + } + idx += block_dim; + } + + threadgroup_barrier(mem_flags::mem_none); + + // reduction in shared memory + for (uint s = block_dim / 2; s > 0; s >>= 1) { + if (tid < s && shared_memory[tid + s] > shared_memory[tid]) { + shared_indices[tid] = shared_indices[tid + s]; + shared_memory[tid] = shared_memory[tid + s]; + } + threadgroup_barrier(mem_flags::mem_none); + } + + // Thread 0 writes the result of the reduction + if (tid == 0) { + dst[dst_id] = shared_indices[0]; + } + } + +#define ARGMAX(NAME, T, MINVALUE) \ +kernel void NAME( \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + constant size_t &el_to_sum_per_block, \ + device const T *src, \ + device uint *dst, \ + uint id [[ thread_position_in_grid ]], \ + uint tid [[ thread_index_in_threadgroup ]], \ + uint dst_id [[ threadgroup_position_in_grid ]], \ + uint block_dim [[ threads_per_threadgroup ]] \ +) { \ + threadgroup T shared_memory[THREADGROUP_SIZE]; \ + threadgroup uint shared_indices[THREADGROUP_SIZE]; \ + shared_memory[tid] = MINVALUE; \ + shared_indices[tid] = 0xFFFFFFFF; \ + argmax(num_dims, dims, strides, el_to_sum_per_block, src, dst, id, tid, dst_id, block_dim, shared_memory, shared_indices); \ +} \ + +template +METAL_FUNC void reduce( + constant size_t & num_dims, + constant size_t * dims, + constant size_t * strides, + constant size_t & el_to_sum_per_block, + device const T * src, + device T * dst, + uint id, + uint tid, + uint dst_id, + uint block_dim, + threadgroup T * shared_memory, + T (*fn)(T, T) +) { + // Elements summed in this block range from dst_id * el_to_sum_per_block + // to (dst_id + 1) * el_to_sum_per_block. + size_t start_idx = dst_id * el_to_sum_per_block; + size_t stop_idx = start_idx + el_to_sum_per_block; + size_t idx = start_idx + tid; + while (idx < stop_idx) { + // TODO: Fast version for the contiguous case. + size_t strided_i = get_strided_index(idx, num_dims, dims, strides); + T x = shared_memory[tid]; + T y = src[strided_i]; + shared_memory[tid] = fn(x, y); + idx += block_dim; + } + + threadgroup_barrier(mem_flags::mem_none); + + // reduction in shared memory + for (uint s = block_dim / 2; s > 0; s >>= 1) { + if (tid < s) { + T x = shared_memory[tid]; + T y = shared_memory[tid + s]; + shared_memory[tid] = fn(x, y); + } + threadgroup_barrier(mem_flags::mem_none); + } + + if (tid == 0) { + dst[dst_id] = shared_memory[0]; + } +} + +#define REDUCE(FN, NAME, T, START) \ +METAL_FUNC T NAME##_##op(T x, T y) { return FN; } \ +kernel void NAME( \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + constant size_t &el_to_sum_per_block, \ + device const T *src, \ + device T *dst, \ + uint id [[ thread_position_in_grid ]], \ + uint tid [[ thread_index_in_threadgroup ]], \ + uint dst_id [[ threadgroup_position_in_grid ]], \ + uint block_dim [[ threads_per_threadgroup ]] \ +) { \ + threadgroup T shared_memory[THREADGROUP_SIZE]; \ + shared_memory[tid] = START; \ + reduce(num_dims, dims, strides, el_to_sum_per_block, src, dst, id, tid, dst_id, block_dim, shared_memory, NAME##_##op); \ +} \ + +template +METAL_FUNC void softmax( + constant size_t & src_numel, + constant size_t & el_to_sum_per_block, + device const T * src, + device T * dst, + uint id, + uint tid, + uint dst_id, + uint block_dim, + threadgroup float * shared_memory +) { + size_t start_idx = dst_id * el_to_sum_per_block; + size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); + size_t idx = start_idx + tid; + + float tmp = -INFINITY; + while (idx < stop_idx) { + tmp = MAX(tmp, float(src[idx])); + idx += block_dim; + } + shared_memory[tid] = tmp; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (uint s = block_dim / 2; s > 0; s >>= 1) { + if (tid < s) { + shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]);\ + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + /* wait for shared_memory[0] to be filled */ + threadgroup_barrier(mem_flags::mem_threadgroup); + + float _max = shared_memory[0]; + + /* prevent tid=0 from overwriting _max before other threads have written it */ + threadgroup_barrier(mem_flags::mem_threadgroup); + shared_memory[tid] = 0; + + idx = start_idx + tid; + while (idx < stop_idx) { + const float val = exp(float(src[idx]) - _max); + dst[idx] = T(val); + shared_memory[tid] += val; + idx += block_dim; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint s = block_dim / 2; s > 0; s >>= 1) { + if (tid < s) { + shared_memory[tid] += shared_memory[tid + s]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + const T inv_acc = T(1.0 / shared_memory[0]); + idx = start_idx + tid; + while (idx < stop_idx) { + dst[idx] *= inv_acc; + idx += block_dim; + } +} + +#define SOFTMAX(NAME, T) \ +kernel void NAME( \ + constant size_t &src_numel, \ + constant size_t &el_to_sum_per_block, \ + device const T *src, \ + device T *dst, \ + uint id [[ thread_position_in_grid ]], \ + uint tid [[ thread_index_in_threadgroup ]], \ + uint dst_id [[ threadgroup_position_in_grid ]], \ + uint block_dim [[ threads_per_threadgroup ]] \ +) { \ + threadgroup float shared_memory[THREADGROUP_SIZE]; \ + shared_memory[tid] = -INFINITY; \ + softmax(src_numel, el_to_sum_per_block, src, dst, id, tid, dst_id, block_dim, shared_memory); \ +} \ + +template +METAL_FUNC void rmsnorm( + constant size_t & src_numel, + constant size_t & el_to_sum_per_block, + device const T * src, + device T * dst, + device const T * alpha, + constant float & eps, + uint id, + uint tid, + uint dst_id, + uint block_dim, + threadgroup float * shared_memory +) { + size_t start_idx = dst_id * el_to_sum_per_block; + size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); + size_t idx = start_idx + tid; + + float tmp = 0; + while (idx < stop_idx) { + tmp = tmp + float(src[idx]) * float(src[idx]); + idx += block_dim; + } + shared_memory[tid] = tmp; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (uint s = block_dim / 2; s > 0; s >>= 1) { + if (tid < s) { + shared_memory[tid] = shared_memory[tid] + shared_memory[tid + s]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + /* wait for shared_memory[0] to be filled */ + threadgroup_barrier(mem_flags::mem_threadgroup); + + float norm = sqrt(shared_memory[0] / float(el_to_sum_per_block) + eps); + float inv_norm = 1.0f / norm; + idx = start_idx + tid; + while (idx < stop_idx) { + float val = float(src[idx]) * inv_norm; + if (alpha != nullptr) { + val *= float(alpha[idx - start_idx]); + } + dst[idx] = T(val); + idx += block_dim; + } +} + +template +METAL_FUNC void layernorm( + constant size_t & src_numel, + constant size_t & el_to_sum_per_block, + device const T * src, + device T * dst, + device const T * alpha, + device const T * beta, + constant float & eps, + uint id, + uint tid, + uint dst_id, + uint block_dim, + threadgroup float * shared_memory +) { + size_t start_idx = dst_id * el_to_sum_per_block; + size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); + size_t idx = start_idx + tid; + + float tmp1 = 0; + float tmp2 = 0; + while (idx < stop_idx) { + tmp1 += float(src[idx]); + tmp2 += float(src[idx]) * float(src[idx]); + idx += block_dim; + } + shared_memory[tid] = tmp1; + shared_memory[tid + block_dim] = tmp2; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (uint s = block_dim / 2; s > 0; s >>= 1) { + if (tid < s) { + shared_memory[tid] = shared_memory[tid] + shared_memory[tid + s]; + shared_memory[block_dim + tid] = shared_memory[block_dim + tid] + shared_memory[block_dim + tid + s]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + /* wait for shared_memory[0] to be filled */ + threadgroup_barrier(mem_flags::mem_threadgroup); + + float mean = shared_memory[0] / float(el_to_sum_per_block); + float var = shared_memory[block_dim] / float(el_to_sum_per_block) - mean * mean; + float inv_norm = 1.0f / sqrt(var + eps); + idx = start_idx + tid; + while (idx < stop_idx) { + float val = (float(src[idx]) - mean) * inv_norm; + if (alpha != nullptr) { + val *= float(alpha[idx - start_idx]); + } + if (beta != nullptr) { + val += float(beta[idx - start_idx]); + } + dst[idx] = T(val); + idx += block_dim; + } +} + +#define RMSNORM(NAME, T) \ +kernel void NAME( \ + constant size_t &src_numel, \ + constant size_t &el_to_sum_per_block, \ + device const T *src, \ + device T *dst, \ + device const T *alpha, \ + constant float &eps, \ + uint id [[ thread_position_in_grid ]], \ + uint tid [[ thread_index_in_threadgroup ]], \ + uint dst_id [[ threadgroup_position_in_grid ]], \ + uint block_dim [[ threads_per_threadgroup ]] \ +) { \ + threadgroup float shared_memory[THREADGROUP_SIZE]; \ + shared_memory[tid] = 0; \ + rmsnorm(src_numel, el_to_sum_per_block, src, dst, alpha, eps, id, tid, dst_id, block_dim, shared_memory); \ +} \ + +#define LAYERNORM(NAME, T) \ +kernel void NAME( \ + constant size_t &src_numel, \ + constant size_t &el_to_sum_per_block, \ + device const T *src, \ + device T *dst, \ + device const T *alpha, \ + device const T *beta, \ + constant float &eps, \ + uint id [[ thread_position_in_grid ]], \ + uint tid [[ thread_index_in_threadgroup ]], \ + uint dst_id [[ threadgroup_position_in_grid ]], \ + uint block_dim [[ threads_per_threadgroup ]] \ +) { \ + threadgroup float shared_memory[THREADGROUP_SIZE]; \ + shared_memory[tid] = 0; \ + layernorm(src_numel, el_to_sum_per_block, src, dst, alpha, beta, eps, id, tid, dst_id, block_dim, shared_memory); \ +} \ + +template +METAL_FUNC void ropei( + constant size_t &bh, + constant size_t &td, + device const T *src, + device const T *cos, + device const T *sin, + device T *dst, + uint tid +) { + if (2 * tid >= bh * td) { + return; + } + size_t rope_idx = tid % (td / 2); + T c = cos[rope_idx]; + T s = sin[rope_idx]; + dst[2 * tid] = src[2 * tid] * c - src[2 * tid + 1] * s; + dst[2 * tid + 1] = src[2 * tid] * s + src[2 * tid + 1] * c; +} + +template +METAL_FUNC void rope( + constant size_t &bh, + constant size_t &td, + constant size_t &d, + device const T *src, + device const T *cos, + device const T *sin, + device T *dst, + uint idx +) { + if (2 * idx >= bh * td) { + return; + } + size_t i_bh = idx / (td / 2); + size_t i_td = idx - (td / 2) * i_bh; + size_t i_t = i_td / (d / 2); + size_t i_d = i_td - (d / 2) * i_t; + size_t i1 = i_bh * td + i_t * d + i_d; + size_t i2 = i1 + d / 2; + size_t i_cs = i_t * (d / 2) + i_d; + T c = cos[i_cs]; + T s = sin[i_cs]; + dst[i1] = src[i1] * c - src[i2] * s; + dst[i2] = src[i1] * s + src[i2] * c; +} + +template +METAL_FUNC void rope_thd( + constant size_t &b, + constant size_t &t, + constant size_t &h, + constant size_t &d, + device const T *src, + device const T *cos, + device const T *sin, + device T *dst, + uint idx +) { + if (2 * idx >= b * t * h * d) { + return; + } + const size_t i_bth = idx / (d / 2); + const size_t i_d = idx - (d / 2) * i_bth; + const size_t i_t = (i_bth / h) % t; + const size_t i1 = i_bth * d + i_d; + const size_t i2 = i1 + d / 2; + const size_t i_cs = i_t * (d / 2) + i_d; + T c = cos[i_cs]; + T s = sin[i_cs]; + dst[i1] = src[i1] * c - src[i2] * s; + dst[i2] = src[i1] * s + src[i2] * c; +} + +#define ROPE(FN_NAME, FN_NAME_I, FN_NAME_THD, TYPENAME) \ +kernel void FN_NAME_I( \ + constant size_t &bh, \ + constant size_t &td, \ + device const TYPENAME *src, \ + device const TYPENAME *cos, \ + device const TYPENAME *sin, \ + device TYPENAME *dst, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + ropei(bh, td, src, cos, sin, dst, tid); \ +}\ +kernel void FN_NAME( \ + constant size_t &bh, \ + constant size_t &td, \ + constant size_t &d, \ + device const TYPENAME *src, \ + device const TYPENAME *cos, \ + device const TYPENAME *sin, \ + device TYPENAME *dst, \ + uint idx [[ thread_position_in_grid ]] \ +) { \ + rope(bh, td, d, src, cos, sin, dst, idx); \ +}\ +kernel void FN_NAME_THD( \ + constant size_t &b, \ + constant size_t &t, \ + constant size_t &h, \ + constant size_t &d, \ + device const TYPENAME *src, \ + device const TYPENAME *cos, \ + device const TYPENAME *sin, \ + device TYPENAME *dst, \ + uint idx [[ thread_position_in_grid ]] \ +) { \ + rope_thd(b, t, h, d, src, cos, sin, dst, idx); \ +}\ + + +template +kernel void attn_soft_max( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant float & scale, + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + const int64_t i03 = (tgpig) / (ne02*ne01); + const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; + const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); + + device const T * psrc0 = (device const T *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr; + device T * pdst = (device T *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + + float slope = 1.0f; + + // parallel max + float lmax = -INFINITY; + + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + lmax = MAX(lmax, ((float)psrc0[i00])*scale + (pmask ? slope*((float)pmask[i00]) : 0.0f)); + } + + // find the max value in the block + float max_val = simd_max(lmax); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = -INFINITY; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = max_val; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + max_val = buf[tiisg]; + max_val = simd_max(max_val); + } + + // parallel sum + float lsum = 0.0f; + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + const float exp_psrc0 = exp((((float)psrc0[i00])*scale + (pmask ? slope*((float)pmask[i00]) : 0.0f)) - max_val); + lsum += exp_psrc0; + pdst[i00] = static_cast(exp_psrc0); + } + + // This barrier fixes a failing test + // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335 + threadgroup_barrier(mem_flags::mem_none); + + float sum = simd_sum(lsum); + + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = sum; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sum = buf[tiisg]; + sum = simd_sum(sum); + } + + const float inv_sum = 1.0f/sum; + + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + pdst[i00] *= static_cast(inv_sum); + } +} + +template +kernel void attn_soft_max_4( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant float & scale, + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + const int64_t i03 = (tgpig) / (ne02*ne01); + const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; + const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); + + device const T * psrc4 = (device const T *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; + device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr; + device T * pdst4 = (device T *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; + + float slope = 1.0f; + + // parallel max + float4 lmax4 = -INFINITY; + + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + lmax4 = fmax(lmax4, ((float4)psrc4[i00])*scale + (float4)((pmask ? slope*((float4)pmask[i00]) : 0.0f))); + } + + const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); + + float max_val = simd_max(lmax); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = -INFINITY; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = max_val; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + max_val = buf[tiisg]; + max_val = simd_max(max_val); + } + + // parallel sum + float4 lsum4 = 0.0f; + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + const float4 exp_psrc4 = exp((((float4)psrc4[i00])*scale + (float4)((pmask ? slope*((float4)pmask[i00]) : 0.0f))) - max_val); + lsum4 += exp_psrc4; + pdst4[i00] = static_cast(exp_psrc4); + } + + const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3]; + + // This barrier fixes a failing test + // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335 + threadgroup_barrier(mem_flags::mem_none); + + float sum = simd_sum(lsum); + + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = sum; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sum = buf[tiisg]; + sum = simd_sum(sum); + } + + const float inv_sum = 1.0f/sum; + + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + pdst4[i00] = pdst4[i00] * static_cast((S)inv_sum); + } +} + + +REDUCE(x + y, fast_sum_f32_strided, float, 0) +REDUCE(x + y, fast_sum_u32_strided, uint, 0) +REDUCE(x + y, fast_sum_f16_strided, half, 0) +REDUCE(x + y, fast_sum_u8_strided, uint8_t, 0) +REDUCE(x * y, fast_mul_f32_strided, float, 1) +REDUCE(x * y, fast_mul_u32_strided, uint, 1) +REDUCE(x * y, fast_mul_f16_strided, half, 1) +REDUCE(MAX(x, y), fast_max_f32_strided, float, -HUGE_VALF) +REDUCE(MAX(x, y), fast_max_u32_strided, uint, 0) +REDUCE(MAX(x, y), fast_max_f16_strided, half, -HUGE_VALH) +REDUCE(MAX(x, y), fast_max_u8_strided, uint8_t, 0) +REDUCE(MIN(x, y), fast_min_f32_strided, float, HUGE_VALF) +REDUCE(MIN(x, y), fast_min_u32_strided, uint, 0xFFFFFFFF) +REDUCE(MIN(x, y), fast_min_f16_strided, half, HUGE_VALH) +REDUCE(MIN(x, y), fast_min_u8_strided, uint8_t, 0xFF) +ARGMIN(fast_argmin_f32_strided, float, HUGE_VALF) +ARGMIN(fast_argmin_f16_strided, half, HUGE_VALH) +ARGMIN(fast_argmin_u32_strided, uint, 0xFFFFFFFF) +ARGMIN(fast_argmin_u8_strided, uint8_t, 0xFF) +ARGMAX(fast_argmax_f32_strided, float, -HUGE_VALF) +ARGMAX(fast_argmax_f16_strided, half, -HUGE_VALH) +ARGMAX(fast_argmax_u32_strided, uint, 0) +ARGMAX(fast_argmax_u8_strided, uint8_t, 0) + +SOFTMAX(softmax_f32, float) +SOFTMAX(softmax_f16, half) +// Softmax for attention +typedef decltype(attn_soft_max) attn_soft_max_t; +typedef decltype(attn_soft_max_4) attn_soft_max_4_t; +template [[host_name("attn_soft_max_f16")]] kernel attn_soft_max_t attn_soft_max; +template [[host_name("attn_soft_max_f32")]] kernel attn_soft_max_t attn_soft_max; +template [[host_name("attn_soft_max_f16_4")]] kernel attn_soft_max_4_t attn_soft_max_4; +template [[host_name("attn_soft_max_f32_4")]] kernel attn_soft_max_4_t attn_soft_max_4; +RMSNORM(rmsnorm_f32, float) +RMSNORM(rmsnorm_f16, half) +LAYERNORM(layernorm_f32, float) +LAYERNORM(layernorm_f16, half) +ROPE(rope_f32, rope_i_f32, rope_thd_f32, float) +ROPE(rope_f16, rope_i_f16, rope_thd_f16, half) + +#if __METAL_VERSION__ >= 220 +REDUCE(x + y, fast_sum_i64_strided, int64_t, 0) +REDUCE(MIN(x, y), fast_min_i64_strided, int64_t, INT_MAX) +REDUCE(MAX(x, y), fast_max_i64_strided, int64_t, INT_MIN) +ARGMIN(fast_argmin_i64_strided, int64_t, INT_MAX) +ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN) +#endif + +REDUCE(x + y, fast_sum_i32_strided, int32_t, 0) +REDUCE(MIN(x, y), fast_min_i32_strided, int32_t, INT_MAX) +REDUCE(MAX(x, y), fast_max_i32_strided, int32_t, INT_MIN) +ARGMIN(fast_argmin_i32_strided, int32_t, INT_MAX) +ARGMAX(fast_argmax_i32_strided, int32_t, INT_MIN) + +REDUCE(x + y, fast_sum_i16_strided, int16_t, 0) +REDUCE(MIN(x, y), fast_min_i16_strided, int16_t, INT_MAX) +REDUCE(MAX(x, y), fast_max_i16_strided, int16_t, INT_MIN) +ARGMIN(fast_argmin_i16_strided, int16_t, INT_MAX) +ARGMAX(fast_argmax_i16_strided, int16_t, INT_MIN) + +#if defined(__HAVE_BFLOAT__) +REDUCE(x + y, fast_sum_bf16, bfloat, 0) +REDUCE(x + y, fast_sum_bf16_strided, half, 0) +REDUCE(x * y, fast_mul_bf16, bfloat, 1) +REDUCE(x * y, fast_mul_bf16_strided, bfloat, 1) +REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF) +REDUCE(MAX(x, y), fast_max_bf16_strided, bfloat, -HUGE_VALBF) +REDUCE(MIN(x, y), fast_min_bf16, bfloat, HUGE_VALBF) +REDUCE(MIN(x, y), fast_min_bf16_strided, bfloat, HUGE_VALBF) +ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF) +ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF) +SOFTMAX(softmax_bf16, bfloat) +// // Softmax for attention +template [[host_name("attn_soft_max_bf16")]] kernel attn_soft_max_t attn_soft_max; +template [[host_name("attn_soft_max_bf16_4")]] kernel attn_soft_max_4_t attn_soft_max_4; +RMSNORM(rmsnorm_bf16, bfloat) +LAYERNORM(layernorm_bf16, bfloat) +ROPE(rope_bf16, rope_i_bf16, rope_thd_bf16, bfloat) +#endif diff --git a/mcandle/mcandle-metal-kernels/src/scaled_dot_product_attention.metal b/mcandle/mcandle-metal-kernels/src/scaled_dot_product_attention.metal new file mode 100644 index 000000000..1abb9f080 --- /dev/null +++ b/mcandle/mcandle-metal-kernels/src/scaled_dot_product_attention.metal @@ -0,0 +1,1257 @@ +// Updated from MLX commit has f70764a + +#include +#include + +using namespace metal; + +// ============ "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h" + +struct MLXFastAttentionParams { + const int M; + const int N; + const int K; + + const int ldq; // ldq == ldo + const int ldk; + const int ldv; + const int lds; + const int ldo; + + const int tiles_n; + const int tiles_m; + + const int batch_stride_q; + const int batch_stride_k; + const int batch_stride_v; + const int batch_stride_o; + + const int swizzle_log; + const int gemm_n_iterations_aligned; + const int gemm_k_iterations_aligned; + const int gemm_sv_m_block_iterations; + + const int batch_ndim; + const float alpha; + const float softcapping; +}; + +struct MLXScaledDotProductAttentionParams { + // Associated dimensions & transposition information + const uint QUERY_SEQUENCE_LENGTH = 1; + const uint N_Q_HEADS = 32; + const uint N_KV_HEADS = 32; + const uint KV_TILES = 1; + const float INV_ALPHA = 0.08838834764831843f; +}; + +// ============ "mlx/backend/metal/kernels/scaled_dot_product_attention_params.sdpa_vector" + +template +[[kernel]] void sdpa_vector( + const device T* queries [[buffer(0)]], + const device T* keys [[buffer(1)]], + const device T* values [[buffer(2)]], + device T* out [[buffer(3)]], + const constant int& gqa_factor, + const constant int& N, + const constant size_t& k_stride, + const constant size_t& v_stride, + const constant float& scale, + const constant float& softcapping, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int BN = 32; + constexpr int BD = 32; + constexpr int elem_per_thread = D / BD; + + const int stride = BN * D; + + typedef float U; + + thread U q[elem_per_thread]; + thread U k[elem_per_thread]; + thread U o[elem_per_thread]; + + threadgroup U outputs[BN * BD]; + threadgroup U max_scores[BN]; + threadgroup U sum_exp_scores[BN]; + + // Adjust positions + const int head_idx = tid.y; + const int kv_head_idx = head_idx / gqa_factor; + queries += head_idx * D + simd_lid * elem_per_thread; + keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread; + values += kv_head_idx * v_stride + simd_gid * D + simd_lid * elem_per_thread; + out += head_idx * D + simd_gid * elem_per_thread; + + // Read the query and 0 the output accumulator + for (int i = 0; i < elem_per_thread; i++) { + q[i] = static_cast(scale) * queries[i]; + } + for (int i = 0; i < elem_per_thread; i++) { + o[i] = 0; + } + + U max_score = -INFINITY; + U sum_exp_score = 0; + + // For each key + for (int i = simd_gid; i < N; i += BN) { + // Read the key + for (int i = 0; i < elem_per_thread; i++) { + k[i] = keys[i]; + } + + // Compute the i-th score + U score = 0; + for (int i = 0; i < elem_per_thread; i++) { + score += q[i] * k[i]; + } + score = simd_sum(score); + if (softcapping != 1.) { + score = precise::tanh(score); + score = score * softcapping; + } + + // Update the accumulators + U new_max = max(max_score, score); + U factor = fast::exp(max_score - new_max); + U exp_score = fast::exp(score - new_max); + + max_score = new_max; + sum_exp_score = sum_exp_score * factor + exp_score; + + // Update the output accumulator + for (int i = 0; i < elem_per_thread; i++) { + o[i] = o[i] * factor + exp_score * values[i]; + } + + // Move the pointers to the next kv + keys += stride; + values += stride; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Each thread has a partial part of the output so we need to combine them. + + // First let's communicate the max and sum_exp + if (simd_lid == 0) { + max_scores[simd_gid] = max_score; + sum_exp_scores[simd_gid] = sum_exp_score; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + max_score = max_scores[simd_lid]; + U new_max = simd_max(max_score); + U factor = fast::exp(max_score - new_max); + sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor); + + // Now we need to aggregate all the outputs + for (int i = 0; i < elem_per_thread; i++) { + outputs[simd_lid * BD + simd_gid] = o[i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // And write the output + if (simd_lid == 0) { + for (int i = 0; i < elem_per_thread; i++) { + out[i] = static_cast(o[i]); + } + } +} + +// ============ "mlx/backend/metal/kernels/steel/defines.h" + +#define STEEL_CONST static constant constexpr const +#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") + +// ============ "mlx/backend/metal/kernels/steel/gemm/transforms.h" + +template +struct TransformNone { + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } + + static METAL_FUNC OutT apply(InT x, OutT) { + return static_cast(x); + } +}; + +template +struct TransformAdd { + TransformAdd(const float, const float) {} + + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } + + static METAL_FUNC OutT apply(InT x, OutT c) { + return static_cast(x) + c; + } +}; + +template +struct TransformAxpby { + const float alpha; + const float beta; + + TransformAxpby(const float alpha_, const float beta_) + : alpha(alpha_), beta(beta_) {} + + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } + + METAL_FUNC OutT apply(InT x, OutT c) const { + return static_cast(x * alpha + (beta * c)); + } +}; + +template +struct AccumHelper { + typedef float accum_type; +}; + +struct BlockSwizzle { + static METAL_FUNC int2 + swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) { + const int tid_x = (tid.x) >> swizzle_log; + const int tid_y = + ((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1)); + return int2(tid_x, tid_y); + } +}; + +// ============ "mlx/backend/metal/kernels/utils.h" + +#if defined(__HAVE_BFLOAT__) +typedef bfloat bfloat16_t; +#endif +typedef half float16_t; + +METAL_FUNC ulong2 elem_to_loc_broadcast( + uint elem, + constant const int* shape, + constant const size_t* a_strides, + constant const size_t* b_strides, + int ndim) { + ulong loc_a{0}; + ulong loc_b{0}; + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { + int pos_in_dim = (elem % shape[i]); + elem /= shape[i]; + loc_a += pos_in_dim * a_strides[i]; + loc_b += pos_in_dim * b_strides[i]; + } + return ulong2(loc_a, loc_b); +} + +METAL_FUNC ulong3 elem_to_loc_broadcast( + uint elem, + constant const int* shape, + constant const size_t* a_strides, + constant const size_t* b_strides, + constant const size_t* c_strides, + int ndim) { + ulong loc_a{0}; + ulong loc_b{0}; + ulong loc_c{0}; + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { + int pos_in_dim = (elem % shape[i]); + elem /= shape[i]; + loc_a += pos_in_dim * a_strides[i]; + loc_b += pos_in_dim * b_strides[i]; + loc_c += pos_in_dim * c_strides[i]; + } + return ulong3(loc_a, loc_b, loc_c); +} + +// ============ "mlx/backend/metal/kernels/scaled_dot_product_attention_params.metal" + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size, + short alignment = 1, + short n_reads = (BCOLS * BROWS) / (tgp_size), + short TCOLS = BCOLS / n_reads, + short TROWS = tgp_size / TCOLS> +struct BlockLoaderFA { + STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS; + STEEL_CONST short vec_size = n_reads; + + // Leading dimension for src + const int src_ld; + const int tile_stride; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + const device T* src; + + struct alignas(alignment * sizeof(T)) ReadVector { + uint8_t v[sizeof(T) * vec_size]; + }; + + /* Constructor */ + METAL_FUNC BlockLoaderFA( + const device T* src_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tile_stride(reduction_dim ? BCOLS : BROWS * src_ld), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bi * src_ld + bj) {} + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + *((threadgroup ReadVector*)(&dst[i * dst_ld])) = + *((const device ReadVector*)(&src[i * src_ld])); + } + } + + /* Load from device memory into threadgroup memory - with bound checking */ + METAL_FUNC void load_safe(short2 src_tile_dim) const { + src_tile_dim = src_tile_dim - short2(bj, bi); + + // Skip loading if thread has no valid reads + if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + return; + } + + // Use fast thread memory for bound checks + bool tmp_idx[vec_size]; + T tmp_val[vec_size]; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + // Make sure tmp_idx only contains valid indices + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x); + } + + // Read valid indices into tmp_val + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; + } + + // Zero out uneeded values + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); + } + + // Copy values to threadgroup memory + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = tmp_val[j]; + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + src += tile_stride; + } + METAL_FUNC void next(short n) { + src += n * tile_stride; + } +}; + +template +struct LoopAlignment {}; + +template < + typename T, + typename U, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + short lda_tgp, + short ldb_tgp, + typename AccumType = float, + typename Epilogue = TransformNone> +struct BlockMMAFA { + // Warp tile simdgroup matrix strides along M + STEEL_CONST short TM_stride = 8 * WM; + // Warp tile simdgroup matrix strides along M + STEEL_CONST short TN_stride = 8 * WN; + + // Warp tile size along M + STEEL_CONST short TM = BM / TM_stride; + // Warp tile size along N + STEEL_CONST short TN = BN / TN_stride; + + // Strides of A, B along reduction axis + STEEL_CONST short simd_stride_a = { + transpose_a ? TM_stride : TM_stride * lda_tgp}; + STEEL_CONST short simd_stride_b = { + transpose_b ? TN_stride * ldb_tgp : TN_stride}; + + // Jump between elements + STEEL_CONST short jump_a = {transpose_a ? lda_tgp : 1}; + STEEL_CONST short jump_b = {transpose_b ? ldb_tgp : 1}; + + STEEL_CONST short tile_stride_a = {transpose_a ? 8 * lda_tgp : 8}; + STEEL_CONST short tile_stride_b = {transpose_b ? 8 : 8 * ldb_tgp}; + + // Simdgroup matrices + simdgroup_matrix Asimd[TM]; + simdgroup_matrix Bsimd[TN]; + simdgroup_matrix results[TM * TN] = { + simdgroup_matrix(0)}; + + // Offsets within threadgroup + const short tm; + const short tn; + + short sm; + short sn; + + ushort sid; + ushort slid; + + short As_offset; + short Bs_offset; + + /* Constructor */ + METAL_FUNC BlockMMAFA( + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) { + // Determine thread position in simdgroup matrix + short qid = simd_lane_id / 4; + slid = simd_lane_id; + sid = simd_group_id; + + sm = (qid & 4) + (simd_lane_id / 2) % 4; + sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; + + // Determine thread and simdgroup offset + As_offset = + transpose_a ? ((sn)*lda_tgp + (tm + sm)) : ((sn) + (tm + sm) * lda_tgp); + Bs_offset = + transpose_b ? ((tn + sn) * ldb_tgp + (sm)) : ((sm)*ldb_tgp + (tn + sn)); + } + + /* (BM, BK) X (BK, BN) multiply accumulate function */ + METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) { + // Adjust for simdgroup and thread location + As += As_offset; + Bs += Bs_offset; + + // Iterate over BK in blocks of 8 + STEEL_PRAGMA_UNROLL + for (short kk = 0; kk < BK; kk += 8) { + simdgroup_barrier(mem_flags::mem_none); + + // Load elements from threadgroup A as simdgroup matrices + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + Asimd[i].thread_elements()[0] = + static_cast(As[i * simd_stride_a + 0]); + Asimd[i].thread_elements()[1] = + static_cast(As[i * simd_stride_a + jump_a]); + } + + simdgroup_barrier(mem_flags::mem_none); + + // Load elements from threadgroup B as simdgroup matrices + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + Bsimd[j].thread_elements()[0] = + static_cast(Bs[j * simd_stride_b + 0]); + Bsimd[j].thread_elements()[1] = + static_cast(Bs[j * simd_stride_b + jump_b]); + } + + simdgroup_barrier(mem_flags::mem_none); + + // Multiply and accumulate into result simdgroup matrices + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + short j_serp = (i % 2) ? (TN - 1 - j) : j; + + simdgroup_multiply_accumulate( + results[i * TN + j_serp], + Asimd[i], + Bsimd[j_serp], + results[i * TN + j_serp]); + } + } + + // Progress to next simdgroup tile + As += tile_stride_a; + Bs += tile_stride_b; + } + } + + METAL_FUNC void rescale_output(const threadgroup float* Corrections) { + // Loop over all simdgroup tiles + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + short row = sm + tm + i * TM_stride; + float scale_value = Corrections[row]; + + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread auto& accum = results[i * TN + j].thread_elements(); + // int offset = (i * TM_stride) * ldc + (j * TN_stride); + accum[0] *= scale_value; + accum[1] *= scale_value; + } + } + } + + /* Store results from simdgroup_matrix results into device memory */ + METAL_FUNC void store_result(device U* C, const int ldc) const { + // Adjust for simdgroup and thread location + C += (sm + tm) * ldc + tn + sn; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = results[i * TN + j].thread_elements(); + int offset = (i * TM_stride) * ldc + (j * TN_stride); + + // Apply epilogue + U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])}; + + // Write out C + C[offset] = outs[0]; + C[offset + 1] = outs[1]; + } + } + } + + METAL_FUNC void store_result_to_tgp_memory( + threadgroup U* C, + const int ldc, + short2 dst_tile_dims) const { + // Adjust for simdgroup and thread location + C += (sm + tm) * ldc + (tn + sn); + dst_tile_dims -= short2(tn + sn, sm + tm); + + STEEL_PRAGMA_UNROLL + for (int i = 0; i < TM; i++) { + if (i * TM_stride < dst_tile_dims.y) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = results[i * TN + j].thread_elements(); + int offset = (i * TM_stride) * ldc + (j * TN_stride); + + // Apply epilogue and output C + if (j * TN_stride < dst_tile_dims.x) { + C[offset] = Epilogue::apply(accum[0]); + } + + if (j * TN_stride + 1 < dst_tile_dims.x) { + C[offset + 1] = Epilogue::apply(accum[1]); + } + } + } + } + } + + METAL_FUNC void + store_result_safe(device U* C, const int ldc, short2 dst_tile_dims) const { + // Adjust for simdgroup and thread location + C += (sm + tm) * ldc + (tn + sn); + dst_tile_dims -= short2(tn + sn, sm + tm); + + STEEL_PRAGMA_UNROLL + for (int i = 0; i < TM; i++) { + if (i * TM_stride < dst_tile_dims.y) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = results[i * TN + j].thread_elements(); + int offset = (i * TM_stride) * ldc + (j * TN_stride); + + // Apply epilogue and output C + if (j * TN_stride < dst_tile_dims.x) { + C[offset] = Epilogue::apply(accum[0]); + } + + if (j * TN_stride + 1 < dst_tile_dims.x) { + C[offset + 1] = Epilogue::apply(accum[1]); + } + } + } + } + } + + /* Store results from simdgroup_matrix results into device memory */ + METAL_FUNC void store_result( + device U* D, + const int ldd, + const device U* C, + const int ldc, + const int fdc, + thread const Epilogue& epilogue_op) const { + // Adjust for simdgroup and thread location + C += (sm + tm) * ldc + (tn + sn) * fdc; + D += (sm + tm) * ldd + tn + sn; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = results[i * TN + j].thread_elements(); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + int offset_d = (i * TM_stride) * ldd + (j * TN_stride); + + // Apply epilogue + U outs[2] = { + epilogue_op.apply(accum[0], C[offset_c]), + epilogue_op.apply(accum[1], C[offset_c + fdc])}; + + // Write out D + D[offset_d] = outs[0]; + D[offset_d + 1] = outs[1]; + } + } + } + + METAL_FUNC void store_result_safe( + device U* D, + const int ldd, + const device U* C, + const int ldc, + const int fdc, + short2 dst_tile_dims, + thread const Epilogue& epilogue_op) const { + // Adjust for simdgroup and thread location + C += (sm + tm) * ldc + (tn + sn) * fdc; + D += (sm + tm) * ldd + tn + sn; + dst_tile_dims -= short2(tn + sn, sm + tm); + + STEEL_PRAGMA_UNROLL + for (int i = 0; i < TM; i++) { + if (i * TM_stride < dst_tile_dims.y) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = results[i * TN + j].thread_elements(); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + int offset_d = (i * TM_stride) * ldd + (j * TN_stride); + + // Apply epilogue and output C + if (j * TN_stride < dst_tile_dims.x) { + D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]); + } + + if (j * TN_stride + 1 < dst_tile_dims.x) { + D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]); + } + } + } + } + } + + METAL_FUNC void clear_results() { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < TN; j++) { + results[i * TN + j] = simdgroup_matrix(0); + } + } + } +}; + +template < + typename T, + typename U, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_q, + bool transpose_k, + bool transpose_v, + bool MN_aligned, + bool K_aligned, + typename AccumType = typename AccumHelper::accum_type, + typename Epilogue = TransformNone> +struct FastAttentionKernel { + STEEL_CONST short tgp_padding = 16 / sizeof(T); + STEEL_CONST short float_padding = 16 / sizeof(float); + STEEL_CONST short tgp_mem_size_q = + transpose_q ? BK * (BM + tgp_padding) : BM * (BK + tgp_padding); + STEEL_CONST short tgp_mem_size_k = + transpose_k ? BK * (BN + tgp_padding) : BN * (BK + tgp_padding); + STEEL_CONST short tgp_mem_size_v = + transpose_v ? BK * (BN + tgp_padding) : BN * (BK + tgp_padding); + STEEL_CONST short tgp_mem_size_s = BM * (BN + tgp_padding); + + // maxes, rowsums, rescale + STEEL_CONST short tgp_mem_size_corrections = + 4 * (BM * sizeof(float) + float_padding); + + STEEL_CONST bool share_kv_smem = transpose_k != transpose_v; + + STEEL_CONST short tgp_mem_size = share_kv_smem + ? tgp_mem_size_q + tgp_mem_size_k + tgp_mem_size_s + + tgp_mem_size_corrections + : tgp_mem_size_q + tgp_mem_size_k + tgp_mem_size_s + + tgp_mem_size_corrections + tgp_mem_size_v; + + STEEL_CONST short tgp_size = WM * WN * 32; + + static_assert(transpose_q == false, "Expected Q not transposed."); + static_assert(transpose_k == true, "Expected K transposed."); + static_assert(transpose_v == false, "Expected V not transposed."); + static_assert(tgp_mem_size <= 32768, "Excessive tgp memory requested."); + + using loader_q_t = BlockLoaderFA< + T, + transpose_q ? BK : BM, + transpose_q ? BM : BK, + transpose_q ? BM + tgp_padding : BK + tgp_padding, + !transpose_q, + tgp_size>; + + using loader_k_t = BlockLoaderFA< + T, + transpose_k ? BN : BK, + transpose_k ? BK : BN, + transpose_k ? BK + tgp_padding : BN + tgp_padding, + transpose_k, + tgp_size>; + + using loader_v_t = BlockLoaderFA< + T, + transpose_v ? BK : BN, + transpose_v ? BN : BK, + transpose_v ? BN + tgp_padding : BK + tgp_padding, + transpose_v, + tgp_size>; + + using mma_qk_t = BlockMMAFA< + T, + U, + BM, + BN, + BK, + WM, + WN, + transpose_q, + transpose_k, + transpose_q ? BM + tgp_padding : BK + tgp_padding, + transpose_k ? BK + tgp_padding : BN + tgp_padding, + AccumType, + Epilogue>; + + using mma_sv_t = BlockMMAFA< + T, + U, + BM, + BK, + BN, + WM, + WN, + false, + transpose_v, + BN + tgp_padding, + BK + tgp_padding, + AccumType, + Epilogue>; + + /* Main kernel function */ + template + static METAL_FUNC void gemm_loop( + threadgroup T* As [[threadgroup(0)]], + threadgroup T* Bs [[threadgroup(1)]], + const int gemm_k_iterations, + thread loader_k_t& loader_b, + thread mma_qk_t& mma_op, + thread const short& tgp_bm, + thread const short& tgp_bn, + LoopAlignment l = {}) { + // Appease the compiler + (void)l; + (void)tgp_bm; + + short2 tile_dims_B = transpose_k ? short2(BK, tgp_bn) : short2(tgp_bn, BK); + + // not valid for gemm_k_iterations > 1 (so, BK == d_k) + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (N_aligned) { + loader_b.load_unsafe(); + } else { + loader_b.load_safe(tile_dims_B); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + } + } + + static METAL_FUNC void initialize_corrections( + threadgroup float* C, + uint simd_lane_id, + uint simd_group_id) { + if (simd_group_id == 0) { + threadgroup float* maxes = C; + threadgroup float* sums = C + (BM + float_padding); + threadgroup float* o_rescale = sums + (BM + float_padding); + threadgroup float* output_rescale = o_rescale + (BM + float_padding); + + if (simd_lane_id < BM) { + maxes[simd_lane_id] = -INFINITY; // m_i + sums[simd_lane_id] = 0.f; // l_i + o_rescale[simd_lane_id] = 1.f; // li * exp(mi - mi_new) + output_rescale[simd_lane_id] = 1.f; // 1.0 / l_i + } + } + } + + static METAL_FUNC void rescale_ss( + threadgroup T* Ss, + threadgroup float* Corrections, + uint simd_group_id, + uint simd_lane_id, + short2 local_blocks, + float alpha, + float softcapping) { + if (simd_group_id == 0) { + short row_offset = BM + float_padding; + threadgroup float* maxes = Corrections; + threadgroup float* sums = Corrections + row_offset; + threadgroup float* o_rescale = sums + row_offset; + threadgroup float* output_scales = o_rescale + row_offset; + + if (simd_lane_id < uint(local_blocks.y)) { + float m_i_old = maxes[simd_lane_id]; + float l_i_old = sums[simd_lane_id]; + + float m_i_new = m_i_old; + float l_i_new = l_i_old; + + short offset = simd_lane_id * (BN + tgp_padding); + + float m_ij = -INFINITY; + + for (short j = 0; j < local_blocks.x; j++) { + float val = alpha * float(Ss[offset + j]); + if (softcapping != 1.) { + val = precise::tanh(val); + val = val * softcapping; + } + m_ij = max(m_ij, val); + } + + m_i_new = max(m_ij, m_i_new); + + float rowsum = 0.f; // lij + + for (short j = 0; j < local_blocks.x; j++) { + float val = alpha * float(Ss[offset + j]); + if (softcapping != 1.) { + val = precise::tanh(val); + val = val * softcapping; + } + float P_i_j = exp(val - m_ij); + rowsum += P_i_j; + P_i_j = P_i_j * exp(m_ij - m_i_new); + Ss[offset + j] = T(P_i_j); + } + + l_i_new = + exp(m_i_old - m_i_new) * l_i_old + exp(m_ij - m_i_new) * rowsum; + maxes[simd_lane_id] = m_i_new; + sums[simd_lane_id] = l_i_new; + float rescale = l_i_old * exp(m_i_old - m_i_new); + o_rescale[simd_lane_id] = rescale; + output_scales[simd_lane_id] = 1.0 / l_i_new; + } + } + } + + /* Main kernel function */ + static METAL_FUNC void run( + const device T* Q [[buffer(0)]], + const device T* K [[buffer(1)]], + const device T* V [[buffer(2)]], + device U* O [[buffer(3)]], + const constant MLXFastAttentionParams* params [[buffer(4)]], + threadgroup T* Qs [[threadgroup(0)]], + threadgroup T* Ks [[threadgroup(1)]], + threadgroup T* Ss [[threadgroup(2)]], + threadgroup T* Vs [[threadgroup(3)]], + threadgroup float* Corrections [[threadgroup(4)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + // Pacifying compiler + (void)lid; + + const int tid_y = ((tid.y) << params->swizzle_log) + + ((tid.x) & ((1 << params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> params->swizzle_log; + + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } + + threadgroup_barrier(mem_flags::mem_none); + + // Find block in Q, O; and head in K, V. + const int c_row = tid_y * BM; + + Q += transpose_q ? c_row : c_row * params->ldq; + thread loader_q_t loader_q(Q, params->ldq, Qs, simd_group_id, simd_lane_id); + + short tgp_bm = min(BM, params->M - c_row); + short2 tile_dims_Q = transpose_q ? short2(tgp_bm, BK) : short2(BK, tgp_bm); + + loader_q.load_safe(tile_dims_Q); + + initialize_corrections(Corrections, simd_lane_id, simd_group_id); + + O += c_row * params->ldo; + + // Prepare threadgroup mma operation + thread mma_qk_t mma_qk_op(simd_group_id, simd_lane_id); + thread mma_sv_t mma_softmax_sv_op(simd_group_id, simd_lane_id); + thread loader_k_t loader_k(K, params->ldk, Ks, simd_group_id, simd_lane_id); + thread loader_v_t loader_v(V, params->ldv, Vs, simd_group_id, simd_lane_id); + + for (short n_block = 0; n_block < params->gemm_n_iterations_aligned; + n_block++) { + short c_col = BN; + + // Prepare threadgroup loading operations + short gemm_k_iterations = params->gemm_k_iterations_aligned; + short tgp_bn_qk = min(BN, params->N - c_col * n_block); + threadgroup_barrier(mem_flags::mem_none); + + /////////////////////////////////////////////////////////////////////////////// + { // Loop over K - unaligned case + + if (tgp_bm == BM && tgp_bn_qk == BN) { + gemm_loop( + Qs, + Ks, + gemm_k_iterations, + loader_k, + mma_qk_op, + tgp_bm, + tgp_bn_qk); + } else if (tgp_bn_qk == BN) { + gemm_loop( + Qs, + Ks, + gemm_k_iterations, + loader_k, + mma_qk_op, + tgp_bm, + tgp_bn_qk); + + } else if (tgp_bm == BM) { + gemm_loop( + Qs, + Ks, + gemm_k_iterations, + loader_k, + mma_qk_op, + tgp_bm, + tgp_bn_qk); + + } else { + gemm_loop( + Qs, + Ks, + gemm_k_iterations, + loader_k, + mma_qk_op, + tgp_bm, + tgp_bn_qk); + } + } + + mma_qk_op.store_result_to_tgp_memory( + Ss, BN + tgp_padding, short2(BN, BM)); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + rescale_ss( + Ss, + Corrections, + simd_group_id, + simd_lane_id, + short2(tgp_bn_qk, tgp_bm), + params->alpha, + params->softcapping); + + loader_v.load_safe(short2(BK, tgp_bn_qk)); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + threadgroup float* o_scales = Corrections + 2 * (BM + float_padding); + mma_softmax_sv_op.rescale_output(o_scales); + + mma_softmax_sv_op.mma(Ss, Vs); + + threadgroup float* final_output_scales = + Corrections + 3 * (BM + float_padding); + + mma_softmax_sv_op.rescale_output(final_output_scales); + + loader_v.next(); + loader_k.next(BN); + + mma_qk_op.clear_results(); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_softmax_sv_op.store_result_safe(O, params->ldo, short2(BK, tgp_bm)); + } +}; + +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_q, + bool transpose_k, + bool transpose_v, + bool MN_aligned, + bool K_aligned> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void attention( + const device T* Q [[buffer(0)]], + const device T* K [[buffer(1)]], + const device T* V [[buffer(2)]], + device T* O [[buffer(3)]], + const constant MLXFastAttentionParams* params [[buffer(4)]], + const constant int* batch_shape [[buffer(6)]], + const constant size_t* batch_strides [[buffer(7)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + using attention_kernel = FastAttentionKernel< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_q, + transpose_k, + transpose_v, + MN_aligned, + K_aligned>; + + // Adjust for batch + if (params->batch_ndim > 1) { + const constant size_t* Q_bstrides = batch_strides; + const constant size_t* KV_bstrides = batch_strides + params->batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, Q_bstrides, KV_bstrides, params->batch_ndim); + + Q += batch_offsets.x; + K += batch_offsets.y; + V += batch_offsets.y; + + } else { + Q += params->batch_stride_q * tid.z; + K += params->batch_stride_k * tid.z; + V += params->batch_stride_v * tid.z; + } + + // same shape as input + O += params->batch_stride_o * tid.z; + threadgroup T Qs[attention_kernel::tgp_mem_size_q]; + threadgroup T Ss[attention_kernel::tgp_mem_size_s]; + threadgroup float Corrections[attention_kernel::tgp_mem_size_corrections]; + + if (attention_kernel::share_kv_smem) { + threadgroup T Ks[attention_kernel::tgp_mem_size_k]; + threadgroup T* Vs = Ks; //[attention_kernel::tgp_mem_size_v]; + attention_kernel::run( + Q, + K, + V, + O, + params, + Qs, + Ks, + Ss, + Vs, + Corrections, + simd_lane_id, + simd_group_id, + tid, + lid); + } else { + threadgroup T Ks[attention_kernel::tgp_mem_size_k]; + threadgroup T Vs[attention_kernel::tgp_mem_size_v]; + attention_kernel::run( + Q, + K, + V, + O, + params, + Qs, + Ks, + Ss, + Vs, + Corrections, + simd_lane_id, + simd_group_id, + tid, + lid); + } +} + +// clang-format off + +// SDPA full instantiations +#define instantiate_fast_inference_self_attention_kernel( \ + itype, otype, bm, bn, bk, wm, wn) \ + template [[host_name("steel_gemm_attention_bm_" #bm "_bn_" #bn "_bk_" #bk \ + "_itype_" #itype)]] [[kernel]] void \ + attention( \ + const device itype* Q [[buffer(0)]], \ + const device itype* K [[buffer(1)]], \ + const device itype* V [[buffer(2)]], \ + device otype* O [[buffer(3)]], \ + const constant MLXFastAttentionParams* params [[buffer(4)]], \ + const constant int* batch_shape [[buffer(5)]], \ + const constant size_t* batch_strides [[buffer(6)]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]]); + +instantiate_fast_inference_self_attention_kernel( + float, + float, + 16, + 16, + 32, + 2, + 2); +instantiate_fast_inference_self_attention_kernel( + float, + float, + 16, + 16, + 64, + 2, + 2); +instantiate_fast_inference_self_attention_kernel( + float, + float, + 16, + 16, + 96, + 2, + 2); +instantiate_fast_inference_self_attention_kernel( + float, + float, + 16, + 16, + 128, + 2, + 2); +instantiate_fast_inference_self_attention_kernel( + float, + float, + 16, + 16, + 256, + 2, + 2); +instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 32, 2, 2); +instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 64, 2, 2); +instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 96, 2, 2); +instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 128, 2, 2); +instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 256, 2, 2); + +// SDPA vector instantiations +#define instantiate_sdpa_vector(type, head_dim) \ + template [[host_name("sdpa_vector_" #type "_" #head_dim)]] \ + [[kernel]] void sdpa_vector( \ + const device type* queries [[buffer(0)]], \ + const device type* keys [[buffer(1)]], \ + const device type* values [[buffer(2)]], \ + device type* out [[buffer(3)]], \ + const constant int& gqa_factor, \ + const constant int& N, \ + const constant size_t& k_stride, \ + const constant size_t& v_stride, \ + const constant float& scale, \ + const constant float& softcapping, \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); + +#define instantiate_sdpa_vector_heads(type) \ + instantiate_sdpa_vector(type, 32) \ + instantiate_sdpa_vector(type, 64) \ + instantiate_sdpa_vector(type, 96) \ + instantiate_sdpa_vector(type, 128) \ + instantiate_sdpa_vector(type, 256) + +instantiate_sdpa_vector_heads(float) +#if defined(__HAVE_BFLOAT__) +instantiate_sdpa_vector_heads(bfloat16_t) +#endif +instantiate_sdpa_vector_heads(float16_t) + // clang-format on diff --git a/mcandle/mcandle-metal-kernels/src/sort.metal b/mcandle/mcandle-metal-kernels/src/sort.metal new file mode 100644 index 000000000..9f001d8fb --- /dev/null +++ b/mcandle/mcandle-metal-kernels/src/sort.metal @@ -0,0 +1,99 @@ +// Imported from https://github.com/ggerganov/llama.cpp/blob/master/ggml-metal.metal +#include +using namespace metal; + +#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; } +#define SORT_ASC 1 +#define SORT_DESC 0 + +template +METAL_FUNC void argsort( + device const T * x, + device uint32_t * dst, + constant int64_t & ncols, + constant int64_t & ncols_pad, + threadgroup uint32_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]]) { + int col = tpitg[0]; + int row = tgpig[1]; + + if (col >= ncols_pad) return; + + device const T * x_row = x + row * ncols; + threadgroup uint32_t * dst_row = shared_values; + + // initialize indices + dst_row[col] = col; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int k = 2; k <= ncols_pad; k *= 2) { + for (int j = k / 2; j > 0; j /= 2) { + int ixj = col ^ j; + if (ixj > col) { + if ((col & k) == 0) { + if (dst_row[col] >= ncols || + (dst_row[ixj] < ncols && (order == SORT_ASC ? + x_row[dst_row[col]] > x_row[dst_row[ixj]] : + x_row[dst_row[col]] < x_row[dst_row[ixj]])) + ) { + SWAP(dst_row[col], dst_row[ixj]); + } + } else { + if (dst_row[ixj] >= ncols || + (dst_row[col] < ncols && (order == SORT_ASC ? + x_row[dst_row[col]] < x_row[dst_row[ixj]] : + x_row[dst_row[col]] > x_row[dst_row[ixj]])) + ) { + SWAP(dst_row[col], dst_row[ixj]); + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + } + + // copy the result to dst without the padding + if (col < ncols) { + dst[row * ncols + col] = dst_row[col]; + } +} + +#define ARGSORT(T, RUST_T) \ +kernel void asort_asc_##RUST_T( \ + device const T * x, \ + device uint32_t * dst, \ + constant int64_t & ncols, \ + constant int64_t & ncols_pad, \ + threadgroup uint32_t * shared_values [[threadgroup(0)]], \ + uint3 tgpig[[threadgroup_position_in_grid]], \ + uint3 tpitg[[thread_position_in_threadgroup]] \ +) { \ + argsort(x, dst, ncols, ncols_pad, shared_values, tgpig, tpitg); \ +} \ +kernel void asort_desc_##RUST_T( \ + device const T * x, \ + device uint32_t * dst, \ + constant int64_t & ncols, \ + constant int64_t & ncols_pad, \ + threadgroup uint32_t * shared_values [[threadgroup(0)]], \ + uint3 tgpig[[threadgroup_position_in_grid]], \ + uint3 tpitg[[thread_position_in_threadgroup]] \ +) { \ + argsort(x, dst, ncols, ncols_pad, shared_values, tgpig, tpitg); \ +} \ + +ARGSORT(float, f32) +ARGSORT(half, f16) +ARGSORT(uint8_t, u8) +ARGSORT(uint32_t, u32) +ARGSORT(int32_t, i32) +ARGSORT(int16_t, i16) + +#if __METAL_VERSION__ >= 220 +ARGSORT(int64_t, i64) +#endif +#if defined(__HAVE_BFLOAT__) +ARGSORT(bfloat, bf16) +#endif diff --git a/mcandle/mcandle-metal-kernels/src/ternary.metal b/mcandle/mcandle-metal-kernels/src/ternary.metal new file mode 100644 index 000000000..98aacd003 --- /dev/null +++ b/mcandle/mcandle-metal-kernels/src/ternary.metal @@ -0,0 +1,115 @@ +#include +using namespace metal; + +METAL_FUNC uint get_strided_index( + uint idx, + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides +) { + uint strided_i = 0; + for (uint d = 0; d < num_dims; d++) { + uint dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; +} + +template +METAL_FUNC void where_cond( + constant size_t &numel, + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides, + constant size_t *strides_t, + constant size_t *strides_f, + device const ID *ids, + device const T *t, + device const T *f, + device T *out, + uint i [[ thread_position_in_grid ]] +) { + if (i >= numel){ + return; + } + uint strided_i = get_strided_index(i, num_dims, dims, strides); + uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); + uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); + out[i] = ids[strided_i] ? t[strided_i_t] : f[strided_i_f]; +} + +#define WHERE_OP(T, ID, FN_NAME) \ +kernel void FN_NAME( \ + constant size_t &numel, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + constant size_t *strides_t, \ + constant size_t *strides_f, \ + device const ID *ids, \ + device const T *t, \ + device const T *f, \ + device T *out, \ + uint i [[ thread_position_in_grid ]] \ +) { \ + where_cond(numel, num_dims, dims, strides, strides_t, strides_f, ids, t, f, out, i); \ +} \ + +WHERE_OP(half, uint32_t, where_u32_f16) +WHERE_OP(float, uint32_t, where_u32_f32) +WHERE_OP(uint8_t, uint32_t, where_u32_u8) +WHERE_OP(uint32_t, uint32_t, where_u32_u32) + +WHERE_OP(half, uint8_t, where_u8_f16) +WHERE_OP(float, uint8_t, where_u8_f32) +WHERE_OP(uint8_t, uint8_t, where_u8_u8) +WHERE_OP(uint32_t, uint8_t, where_u8_u32) + +#if __METAL_VERSION__ >= 220 +WHERE_OP(int64_t, uint8_t, where_u8_i64) +WHERE_OP(int64_t, uint32_t, where_u32_i64) + +WHERE_OP(half, int64_t, where_i64_f16) +WHERE_OP(float, int64_t, where_i64_f32) +WHERE_OP(uint8_t, int64_t, where_i64_u8) +WHERE_OP(uint32_t, int64_t, where_i64_u32) +WHERE_OP(int64_t, int64_t, where_i64_i64) +WHERE_OP(int64_t, int32_t, where_i64_i32) +WHERE_OP(int64_t, int16_t, where_i64_i16) +#if defined(__HAVE_BFLOAT__) +WHERE_OP(bfloat, int64_t, where_i64_bf16) +#endif +#endif + +WHERE_OP(int64_t, uint8_t, where_u8_i32) +WHERE_OP(int64_t, uint32_t, where_u32_i32) + +WHERE_OP(half, int32_t, where_i32_f16) +WHERE_OP(float, int32_t, where_i32_f32) +WHERE_OP(uint8_t, int32_t, where_i32_u8) +WHERE_OP(uint32_t, int32_t, where_i32_u32) +WHERE_OP(int64_t, int32_t, where_i32_i64) +WHERE_OP(int32_t, int32_t, where_i32_i32) +#if defined(__HAVE_BFLOAT__) +WHERE_OP(bfloat, int32_t, where_i32_bf16) +#endif + +WHERE_OP(int64_t, uint8_t, where_u8_i16) +WHERE_OP(int64_t, uint32_t, where_u32_i16) + +WHERE_OP(half, int16_t, where_i16_f16) +WHERE_OP(float, int16_t, where_i16_f32) +WHERE_OP(uint8_t, int16_t, where_i16_u8) +WHERE_OP(uint32_t, int16_t, where_i16_u32) +WHERE_OP(int64_t, int16_t, where_i16_i64) +WHERE_OP(int32_t, int16_t, where_i16_i32) +WHERE_OP(int16_t, int16_t, where_i16_i16) +#if defined(__HAVE_BFLOAT__) +WHERE_OP(bfloat, int16_t, where_i16_bf16) +#endif + +#if defined(__HAVE_BFLOAT__) +WHERE_OP(bfloat, uint8_t, where_u8_bf16) +WHERE_OP(bfloat, uint32_t, where_u32_bf16) +#endif diff --git a/mcandle/mcandle-metal-kernels/src/tests.rs b/mcandle/mcandle-metal-kernels/src/tests.rs new file mode 100644 index 000000000..028694d2a --- /dev/null +++ b/mcandle/mcandle-metal-kernels/src/tests.rs @@ -0,0 +1,2342 @@ +use super::*; +use half::{bf16, f16}; +use metal::MTLResourceOptions; +use rand::Rng; + +fn read_to_vec(buffer: &Buffer, n: usize) -> Vec { + let ptr = buffer.contents() as *const T; + assert!(!ptr.is_null()); + let slice = unsafe { std::slice::from_raw_parts(ptr, n) }; + slice.to_vec() +} + +fn new_buffer(device: &Device, data: &[T]) -> Buffer { + let options = MTLResourceOptions::StorageModeManaged; + let ptr = data.as_ptr() as *const c_void; + let size = std::mem::size_of_val(data) as u64; + device.new_buffer_with_data(ptr, size, options) +} + +fn device() -> Device { + Device::system_default().unwrap() +} + +fn approx(v: Vec, digits: i32) -> Vec { + let b = 10f32.powi(digits); + v.iter().map(|t| f32::round(t * b) / b).collect() +} + +fn approx_f16(v: Vec, digits: i32) -> Vec { + let b = 10f32.powi(digits); + v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect() +} + +fn approx_bf16(v: Vec, digits: i32) -> Vec { + let b = 10f32.powi(digits); + v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect() +} + +fn run(v: &[T], name: unary::contiguous::Kernel) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let input = new_buffer(&device, v); + let input = BufferOffset { + buffer: &input, + offset_in_bytes: 0, + }; + let output = new_buffer(&device, v); + call_unary_contiguous( + &device, + command_buffer, + &kernels, + name, + v.len(), + input, + &output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + read_to_vec(&output, v.len()) +} + +fn run_binary(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let options = MTLResourceOptions::StorageModeManaged; + let left = new_buffer(&device, x); + let right = new_buffer(&device, y); + let output = device.new_buffer(std::mem::size_of_val(x) as u64, options); + call_binary_contiguous( + &device, + command_buffer, + &kernels, + name, + x.len(), + BufferOffset::zero_offset(&left), + BufferOffset::zero_offset(&right), + &output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + read_to_vec(&output, x.len()) +} + +fn run_strided( + v: &[T], + kernel: unary::strided::Kernel, + shape: &[usize], + strides: &[usize], + offset: usize, +) -> Vec { + let device = device(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let input = new_buffer(&device, v); + let input = BufferOffset { + buffer: &input, + offset_in_bytes: offset, + }; + let output_b = new_buffer(&device, v); + let output = BufferOffset { + buffer: &output_b, + offset_in_bytes: 0, + }; + let kernels = Kernels::new(); + call_unary_strided( + &device, + command_buffer, + &kernels, + kernel, + shape, + input, + strides, + output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + read_to_vec(&output_b, v.len()) +} + +#[test] +fn cos_f32() { + let v = vec![1.0f32, 2.0, 3.0]; + let results = run(&v, unary::contiguous::cos::FLOAT); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!(approx(results, 4), vec![0.5403, -0.4161, -0.99]); + assert_eq!(approx(expected, 4), vec![0.5403, -0.4161, -0.99]); + + let v = vec![1.0f32; 10_000]; + let results = run(&v, unary::contiguous::cos::FLOAT); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!(approx(results, 4), vec![0.5403; 10_000]); + assert_eq!(approx(expected, 4), vec![0.5403; 10_000]); +} + +#[test] +fn cos_f32_strided() { + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let shape = vec![6]; + let strides = vec![1]; + let offset = 0; + let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!( + approx(results, 4), + vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602] + ); + assert_eq!( + approx(expected, 4), + vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602] + ); + + // Contiguous + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let shape = vec![3, 2]; + let strides = vec![2, 1]; + let offset = 0; + let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!( + approx(results, 4), + vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602] + ); + assert_eq!( + approx(expected, 4), + vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602] + ); + + // Transposed + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let shape = vec![3, 2]; + let strides = vec![1, 3]; + let offset = 0; + let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!( + approx(results, 4), + vec![0.5403, -0.6536, -0.4161, 0.2837, -0.99, 0.9602] + ); + assert_eq!( + approx(expected, 4), + vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602] + ); + + // Very large + let v = vec![1.0f32; 10_000]; + let shape = vec![2, 5_000]; + let strides = vec![2, 1]; + let offset = 0; + let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!(approx(results, 4), vec![0.5403; 10_000]); + assert_eq!(approx(expected, 4), vec![0.5403; 10_000]); +} + +#[test] +fn cos_strided_random() { + let v: Vec<_> = (0..10_000).map(|_| rand::random::()).collect(); + let shape = vec![5_000, 2]; + let strides = vec![1, 5_000]; + let offset = 0; + let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!(approx(vec![results[0]], 4), approx(vec![expected[0]], 4)); + assert_eq!( + approx(vec![results[1]], 4), + approx(vec![expected[5_000]], 4) + ); + assert_eq!(approx(vec![results[2]], 4), approx(vec![expected[1]], 4)); + assert_eq!( + approx(vec![results[3]], 4), + approx(vec![expected[5_001]], 4) + ); + assert_eq!( + approx(vec![results[5_000]], 4), + approx(vec![expected[2_500]], 4) + ); +} + +#[test] +fn gelu_f16() { + let v: Vec = [-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0] + .iter() + .map(|v| f16::from_f32(*v)) + .collect(); + let expected: Vec = vec![-0.0, -0.16, 0.0, 0.84, 1.96, 3.0, 10.0, 20.0]; + let results = run(&v, unary::contiguous::gelu::HALF); + assert_eq!(approx_f16(results, 2), expected); +} + +#[test] +fn gelu_f32() { + let v: Vec = vec![-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0]; + let expected: Vec = vec![-0.0, -0.159, 0.0, 0.841, 1.955, 2.996, 10.0, 20.0]; + let results = run(&v, unary::contiguous::gelu::FLOAT); + assert_eq!(approx(results, 3), expected); +} + +#[test] +fn silu_f16() { + let v: Vec = [-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0] + .iter() + .map(|v| f16::from_f32(*v)) + .collect(); + let expected: Vec = vec![-0.0, -0.27, 0.0, 0.73, 1.76, 2.86, 10.0, 20.0]; + let results = run(&v, unary::contiguous::silu::HALF); + assert_eq!(approx_f16(results, 2), expected); +} + +#[test] +fn silu_f32() { + let v: Vec = vec![-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0]; + let expected: Vec = vec![-0.0, -0.269, 0.0, 0.731, 1.762, 2.858, 10.0, 20.0]; + let results = run(&v, unary::contiguous::silu::FLOAT); + assert_eq!(approx(results, 3), expected); +} + +#[test] +fn binary_add_f32() { + let left = vec![1.0f32, 2.0, 3.0]; + let right = vec![2.0f32, 3.1, 4.2]; + let results = run_binary(&left, &right, binary::contiguous::add::FLOAT); + let expected: Vec<_> = left + .iter() + .zip(right.iter()) + .map(|(&x, &y)| x + y) + .collect(); + assert_eq!(approx(results, 4), vec![3.0f32, 5.1, 7.2]); + assert_eq!(approx(expected, 4), vec![3.0f32, 5.1, 7.2]); +} + +#[test] +fn binary_ops_bf16() { + let lhs: Vec = [1.1f32, 2.2, 3.3].into_iter().map(bf16::from_f32).collect(); + let rhs: Vec = [4.2f32, 5.5f32, 6.91f32] + .into_iter() + .map(bf16::from_f32) + .collect(); + + macro_rules! binary_op { + ($opname:ident, $opexpr:expr) => {{ + let results = run_binary(&lhs, &rhs, binary::contiguous::$opname::BFLOAT); + let expected: Vec = lhs + .iter() + .zip(rhs.iter()) + .map(|(x, y): (&bf16, &bf16)| $opexpr(*x, *y)) + .collect(); + assert_eq!(results, expected); + }}; + } + + binary_op!(add, |x, y| x + y); + binary_op!(sub, |x, y| x - y); + binary_op!(mul, |x, y| x * y); + binary_op!(div, |x, y| x / y); + binary_op!(min, |x: bf16, y| x.min(y)); + binary_op!(max, |x: bf16, y| x.max(y)); +} + +fn run_cast(v: &[T], name: &'static str) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let input = new_buffer(&device, v); + let options = MTLResourceOptions::StorageModeManaged; + let size = (v.len() * std::mem::size_of::()) as u64; + let output = device.new_buffer(size, options); + + call_cast_contiguous( + &device, + command_buffer, + &kernels, + name, + v.len(), + BufferOffset::zero_offset(&input), + &output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + read_to_vec(&output, v.len()) +} + +#[test] +fn cast_f32() { + let v_f64 = [1.0f64, 2.0, 3.0]; + let v_f32: Vec = v_f64.iter().map(|&v| v as f32).collect(); + let v_f16: Vec = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect(); + let v_bf16: Vec = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect(); + let v_u32: Vec = v_f64.iter().map(|&v| v as u32).collect(); + let v_u8: Vec = v_f64.iter().map(|&v| v as u8).collect(); + let v_i64: Vec = v_f64.iter().map(|&v| v as i64).collect(); + + // f32 -> f16 + let results: Vec = run_cast(&v_f32, "cast_f32_f16"); + assert_eq!(results, v_f16); + + // f32 -> bf16 + let results: Vec = run_cast(&v_f32, "cast_f32_bf16"); + assert_eq!(results, v_bf16); + + // f32 -> u32 + let results: Vec = run_cast(&v_f32, "cast_f32_u32"); + assert_eq!(results, v_u32); + + // f32 -> u8 + let results: Vec = run_cast(&v_f32, "cast_f32_u8"); + assert_eq!(results, v_u8); + + // f32 -> i64 + let results: Vec = run_cast(&v_f32, "cast_f32_i64"); + assert_eq!(results, v_i64); +} + +#[test] +fn cast_f16() { + let v_f64 = [1.0f64, 2.0, 3.0]; + let v_f32: Vec = v_f64.iter().map(|&v| v as f32).collect(); + let v_f16: Vec = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect(); + let v_bf16: Vec = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect(); + let v_u32: Vec = v_f64.iter().map(|&v| v as u32).collect(); + let v_u8: Vec = v_f64.iter().map(|&v| v as u8).collect(); + let v_i64: Vec = v_f64.iter().map(|&v| v as i64).collect(); + + // f16 -> f32 + let results: Vec = run_cast(&v_f16, "cast_f16_f32"); + assert_eq!(results, v_f32); + + // f16 -> bf16 + let results: Vec = run_cast(&v_f16, "cast_f16_bf16"); + assert_eq!(results, v_bf16); + + // f16 -> u32 + let results: Vec = run_cast(&v_f16, "cast_f16_u32"); + assert_eq!(results, v_u32); + + // f16 -> u8 + let results: Vec = run_cast(&v_f16, "cast_f16_u8"); + assert_eq!(results, v_u8); + + // f16 -> i64 + let results: Vec = run_cast(&v_f16, "cast_f16_i64"); + assert_eq!(results, v_i64); +} + +#[test] +fn cast_bf16() { + let v_f64 = [1.0f64, 2.0, 3.0]; + let v_f32: Vec = v_f64.iter().map(|&v| v as f32).collect(); + let v_f16: Vec = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect(); + let v_bf16: Vec = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect(); + let v_u32: Vec = v_f64.iter().map(|&v| v as u32).collect(); + let v_u8: Vec = v_f64.iter().map(|&v| v as u8).collect(); + let v_i64: Vec = v_f64.iter().map(|&v| v as i64).collect(); + + // bf16 -> f32 + let results: Vec = run_cast(&v_bf16, "cast_bf16_f32"); + assert_eq!(results, v_f32); + + // bf16 -> f16 + let results: Vec = run_cast(&v_bf16, "cast_bf16_f16"); + assert_eq!(results, v_f16); + + // bf16 -> u32 + let results: Vec = run_cast(&v_bf16, "cast_bf16_u32"); + assert_eq!(results, v_u32); + + // bf16 -> u8 + let results: Vec = run_cast(&v_bf16, "cast_bf16_u8"); + assert_eq!(results, v_u8); + + // bf16 -> i64 + let results: Vec = run_cast(&v_bf16, "cast_bf16_i64"); + assert_eq!(results, v_i64); +} + +#[test] +fn cast_u32() { + let v_f64 = [1.0f64, 2.0, 3.0]; + let v_f32: Vec = v_f64.iter().map(|&v| v as f32).collect(); + let v_f16: Vec = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect(); + let v_bf16: Vec = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect(); + let v_u32: Vec = v_f64.iter().map(|&v| v as u32).collect(); + let v_u8: Vec = v_f64.iter().map(|&v| v as u8).collect(); + let v_i64: Vec = v_f64.iter().map(|&v| v as i64).collect(); + + // u32 -> f32 + let results: Vec = run_cast(&v_u32, "cast_u32_f32"); + assert_eq!(results, v_f32); + + // u32 -> f16 + let results: Vec = run_cast(&v_u32, "cast_u32_f16"); + assert_eq!(results, v_f16); + + // u32 -> bf16 + let results: Vec = run_cast(&v_u32, "cast_u32_bf16"); + assert_eq!(results, v_bf16); + + // u32 -> u8 + let results: Vec = run_cast(&v_u32, "cast_u32_u8"); + assert_eq!(results, v_u8); + + // u32 -> i64 + let results: Vec = run_cast(&v_u32, "cast_u32_i64"); + assert_eq!(results, v_i64); +} + +#[test] +fn cast_u8() { + let v_f64 = [1.0f64, 2.0, 3.0]; + let v_f32: Vec = v_f64.iter().map(|&v| v as f32).collect(); + let v_f16: Vec = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect(); + let v_bf16: Vec = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect(); + let v_u32: Vec = v_f64.iter().map(|&v| v as u32).collect(); + let v_u8: Vec = v_f64.iter().map(|&v| v as u8).collect(); + let v_i64: Vec = v_f64.iter().map(|&v| v as i64).collect(); + + // u8 -> f32 + let results: Vec = run_cast(&v_u8, "cast_u8_f32"); + assert_eq!(results, v_f32); + + // u8 -> f16 + let results: Vec = run_cast(&v_u8, "cast_u8_f16"); + assert_eq!(results, v_f16); + + // u8 -> bf16 + let results: Vec = run_cast(&v_u8, "cast_u8_bf16"); + assert_eq!(results, v_bf16); + + // u8 -> u32 + let results: Vec = run_cast(&v_u8, "cast_u8_u32"); + assert_eq!(results, v_u32); + + // u8 -> i64 + let results: Vec = run_cast(&v_u8, "cast_u8_i64"); + assert_eq!(results, v_i64); +} + +#[test] +fn cast_i64() { + let v_f64 = [1.0f64, 2.0, 3.0]; + let v_f32: Vec = v_f64.iter().map(|&v| v as f32).collect(); + let v_f16: Vec = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect(); + let v_bf16: Vec = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect(); + let v_u32: Vec = v_f64.iter().map(|&v| v as u32).collect(); + let v_u8: Vec = v_f64.iter().map(|&v| v as u8).collect(); + let v_i64: Vec = v_f64.iter().map(|&v| v as i64).collect(); + + // i64 -> f32 + let results: Vec = run_cast(&v_i64, "cast_i64_f32"); + assert_eq!(results, v_f32); + + // i64 -> f16 + let results: Vec = run_cast(&v_i64, "cast_i64_f16"); + assert_eq!(results, v_f16); + + // i64 -> bf16 + let results: Vec = run_cast(&v_i64, "cast_i64_bf16"); + assert_eq!(results, v_bf16); + + // i64 -> u32 + let results: Vec = run_cast(&v_i64, "cast_i64_u32"); + assert_eq!(results, v_u32); + + // i64 -> u8 + let results: Vec = run_cast(&v_i64, "cast_i64_u8"); + assert_eq!(results, v_u8); +} + +fn run_affine(v: &[T], mul: f64, add: f64) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + + let input = new_buffer(&device, v); + let output = new_buffer(&device, v); + + let size = v.len(); + + call_affine( + &device, + command_buffer, + &kernels, + "affine_f32", + size, + BufferOffset::zero_offset(&input), + &output, + mul as f32, + add as f32, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + read_to_vec(&output, v.len()) +} + +fn run_affine_strided( + v: &[T], + shape: &[usize], + strides: &[usize], + mul: f64, + add: f64, +) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + + let input = new_buffer(&device, v); + let output = new_buffer(&device, v); + + call_affine_strided( + &device, + command_buffer, + &kernels, + "affine_f32_strided", + shape, + BufferOffset::zero_offset(&input), + strides, + &output, + mul as f32, + add as f32, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + let len: usize = shape.iter().product(); + read_to_vec(&output, len) +} + +#[test] +fn affine() { + let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; + let mul = 1.5; + let add = 1.1; + let result = run_affine(&input, mul, add); + assert_eq!(result, vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1]); + + let input = [1.0f32; 40_000]; + let mul = 1.5; + let add = 1.1; + let result = run_affine(&input, mul, add); + assert_eq!(result, vec![2.6; 40_000]); +} + +#[test] +fn affine_strided() { + let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; + let mul = 1.5; + let add = 1.1; + let shape = [4]; + let strides = [2]; + let result = run_affine_strided(&input, &shape, &strides, mul, add); + // 1 on 2 + assert_eq!(result, vec![2.6, 5.6, 8.6, 11.6]); +} + +#[test] +fn index_select() { + let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; + let shape = [5, 2]; + let stride = [2, 1]; + let ids = [0u32, 4, 2]; + let dim = 0; + let result = run_index_select(&embedding, &shape, &stride, &ids, dim, "is_u32_f32"); + assert_eq!(result, vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]); + + let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; + let shape = [2, 5]; + let stride = [1, 2]; + let ids = [0u32, 1, 0]; + let dim = 0; + let result = run_index_select(&embedding, &shape, &stride, &ids, dim, "is_u32_f32"); + assert_eq!( + result, + vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0] + ); +} + +#[test] +fn index_select_strided() { + let embedding = (0..16).map(|x| x as f32).collect::>(); + let shape = [2, 2]; + let stride = [2, 4]; + let ids = [0u32]; + let dim = 0; + let result = run_index_select_strided(&embedding, &shape, &stride, &ids, dim, "is_u32_f32"); + assert_eq!(result, vec![0.0, 4.0]); +} + +#[test] +fn index_select_f16() { + let embedding: Vec<_> = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + .into_iter() + .map(f16::from_f32) + .collect(); + let shape = [5, 2]; + let stride = [2, 1]; + let ids = [0u32, 4, 2]; + let dim = 0; + let result = run_index_select(&embedding, &shape, &stride, &ids, dim, "is_u32_f16"); + assert_eq!( + approx_f16(result, 4), + vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0] + ); +} + +#[test] +fn index_select_is_u32_bf16() { + let embedding: Vec = (1..=10).map(|x| bf16::from_f32(x as f32)).collect(); + let shape = [5, 2]; + let stride = [2, 1]; + let ids = [0u32, 4, 2]; + let dim = 0; + let result = run_index_select(&embedding, &shape, &stride, &ids, dim, "is_u32_bf16"); + assert_eq!( + approx_bf16(result, 4), + vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0] + ); +} + +#[test] +fn index_select_is_u8_bf16() { + let embedding: Vec = (1..=10).map(|x| bf16::from_f32(x as f32)).collect(); + let shape = [5, 2]; + let stride = [2, 1]; + let ids = [0u8, 4, 2]; + let dim = 0; + let result = run_index_select(&embedding, &shape, &stride, &ids, dim, "is_u8_bf16"); + assert_eq!( + approx_bf16(result, 4), + vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0] + ); +} + +#[test] +fn index_select_dim1() { + let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; + let shape = [5, 2]; + let stride = [2, 1]; + let ids = [0u32, 1, 0]; + let dim = 1; + let result = run_index_select(&embedding, &shape, &stride, &ids, dim, "is_u32_f32"); + assert_eq!( + result, + vec![1.0f32, 2.0, 1.0, 3.0, 4.0, 3.0, 5.0, 6.0, 5.0, 7.0, 8.0f32, 7.0, 9.0, 10.0, 9.0] + ); +} + +fn run_index_select( + embeddings: &[T], + shape: &[usize], + stride: &[usize], + ids: &[I], + dim: usize, + name: &'static str, +) -> Vec { + let device = Device::system_default().expect("no device found"); + + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let embeddings_buffer = new_buffer(&device, embeddings); + let ids_buffer = new_buffer(&device, ids); + + let left_size: usize = shape[..dim].iter().product(); + let right_size: usize = shape[dim + 1..].iter().product(); + let dst_el = ids.len() * left_size * right_size; + let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]); + + let kernels = Kernels::new(); + call_index_select( + &device, + command_buffer, + &kernels, + name, + shape, + ids.len(), + dim, + true, + shape, + stride, + BufferOffset::zero_offset(&embeddings_buffer), + BufferOffset::zero_offset(&ids_buffer), + &dst_buffer, + ) + .unwrap(); + + command_buffer.commit(); + command_buffer.wait_until_completed(); + + read_to_vec(&dst_buffer, dst_el) +} + +fn run_index_select_strided( + embeddings: &[T], + shape: &[usize], + stride: &[usize], + ids: &[I], + dim: usize, + name: &'static str, +) -> Vec { + let device = Device::system_default().expect("no device found"); + + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let embeddings_buffer = new_buffer(&device, embeddings); + let ids_buffer = new_buffer(&device, ids); + + let left_size: usize = shape[..dim].iter().product(); + let right_size: usize = shape[dim + 1..].iter().product(); + let dst_el = ids.len() * left_size * right_size; + let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]); + + let kernels = Kernels::new(); + call_index_select( + &device, + command_buffer, + &kernels, + name, + shape, + ids.len(), + dim, + false, + shape, + stride, + BufferOffset::zero_offset(&embeddings_buffer), + BufferOffset::zero_offset(&ids_buffer), + &dst_buffer, + ) + .unwrap(); + + command_buffer.commit(); + command_buffer.wait_until_completed(); + + read_to_vec(&dst_buffer, dst_el) +} + +#[test] +fn cos_f16() { + let v: Vec = [1.0f32, 2.0, 3.0] + .iter() + .map(|v| f16::from_f32(*v)) + .collect(); + let results = run(&v, unary::contiguous::cos::HALF); + let expected: Vec = v.iter().map(|v| f16::from_f32(v.to_f32().cos())).collect(); + assert_eq!(approx_f16(results, 2), vec![0.54, -0.42, -0.99]); + assert_eq!(approx_f16(expected, 2), vec![0.54, -0.42, -0.99]); +} + +fn run_reduce(v: &[T], out_length: usize, name: &'static str) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let input = new_buffer(&device, v); + + let options = MTLResourceOptions::StorageModeManaged; + let output = device.new_buffer((out_length * core::mem::size_of::()) as u64, options); + let dims = vec![v.len()]; + let strides = vec![1]; + call_reduce_strided( + &device, + command_buffer, + &kernels, + name, + &dims, + &strides, + out_length, + BufferOffset::zero_offset(&input), + &output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + read_to_vec(&output, out_length) +} + +fn run_softmax(v: &[T], last_dim: usize, name: &'static str) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let input = new_buffer(&device, v); + let output = new_buffer(&device, v); + call_last_softmax( + &device, + command_buffer, + &kernels, + name, + v.len(), + last_dim, + &input, + 0, + &output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + read_to_vec(&output, v.len()) +} + +#[test] +fn reduce_sum() { + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let out_length = 1; + + let results = run_reduce(&v, out_length, "fast_sum_f32_strided"); + assert_eq!(approx(results, 4), vec![21.0]); +} + +#[test] +fn reduce_sum2() { + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let out_length = 2; + + let results = run_reduce(&v, out_length, "fast_sum_f32_strided"); + assert_eq!(approx(results, 4), vec![6.0, 15.0]); +} + +#[test] +fn softmax() { + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let last_dim = 6; + let results = run_softmax(&v, last_dim, "softmax_f32"); + assert_eq!( + approx(results, 4), + vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337] + ); + + let last_dim = 4096; + let n = 200; + let mut v = vec![0.0; n * last_dim]; + for i in 0..n { + v[i * last_dim] = 20.0; + } + let results = run_softmax(&v, last_dim, "softmax_f32"); + let results = approx(results, 4); + assert_eq!( + results.iter().map(|&s| s.round() as usize).sum::(), + n + ); + assert_eq!(results[0], 1.0); + assert_eq!(results[1], 0.0); + assert_eq!(results[last_dim], 1.0); + assert_eq!(results[2 * last_dim], 1.0); + + let v = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0]; + let last_dim = 6; + let results = run_softmax(&v, last_dim, "softmax_f32"); + assert_eq!( + approx(results, 4), + vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337] + ); + + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let last_dim = 3; + let results = run_softmax(&v, last_dim, "softmax_f32"); + assert_eq!( + approx(results, 4), + vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652] + ); + + let v = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0] + .iter() + .map(|v| f16::from_f32(*v)) + .collect::>(); + let last_dim = 6; + let results = run_softmax(&v, last_dim, "softmax_f16"); + assert_eq!( + approx_f16(results, 4), + vec![0.0043, 0.0116, 0.0316, 0.0858, 0.2332, 0.6338] + ); + + let v = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0] + .iter() + .map(|v| bf16::from_f32(*v)) + .collect::>(); + let last_dim = 6; + let results = run_softmax(&v, last_dim, "softmax_bf16"); + assert_eq!( + approx_bf16(results, 4), + vec![0.0043, 0.0116, 0.0315, 0.0859, 0.2324, 0.6328] + ); +} + +#[allow(clippy::too_many_arguments)] +fn run_where_cond( + shape: &[usize], + cond: &[I], + (cond_stride, cond_offset): (Vec, usize), + left_true: &[T], + (left_stride, left_offset): (Vec, usize), + right_false: &[T], + (_right_stride, _right_offset): (Vec, usize), + name: &'static str, +) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let options = MTLResourceOptions::StorageModeManaged; + + let length = cond.len(); + let cond = device.new_buffer_with_data( + cond.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(cond) as u64, + options, + ); + let left = device.new_buffer_with_data( + left_true.as_ptr() as *const core::ffi::c_void, + (length * core::mem::size_of::()) as u64, + options, + ); + let right = device.new_buffer_with_data( + right_false.as_ptr() as *const core::ffi::c_void, + (length * core::mem::size_of::()) as u64, + options, + ); + + let output = device.new_buffer((length * core::mem::size_of::()) as u64, options); + let cond = BufferOffset { + buffer: &cond, + offset_in_bytes: cond_offset, + }; + let left = BufferOffset { + buffer: &left, + offset_in_bytes: left_offset, + }; + let right = BufferOffset { + buffer: &right, + offset_in_bytes: cond_offset, + }; + call_where_cond_strided( + &device, + command_buffer, + &kernels, + name, + shape, + cond, + &cond_stride, + left, + &left_stride, + right, + &cond_stride, + &output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + read_to_vec(&output, length) +} + +#[test] +fn where_cond() { + let shape = vec![6]; + let cond = vec![0u8, 1, 0, 0, 1, 1]; + let cond_l = (vec![1], 0); + let left_true = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let left_l = (vec![1], 0); + let right_false = vec![-1.0f32, -2.0, -3.0, -4.0, -5.0, -6.0]; + let right_l = (vec![1], 0); + let results = run_where_cond( + &shape, + &cond, + cond_l, + &left_true, + left_l, + &right_false, + right_l, + "where_u8_f32", + ); + assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]); +} +#[test] +fn where_cond_u32_f32() { + let shape = vec![6]; + let cond = vec![0u32, 1, 0, 0, 1, 1]; + let cond_l = (vec![1], 0); + let left_true = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let left_l = (vec![1], 0); + let right_false = vec![-1.0f32, -2.0, -3.0, -4.0, -5.0, -6.0]; + let right_l = (vec![1], 0); + let results = run_where_cond( + &shape, + &cond, + cond_l, + &left_true, + left_l, + &right_false, + right_l, + "where_u32_f32", + ); + assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]); +} + +#[allow(clippy::too_many_arguments)] +fn run_gemm( + name: &'static str, + (b, m, n, k): (usize, usize, usize, usize), + lhs: &[T], + lhs_stride: &[usize], + lhs_offset: usize, + rhs: &[T], + rhs_stride: &[usize], + rhs_offset: usize, +) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let options = MTLResourceOptions::StorageModeManaged; + + let lhs = device.new_buffer_with_data( + lhs.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(lhs) as u64, + options, + ); + let rhs = device.new_buffer_with_data( + rhs.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(rhs) as u64, + options, + ); + let length = b * m * n; + let output = device.new_buffer((length * core::mem::size_of::()) as u64, options); + call_gemm( + &device, + command_buffer, + &kernels, + name, + (b, m, n, k), + lhs_stride, + lhs_offset, + &lhs, + rhs_stride, + rhs_offset, + &rhs, + &output, + 1., + 0., + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + read_to_vec(&output, length) +} + +#[test] +fn gemm() { + let (b, m, n, k) = (1, 2, 4, 3); + let lhs_stride = vec![m * k, k, 1]; + let lhs: Vec = (0..b * m * k).map(|f| f as f32).collect(); + let rhs_stride = vec![n * k, n, 1]; + let rhs: Vec = (0..b * n * k).map(|f| f as f32).collect(); + let results = run_gemm( + "sgemm", + (b, m, n, k), + &lhs, + &lhs_stride, + 0, + &rhs, + &rhs_stride, + 0, + ); + assert_eq!( + approx(results, 4), + vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0] + ); + + let (b, m, n, k) = (2, 2, 4, 3); + let lhs_stride = vec![m * k, k, 1]; + let lhs: Vec = (0..b * m * k).map(|f| f as f32).collect(); + let rhs_stride = vec![n * k, n, 1]; + let rhs: Vec = (0..b * n * k).map(|f| f as f32).collect(); + let results = run_gemm( + "sgemm", + (b, m, n, k), + &lhs, + &lhs_stride, + 0, + &rhs, + &rhs_stride, + 0, + ); + assert_eq!( + approx(results, 4), + vec![ + 20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0, 344.0, 365.0, 386.0, 407.0, 488.0, + 518.0, 548.0, 578.0 + ] + ); + + // OFFSET + let (b, m, n, k) = (2, 2, 4, 3); + let lhs_stride = vec![m * k, k, 1]; + let lhs: Vec = (0..b * m * k).map(|f| f as f32).collect(); + let rhs_stride = vec![n * k, n, 1]; + let rhs: Vec = (0..b * n * k).map(|f| f as f32).collect(); + // Manually set batch_size=1 and offset 12 elements * 4 the number of bytes for f32 + let results = run_gemm( + "sgemm", + (1, m, n, k), + &lhs, + &lhs_stride, + 0, + &rhs, + &rhs_stride, + 12 * 4, + ); + assert_eq!( + approx(results, 4), + vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0] + ); + + // bgemm sanity test + if false { + let (b, m, n, k) = (1, 2, 4, 3); + let lhs_stride = vec![m * k, k, 1]; + let lhs: Vec = (0..b * m * k).map(|f| bf16::from_f32(f as f32)).collect(); + let rhs_stride = vec![n * k, n, 1]; + let rhs: Vec = (0..b * n * k).map(|f| bf16::from_f32(f as f32)).collect(); + let results = run_gemm( + "bgemm", + (b, m, n, k), + &lhs, + &lhs_stride, + 0, + &rhs, + &rhs_stride, + 0, + ); + assert_eq!( + approx_bf16(results, 4), + vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0] + ); + } + + // hgemm sanity test + let (b, m, n, k) = (1, 2, 4, 3); + let lhs_stride = vec![m * k, k, 1]; + let lhs: Vec = (0..b * m * k).map(|f| f16::from_f32(f as f32)).collect(); + let rhs_stride = vec![n * k, n, 1]; + let rhs: Vec = (0..b * n * k).map(|f| f16::from_f32(f as f32)).collect(); + let results = run_gemm( + "hgemm", + (b, m, n, k), + &lhs, + &lhs_stride, + 0, + &rhs, + &rhs_stride, + 0, + ); + assert_eq!( + approx_f16(results, 4), + vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0] + ); +} + +#[allow(clippy::too_many_arguments)] +fn run_mlx_gemm( + dtype: GemmDType, + (b, m, n, k): (usize, usize, usize, usize), + lhs: &[T], + lhs_stride: &[usize], + lhs_offset: usize, + rhs: &[T], + rhs_stride: &[usize], + rhs_offset: usize, +) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let options = MTLResourceOptions::StorageModeManaged; + + let lhs = device.new_buffer_with_data( + lhs.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(lhs) as u64, + options, + ); + let rhs = device.new_buffer_with_data( + rhs.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(rhs) as u64, + options, + ); + let length = b * m * n; + let output = device.new_buffer((length * core::mem::size_of::()) as u64, options); + call_mlx_gemm( + &device, + command_buffer, + &kernels, + dtype, + (b, m, n, k), + lhs_stride, + lhs_offset, + &lhs, + rhs_stride, + rhs_offset, + &rhs, + &output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + read_to_vec(&output, length) +} + +fn mlx_vs_mfa_one(b: usize, m: usize, n: usize, k: usize, dtype: GemmDType) { + use rand::SeedableRng; + use rand_distr::Distribution; + + let mut rng = rand::rngs::StdRng::seed_from_u64(42424242); + let normal = rand_distr::Normal::new(0.0, 1.0).unwrap(); + + let lhs: Vec<_> = (0..b * m * k).map(|_| normal.sample(&mut rng)).collect(); + let rhs: Vec<_> = (0..b * n * k).map(|_| normal.sample(&mut rng)).collect(); + let v1: Vec = run_mlx_gemm( + dtype, + (b, m, n, k), + &lhs, + &[m * k, k, 1], + 0, + &rhs, + &[k * n, n, 1], + 0, + ); + let v2: Vec = run_gemm( + "sgemm", + (b, m, n, k), + &lhs, + &[m * k, k, 1], + 0, + &rhs, + &[k * n, n, 1], + 0, + ); + for (a, b) in v1.iter().zip(v2.iter()) { + let diff = (a - b).abs(); + assert_eq!((diff * 1e4).round(), 0.) + } +} + +#[test] +fn mlx_vs_mfa() { + mlx_vs_mfa_one(1, 32, 32, 25, GemmDType::F32); + mlx_vs_mfa_one(1, 128, 128, 100, GemmDType::F32); + mlx_vs_mfa_one(1, 256, 256, 256, GemmDType::F32); + mlx_vs_mfa_one(1, 192, 200, 75, GemmDType::F32); + mlx_vs_mfa_one(3, 27, 67, 64, GemmDType::F32); +} + +#[test] +fn mlx_gemm() { + let (b, m, n, k) = (1, 2, 4, 3); + let lhs: Vec = (0..b * m * k).map(|f| f as f32).collect(); + let rhs: Vec = (0..b * n * k).map(|f| f as f32).collect(); + let results = run_mlx_gemm( + GemmDType::F32, + (b, m, n, k), + &lhs, + &[m * k, k, 1], + 0, + &rhs, + &[n * k, n, 1], + 0, + ); + assert_eq!( + approx(results, 4), + vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0] + ); + + let (b, m, n, k) = (2, 2, 4, 3); + let lhs: Vec = (0..b * m * k).map(|f| f as f32).collect(); + let rhs: Vec = (0..b * n * k).map(|f| f as f32).collect(); + let results = run_mlx_gemm( + GemmDType::F32, + (b, m, n, k), + &lhs, + &[m * k, k, 1], + 0, + &rhs, + &[n * k, n, 1], + 0, + ); + assert_eq!( + approx(results, 4), + vec![ + 20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0, 344.0, 365.0, 386.0, 407.0, 488.0, + 518.0, 548.0, 578.0 + ] + ); + + // OFFSET + let (b, m, n, k) = (2, 2, 4, 3); + let lhs: Vec = (0..b * m * k).map(|f| f as f32).collect(); + let rhs: Vec = (0..b * n * k).map(|f| f as f32).collect(); + // Manually set batch_size=1 and offset 12 elements * 4 the number of bytes for f32 + let results = run_mlx_gemm( + GemmDType::F32, + (1, m, n, k), + &lhs, + &[m * k, k, 1], + 0, + &rhs, + &[n * k, n, 1], + 12 * 4, + ); + assert_eq!( + approx(results, 4), + vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0] + ); + + // bgemm sanity test + { + let (b, m, n, k) = (1, 2, 4, 3); + let lhs: Vec = (0..b * m * k).map(|f| bf16::from_f32(f as f32)).collect(); + let rhs: Vec = (0..b * n * k).map(|f| bf16::from_f32(f as f32)).collect(); + let results = run_mlx_gemm( + GemmDType::BF16, + (b, m, n, k), + &lhs, + &[m * k, k, 1], + 0, + &rhs, + &[n * k, n, 1], + 0, + ); + assert_eq!( + approx_bf16(results, 4), + vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0] + ); + } + + { + // hgemm sanity test + let (b, m, n, k) = (1, 2, 4, 3); + let lhs: Vec = (0..b * m * k).map(|f| f16::from_f32(f as f32)).collect(); + let rhs: Vec = (0..b * n * k).map(|f| f16::from_f32(f as f32)).collect(); + let results = run_mlx_gemm( + GemmDType::F16, + (b, m, n, k), + &lhs, + &[m * k, k, 1], + 0, + &rhs, + &[n * k, n, 1], + 0, + ); + assert_eq!( + approx_f16(results, 4), + vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0] + ); + } +} + +fn run_random(name: &'static str, seed: u32, length: usize, a: f32, b: f32) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + + let options = MTLResourceOptions::StorageModeManaged; + let output = device.new_buffer((length * core::mem::size_of::()) as NSUInteger, options); + + let seed = device.new_buffer_with_data( + &seed as *const u32 as *const core::ffi::c_void, + std::mem::size_of::() as NSUInteger, + options, + ); + + if name.starts_with("rand_uniform") { + call_random_uniform( + &device, + command_buffer, + &kernels, + name, + a, + b, + length, + &seed, + &output, + ) + .unwrap(); + } else { + call_random_normal( + &device, + command_buffer, + &kernels, + name, + a, + b, + length, + &seed, + &output, + ) + .unwrap(); + } + command_buffer.commit(); + command_buffer.wait_until_completed(); + + read_to_vec(&output, length) +} + +#[test] +fn random() { + fn calc_mean(data: &[f32]) -> f32 { + let sum = data.iter().sum::(); + let count = data.len(); + assert!(count > 0); + sum / count as f32 + } + + fn calc_stddev(data: &[f32]) -> f32 { + let mean = calc_mean(data); + let count = data.len(); + assert!(count > 0); + + let variance = data + .iter() + .map(|value| { + let diff = mean - *value; + diff * diff + }) + .sum::() + / count as f32; + + variance.sqrt() + } + + let shape = [1024, 10]; + + let length = shape.iter().product::(); + let seed = 299792458; + + let min = -30.0; + let max = 30.0; + let mean = 100.0; + let stddev = 50.0; + + macro_rules! validate_random { + ($type:ty) => { + let results: Vec = run_random::<$type>( + concat!("rand_uniform_", stringify!($type)), + seed, + length, + min, + max, + ) + .into_iter() + .map(f32::from) + .collect(); + results.iter().for_each(|v| { + assert!(*v >= min && *v <= max); + }); + assert!(calc_mean(&results) > -1.0 && calc_mean(&results) < 1.0); + + let results: Vec = run_random::<$type>( + concat!("rand_normal_", stringify!($type)), + seed, + length, + mean, + stddev, + ) + .into_iter() + .map(f32::from) + .collect(); + assert!((calc_mean(&results) - mean).abs() < mean / 10.0); + assert!((calc_stddev(&results) - stddev).abs() < stddev / 10.0); + }; + } + + validate_random!(f32); + validate_random!(f16); + validate_random!(bf16); +} + +fn run_scatter_add( + input: &[T], + ids: &[I], + shape: &[usize], + dim: usize, + name: &'static str, +) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let options = MTLResourceOptions::StorageModeManaged; + let input_buffer = new_buffer(&device, input); + let ids_buffer = new_buffer(&device, ids); + let output = device.new_buffer(std::mem::size_of_val(input) as u64, options); + call_scatter_add( + &device, + command_buffer, + &kernels, + name, + shape, + shape, + dim, + BufferOffset::zero_offset(&input_buffer), + BufferOffset::zero_offset(&ids_buffer), + &output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + read_to_vec(&output, input.len()) +} + +#[test] +fn scatter_add() { + let ids_u8 = [0u8, 0, 1, 0, 2, 2, 3, 3]; + let ids_u32 = [0u32, 0, 1, 0, 2, 2, 3, 3]; + let ids_i64 = [0i64, 0, 1, 0, 2, 2, 3, 3]; + + let input_f32 = [5.0f32, 1.0, 7.0, 2.0, 3.0, 2.0, 1.0, 3.0]; + let input_f16 = input_f32 + .iter() + .map(|v| f16::from_f32(*v)) + .collect::>(); + let input_bf16 = input_f32 + .iter() + .map(|v| bf16::from_f32(*v)) + .collect::>(); + + let output_dim1_f32 = vec![8.0, 7.0, 5.0, 4.0, 0.0, 0.0, 0.0, 0.0]; + let output_dim1_f16 = output_dim1_f32 + .iter() + .map(|v| f16::from_f32(*v)) + .collect::>(); + let output_dim1_bf16 = output_dim1_f32 + .iter() + .map(|v| bf16::from_f32(*v)) + .collect::>(); + + let output_dim2_f32 = vec![5.0, 3.0, 7.0, 0.0, 3.0, 2.0, 1.0, 3.0]; + let output_dim2_f16 = output_dim2_f32 + .iter() + .map(|v| f16::from_f32(*v)) + .collect::>(); + let output_dim2_bf16 = output_dim2_f32 + .iter() + .map(|v| bf16::from_f32(*v)) + .collect::>(); + + for (shape, output_f32, output_f16, output_bf16) in [ + (vec![8], output_dim1_f32, output_dim1_f16, output_dim1_bf16), + ( + vec![4, 2], + output_dim2_f32, + output_dim2_f16, + output_dim2_bf16, + ), + ] { + for results in [ + run_scatter_add(&input_f32, &ids_u8, &shape, 0, "sa_u8_f32"), + run_scatter_add(&input_f32, &ids_u32, &shape, 0, "sa_u32_f32"), + run_scatter_add(&input_f32, &ids_i64, &shape, 0, "sa_i64_f32"), + ] { + assert_eq!(results, output_f32); + } + for results in [ + run_scatter_add(&input_f16, &ids_u8, &shape, 0, "sa_u8_f16"), + run_scatter_add(&input_f16, &ids_u32, &shape, 0, "sa_u32_f16"), + run_scatter_add(&input_f16, &ids_i64, &shape, 0, "sa_i64_f16"), + ] { + assert_eq!(results, output_f16); + } + for results in [ + run_scatter_add(&input_bf16, &ids_u8, &shape, 0, "sa_u8_bf16"), + run_scatter_add(&input_bf16, &ids_u32, &shape, 0, "sa_u32_bf16"), + run_scatter_add(&input_bf16, &ids_i64, &shape, 0, "sa_i64_bf16"), + ] { + assert_eq!(results, output_bf16); + } + } +} + +fn run_index_add( + left: &[T], + right: &[T], + indices: &[I], + shape: &[usize], + dim: usize, + name: &'static str, +) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let input_buffer = new_buffer(&device, right); + let output = new_buffer(&device, left); + let indices_buffer = new_buffer(&device, indices); + call_index_add( + &device, + command_buffer, + &kernels, + name, + shape, + shape, + shape, + dim, + BufferOffset::zero_offset(&input_buffer), + BufferOffset::zero_offset(&indices_buffer), + &output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + read_to_vec(&output, left.len()) +} + +#[test] +fn index_add() { + let left = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let right = vec![1.0f32, 1.0, 1.0, 1.0, 1.0, 1.0]; + let indices = vec![0u32, 1, 0, 1, 0, 1]; + let shape = vec![6]; + + // u32, f32 + { + let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_u32_f32"); + assert_eq!(results, vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]); + } + + // u32, f16 + { + let left = left.iter().map(|v| f16::from_f32(*v)).collect::>(); + let right = right.iter().map(|v| f16::from_f32(*v)).collect::>(); + let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_u32_f16"); + assert_eq!(approx_f16(results, 4), vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]); + } + + // u32, bf16 + { + let left = left.iter().map(|v| bf16::from_f32(*v)).collect::>(); + let right = right.iter().map(|v| bf16::from_f32(*v)).collect::>(); + let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_u32_bf16"); + assert_eq!(approx_bf16(results, 4), vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]); + } + + // u8, f32 + { + let indices = indices.iter().map(|v| *v as u8).collect::>(); + let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_u8_f32"); + assert_eq!(results, vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]); + } + + // u8, f16 + { + let indices = indices.iter().map(|v| *v as u8).collect::>(); + let left = left.iter().map(|v| f16::from_f32(*v)).collect::>(); + let right = right.iter().map(|v| f16::from_f32(*v)).collect::>(); + let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_u8_f16"); + assert_eq!(approx_f16(results, 4), vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]); + } + + // u8, bf16 + { + let indices = indices.iter().map(|v| *v as u8).collect::>(); + let left = left.iter().map(|v| bf16::from_f32(*v)).collect::>(); + let right = right.iter().map(|v| bf16::from_f32(*v)).collect::>(); + let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_u8_bf16"); + assert_eq!(approx_bf16(results, 4), vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]); + } + + // i64, f32 + { + let indices = indices.iter().map(|v| *v as i64).collect::>(); + let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_i64_f32"); + assert_eq!(results, vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]); + } + + // i64, f16 + { + let indices = indices.iter().map(|v| *v as i64).collect::>(); + let left = left.iter().map(|v| f16::from_f32(*v)).collect::>(); + let right = right.iter().map(|v| f16::from_f32(*v)).collect::>(); + let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_i64_f16"); + assert_eq!(approx_f16(results, 4), vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]); + } + + // i64, bf16 + { + let indices = indices.iter().map(|v| *v as i64).collect::>(); + let left = left.iter().map(|v| bf16::from_f32(*v)).collect::>(); + let right = right.iter().map(|v| bf16::from_f32(*v)).collect::>(); + let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_i64_bf16"); + assert_eq!(approx_bf16(results, 4), vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]); + } +} + +fn run_pool2d( + v: &[T], + (w_k, h_k): (usize, usize), + (w_stride, h_stride): (usize, usize), + shape: &[usize], + strides: &[usize], + name: &'static str, +) -> Vec { + let device = device(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let out_w = (shape[2] - w_k) / w_stride + 1; + let out_h = (shape[3] - h_k) / h_stride + 1; + let dst_el = out_w * out_h * shape[0] * shape[1]; + let input = new_buffer(&device, v); + let output = new_buffer(&device, &vec![0.0f32; dst_el]); + let kernels = Kernels::new(); + call_pool2d( + &device, + command_buffer, + &kernels, + name, + shape, + strides, + out_w, + out_h, + w_k, + h_k, + w_stride, + h_stride, + &input, + &output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + read_to_vec(&output, dst_el) +} + +#[test] +fn max_pool2d_f32() { + // kernel 2 stride 1 + let v: Vec = (0..16).map(|v| v as f32).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 1; + let results = run_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "max_pool2d_f32", + ); + let expected = vec![5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0]; + assert_eq!(results, expected); + + // kernel 2 stride 2 + let v: Vec = (0..16).map(|v| v as f32).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 2; + let results = run_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "max_pool2d_f32", + ); + let expected = vec![5.0, 7.0, 13.0, 15.0]; + assert_eq!(results, expected); +} + +#[test] +fn max_pool2d_f16() { + // kernel 2 stride 1 + let v: Vec = (0..16).map(|v| half::f16::from_f32(v as f32)).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 1; + let results = run_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "max_pool2d_f16", + ); + let expected = [5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0] + .iter() + .map(|v| half::f16::from_f32(*v)) + .collect::>(); + assert_eq!(results, expected); + + // kernel 2 stride 2 + let v: Vec = (0..16).map(|v| half::f16::from_f32(v as f32)).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 2; + let results = run_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "max_pool2d_f16", + ); + let expected = [5.0, 7.0, 13.0, 15.0] + .iter() + .map(|v| half::f16::from_f32(*v)) + .collect::>(); + assert_eq!(results, expected); +} + +#[test] +fn max_pool2d_bf16() { + // kernel 2 stride 1 + let v: Vec = (0..16).map(|v| half::bf16::from_f32(v as f32)).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 1; + let results = run_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "max_pool2d_bf16", + ); + let expected = [5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0] + .iter() + .map(|v| half::bf16::from_f32(*v)) + .collect::>(); + assert_eq!(results, expected); + + // kernel 2 stride 2 + let v: Vec = (0..16).map(|v| half::bf16::from_f32(v as f32)).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 2; + let results = run_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "max_pool2d_bf16", + ); + let expected = [5.0, 7.0, 13.0, 15.0] + .iter() + .map(|v| half::bf16::from_f32(*v)) + .collect::>(); + assert_eq!(results, expected); +} + +#[test] +fn max_pool2d_u8() { + // kernel 2 stride 1 + let v: Vec = (0..16).map(|v| v as u8).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 1; + let results = run_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "max_pool2d_u8", + ); + let expected = vec![5, 6, 7, 9, 10, 11, 13, 14, 15]; + assert_eq!(results, expected); + + // kernel 2 stride 2 + let v: Vec = (0..16).map(|v| v as u8).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 2; + let results = run_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "max_pool2d_u8", + ); + let expected = vec![5, 7, 13, 15]; + assert_eq!(results, expected); +} + +#[test] +fn max_pool2d_u32() { + // kernel 2 stride 1 + let v: Vec = (0..16).map(|v| v as u32).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 1; + let results = run_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "max_pool2d_u32", + ); + let expected = vec![5, 6, 7, 9, 10, 11, 13, 14, 15]; + assert_eq!(results, expected); + + // kernel 2 stride 2 + let v: Vec = (0..16).map(|v| v as u32).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 2; + let results = run_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "max_pool2d_u32", + ); + let expected = vec![5, 7, 13, 15]; + assert_eq!(results, expected); +} + +#[test] +fn avg_pool2d_f32() { + // kernel 2 stride 1 + let v: Vec = (0..16).map(|v| v as f32).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 1; + let results = run_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "avg_pool2d_f32", + ); + let expected = vec![ + 2.5000, 3.5000, 4.5000, 6.5000, 7.5000, 8.5000, 10.5000, 11.5000, 12.5000, + ]; + assert_eq!(results, expected); +} + +#[test] +fn avg_pool2d_f16() { + // kernel 2 stride 1 + let v: Vec = (0..16).map(|v| f16::from_f32(v as f32)).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 1; + let results = run_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "avg_pool2d_f16", + ); + let expected = [ + 2.5000, 3.5000, 4.5000, 6.5000, 7.5000, 8.5000, 10.5000, 11.5000, 12.5000, + ] + .iter() + .map(|v| f16::from_f32(*v)) + .collect::>(); + assert_eq!(results, expected); +} + +#[test] +fn avg_pool2d_bf16() { + // kernel 2 stride 1 + let v: Vec = (0..16).map(|v| bf16::from_f32(v as f32)).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 1; + let results = run_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "avg_pool2d_bf16", + ); + let expected = [ + 2.5000, 3.5000, 4.5000, 6.5000, 7.5000, 8.5000, 10.5000, 11.5000, 12.5000, + ] + .iter() + .map(|v| bf16::from_f32(*v)) + .collect::>(); + assert_eq!(results, expected); +} + +#[test] +fn avg_pool2d_u8() { + // kernel 2 stride 1 + let v: Vec = (0..16).map(|v| v as u8).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 1; + let results = run_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "avg_pool2d_u8", + ); + let expected = vec![2, 3, 4, 6, 7, 8, 10, 11, 12]; + assert_eq!(results, expected); +} + +#[test] +fn avg_pool2d_u32() { + // kernel 2 stride 1 + let v: Vec = (0..16).map(|v| v as u32).collect(); + let shape = vec![1, 1, 4, 4]; + let strides = vec![16, 16, 4, 1]; + let kernel = 2; + let stride = 1; + let results = run_pool2d( + &v, + (kernel, kernel), + (stride, stride), + &shape, + &strides, + "avg_pool2d_u32", + ); + let expected = vec![2, 3, 4, 6, 7, 8, 10, 11, 12]; + assert_eq!(results, expected); +} + +#[allow(clippy::too_many_arguments)] +fn run_conv_transpose1d( + input: &[T], + input_shape: &[usize], + input_stride: &[usize], + kernel: &[T], + kernel_shape: &[usize], + kernel_stride: &[usize], + dilation: usize, + stride: usize, + padding: usize, + out_padding: usize, + name: &'static str, +) -> Vec { + let device = device(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + + let c_out = kernel_shape[1]; + let k_size = kernel_shape[2]; + let b_size = input_shape[0]; + let l_in = input_shape[2]; + let l_out = (l_in - 1) * stride - 2 * padding + dilation * (k_size - 1) + out_padding + 1; + let dst_el = c_out * l_out * b_size; + + let input = new_buffer(&device, input); + let kernel = new_buffer(&device, kernel); + let output = new_buffer(&device, &vec![0.0f32; dst_el]); + let kernels = Kernels::new(); + + call_conv_transpose1d( + &device, + command_buffer, + &kernels, + name, + dilation, + stride, + padding, + out_padding, + c_out, + l_out, + b_size, + input_shape, + input_stride, + kernel_shape, + kernel_stride, + &input, + 0, + &kernel, + 0, + &output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + read_to_vec(&output, dst_el) +} + +#[test] +fn conv_transpose1d_f32() { + let input = vec![1.0f32, 2.0, 3.0, 4.0]; + let input_shape = &[1, 1, 4]; + let input_stride = &[4, 4, 1]; + + let kernel = vec![1.0f32, 2.0, 3.0, 4.0]; + let kernel_shape = &[1, 1, 4]; + let kernel_stride = &[4, 4, 1]; + + let results = run_conv_transpose1d( + &input, + input_shape, + input_stride, + &kernel, + kernel_shape, + kernel_stride, + 1, + 1, + 0, + 0, + "conv_transpose1d_f32", + ); + + let expected = vec![1., 4., 10., 20., 25., 24., 16.]; + assert_eq!(results, expected); +} + +#[test] +fn conv_transpose1d_f16() { + let input: Vec = [1.0, 2.0, 3.0, 4.0] + .iter() + .map(|v| f16::from_f32(*v)) + .collect(); + let input_shape = &[1, 1, 4]; + let input_stride = &[4, 4, 1]; + + let kernel: Vec = [1.0, 2.0, 3.0, 4.0] + .iter() + .map(|v| f16::from_f32(*v)) + .collect(); + let kernel_shape = &[1, 1, 4]; + let kernel_stride = &[4, 4, 1]; + + let results = run_conv_transpose1d( + &input, + input_shape, + input_stride, + &kernel, + kernel_shape, + kernel_stride, + 1, + 1, + 0, + 0, + "conv_transpose1d_f16", + ); + + let expected = [1., 4., 10., 20., 25., 24., 16.] + .iter() + .map(|v| f16::from_f32(*v)) + .collect::>(); + assert_eq!(results, expected); +} + +#[test] +fn conv_transpose1d_bf16() { + let input: Vec = [1.0, 2.0, 3.0, 4.0] + .iter() + .map(|v| bf16::from_f32(*v)) + .collect(); + let input_shape = &[1, 1, 4]; + let input_stride = &[4, 4, 1]; + + let kernel: Vec = [1.0, 2.0, 3.0, 4.0] + .iter() + .map(|v| bf16::from_f32(*v)) + .collect(); + let kernel_shape = &[1, 1, 4]; + let kernel_stride = &[4, 4, 1]; + + let results = run_conv_transpose1d( + &input, + input_shape, + input_stride, + &kernel, + kernel_shape, + kernel_stride, + 1, + 1, + 0, + 0, + "conv_transpose1d_bf16", + ); + + let expected = [1., 4., 10., 20., 25., 24., 16.] + .iter() + .map(|v| bf16::from_f32(*v)) + .collect::>(); + assert_eq!(results, expected); +} + +#[test] +fn conv_transpose1d_u8() { + let input: Vec = vec![1, 2, 3, 4]; + let input_shape = &[1, 1, 4]; + let input_stride = &[4, 4, 1]; + + let kernel: Vec = vec![1, 2, 3, 4]; + let kernel_shape = &[1, 1, 4]; + let kernel_stride = &[4, 4, 1]; + + let results = run_conv_transpose1d( + &input, + input_shape, + input_stride, + &kernel, + kernel_shape, + kernel_stride, + 1, + 1, + 0, + 0, + "conv_transpose1d_u8", + ); + + let expected = vec![1, 4, 10, 20, 25, 24, 16]; + assert_eq!(results, expected); +} + +#[test] +fn conv_transpose1d_u32() { + let input: Vec = vec![1, 2, 3, 4]; + let input_shape = &[1, 1, 4]; + let input_stride = &[4, 4, 1]; + + let kernel: Vec = vec![1, 2, 3, 4]; + let kernel_shape = &[1, 1, 4]; + let kernel_stride = &[4, 4, 1]; + + let results = run_conv_transpose1d( + &input, + input_shape, + input_stride, + &kernel, + kernel_shape, + kernel_stride, + 1, + 1, + 0, + 0, + "conv_transpose1d_u32", + ); + + let expected = vec![1, 4, 10, 20, 25, 24, 16]; + assert_eq!(results, expected); +} + +#[test] +fn const_fill() { + fn constant_fill(name: &'static str, len: usize, value: f32) -> Vec { + let dev = device(); + let kernels = Kernels::new(); + let command_queue = dev.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let buffer = dev.new_buffer( + (len * std::mem::size_of::()) as u64, + MTLResourceOptions::StorageModePrivate, + ); + call_const_fill(&dev, command_buffer, &kernels, name, len, &buffer, value).unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + read_to_vec::(&buffer, len) + } + fn test T>(name: &'static str, f: F) { + let len = rand::thread_rng().gen_range(2..16) * rand::thread_rng().gen_range(4..16); + let value = rand::thread_rng().gen_range(1. ..19.); + let v = constant_fill::(name, len, value); + assert_eq!(v, vec![f(value); len]) + } + test::("fill_u8", |v| v as u8); + test::("fill_u32", |v| v as u32); + test::("fill_i64", |v| v as i64); + test::("fill_f16", f16::from_f32); + test::("fill_bf16", bf16::from_f32); + test::("fill_f32", |v| v); +} diff --git a/mcandle/mcandle-metal-kernels/src/unary.metal b/mcandle/mcandle-metal-kernels/src/unary.metal new file mode 100644 index 000000000..ab4342ec3 --- /dev/null +++ b/mcandle/mcandle-metal-kernels/src/unary.metal @@ -0,0 +1,208 @@ +#include +#include +# +using namespace metal; + +METAL_FUNC uint get_strided_index( + uint idx, + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides +) { + uint strided_i = 0; + for (uint d = 0; d < num_dims; d++) { + uint dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; +} + +template METAL_FUNC T sqr(T in){ return in * in; } +template METAL_FUNC T recip(T in){ return T(1.0 / in); } +template METAL_FUNC T neg(T in){ return -in; } + +template METAL_FUNC T erf(T in){ + float x = (float) in; + // constants + float a1 = 0.254829592; + float a2 = -0.284496736; + float a3 = 1.421413741; + float a4 = -1.453152027; + float a5 = 1.061405429; + float p = 0.3275911; + + // Save the sign of x + int sign = 1; + if (x < 0) + sign = -1; + x = fabs(x); + + // A&S formula 7.1.26 + float t = 1.0/(1.0 + p*x); + float y = 1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x); + + return T(sign*y); +} +template METAL_FUNC T id(T in) { return in; } +template METAL_FUNC T gelu_erf(T x) { + return T(x * (1 + erf(x * M_SQRT1_2_F)) / 2); +} +template METAL_FUNC T gelu(T x) { + if (x > 5) { + return x; + } + T x_sq = x * x; + T x_cube = x_sq * x; + T alpha = x + static_cast(0.044715) * x_cube; + T beta = (static_cast(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha); + return static_cast(0.5) * x * (static_cast(1.0) + T(precise::tanh(beta))); +} +template METAL_FUNC T relu(T in){ + if (in < 0) { + return 0; + } + return in; +} +template METAL_FUNC T silu(T in){ + return in / (static_cast(1) + exp(-in)); +} +template METAL_FUNC T sigmoid(T in) { + return recip(static_cast(1) + exp(-in)); +} + +#define TILE_SIZE 2 + +#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \ +kernel void FN_NAME( \ + constant size_t &dim, \ + device const TYPENAME *input, \ + device TYPENAME *output, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + if (tid >= dim) { \ + return; \ + } \ + output[tid] = TYPENAME(FN(float(input[tid]))); \ +} \ +kernel void FN_NAME##_##strided( \ + constant size_t &dim, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + device const TYPENAME *input, \ + device TYPENAME *output, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + if (tid >= dim) { \ + return; \ + } \ + output[tid] = TYPENAME(FN(float(input[get_strided_index(tid, num_dims, dims, strides)]))); \ +} \ +kernel void FN_NAME##_##tiled( \ + constant size_t &dim, \ + device const TYPENAME *input, \ + device TYPENAME *output, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + for (uint i = 0; i < TILE_SIZE; i++) { \ + const uint idx = tid * TILE_SIZE + i; \ + output[idx] = TYPENAME(FN(float(input[idx]))); \ + } \ +} + +#define UNARY_OP(NAME) \ +UNARY(NAME, float, NAME##_f32, NAME##_f32_strided); \ +UNARY(NAME, half, NAME##_f16, NAME##_f16_strided); + +#define BFLOAT_UNARY_OP(NAME) \ +UNARY(NAME, bfloat, NAME##_bf16, NAME##_bf16_strided); + +#define COPY2D(FN_NAME, TYPENAME) \ +kernel void FN_NAME( \ + constant int64_t &d1, \ + constant int64_t &d2, \ + constant int64_t &src_s, \ + constant int64_t &dst_s, \ + device const TYPENAME *input, \ + device TYPENAME *output, \ + uint2 idx [[thread_position_in_grid]] \ +) { \ + if (idx.x >= d1 || idx.y >= d2) return; \ + int64_t src_idx = idx.x * src_s + idx.y; \ + int64_t dst_idx = idx.x * dst_s + idx.y; \ + output[dst_idx] = input[src_idx]; \ +} + +COPY2D(copy2d_f32, float) +COPY2D(copy2d_f16, half) +COPY2D(copy2d_u8, uint8_t) +COPY2D(copy2d_u32, uint32_t) + +UNARY_OP(cos) +UNARY_OP(sin) +UNARY_OP(sqr) +UNARY_OP(sqrt) +UNARY_OP(neg) +UNARY_OP(exp) +UNARY_OP(log) +UNARY_OP(gelu) +UNARY_OP(silu) +UNARY_OP(abs) +UNARY_OP(ceil) +UNARY_OP(floor) +UNARY_OP(round) +UNARY_OP(gelu_erf) +UNARY_OP(erf) +UNARY_OP(recip) +UNARY_OP(relu) +UNARY_OP(sign) +UNARY_OP(sigmoid) +UNARY(id, float, copy_f32, copy_f32_strided) +UNARY(id, half, copy_f16, copy_f16_strided) +UNARY(id, uint8_t, copy_u8, copy_u8_strided) +UNARY(id, uint32_t, copy_u32, copy_u32_strided) + +// tanh may create NaN on large values, e.g. 45 rather than outputing 1. +// This has been an issue for the encodec example. +UNARY(precise::tanh, float, tanh_f32, tanh_f32_strided); +UNARY(precise::tanh, half, tanh_f16, tanh_f16_strided); + +#if __METAL_VERSION__ >= 220 +UNARY(id, int64_t, copy_i64, copy_i64_strided) +COPY2D(copy2d_i64, int64_t) +#endif + +UNARY(id, int32_t, copy_i32, copy_i32_strided) +COPY2D(copy2d_i32, int32_t) + +UNARY(id, int16_t, copy_i16, copy_i16_strided) +COPY2D(copy2d_i16, int16_t) + +#if defined(__HAVE_BFLOAT__) +BFLOAT_UNARY_OP(cos) +BFLOAT_UNARY_OP(sin) +BFLOAT_UNARY_OP(sqr) +BFLOAT_UNARY_OP(sqrt) +BFLOAT_UNARY_OP(neg) +BFLOAT_UNARY_OP(exp) +BFLOAT_UNARY_OP(log) +BFLOAT_UNARY_OP(gelu) +BFLOAT_UNARY_OP(silu) +BFLOAT_UNARY_OP(abs) +BFLOAT_UNARY_OP(ceil) +BFLOAT_UNARY_OP(floor) +BFLOAT_UNARY_OP(round) +BFLOAT_UNARY_OP(gelu_erf) +BFLOAT_UNARY_OP(erf) +BFLOAT_UNARY_OP(recip) +BFLOAT_UNARY_OP(relu) +BFLOAT_UNARY_OP(sign) +BFLOAT_UNARY_OP(sigmoid) + +UNARY(id, bfloat, copy_bf16, copy_bf16_strided) + +UNARY(precise::tanh, bfloat, tanh_bf16, tanh_bf16_strided); + +COPY2D(copy2d_bf16, bfloat) +#endif diff --git a/mcandle/mcandle-metal-kernels/src/utils.rs b/mcandle/mcandle-metal-kernels/src/utils.rs new file mode 100644 index 000000000..0092ecfa5 --- /dev/null +++ b/mcandle/mcandle-metal-kernels/src/utils.rs @@ -0,0 +1,220 @@ +use metal::{Buffer, ComputeCommandEncoderRef, ComputePipelineState, MTLSize}; +use std::ffi::c_void; + +/// Most kernels apply similarly across the tensors +/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the +/// actual total buffer length). +/// Then kernels can just do their op on their single point in the buffer. +pub(crate) fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) { + let size = length as u64; + let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size); + let count = (size + width - 1) / width; + let thread_group_count = MTLSize { + width: count, + height: 1, + depth: 1, + }; + + let thread_group_size = MTLSize { + width, + height: 1, + depth: 1, + }; + (thread_group_count, thread_group_size) +} + +// https://github.com/ml-explore/mlx/blob/bddf23f175726a57f0e443cd45518c0757daa166/mlx/backend/metal/utils.h#L96 +pub fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize { + let mut pows0 = 0u64; + let mut pows1 = 0u64; + let mut pows2 = 0u64; + let mut sum = 0u64; + loop { + let presum = sum; + // Check all the pows + if dim0 >= (1 << (pows0 + 1)) { + pows0 += 1; + sum += 1; + } + if sum == 10 { + break; + } + if dim1 >= (1 << (pows1 + 1)) { + pows1 += 1; + sum += 1; + } + if sum == 10 { + break; + } + if dim2 >= (1 << (pows2 + 1)) { + pows2 += 1; + sum += 1; + } + if sum == presum || sum == 10 { + break; + } + } + MTLSize { + width: 1 << pows0, + height: 1 << pows1, + depth: 1 << pows2, + } +} + +pub fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: P) { +

::set_param(encoder, position, data) +} + +/// Helper functions to create the various objects on the compute command encoder +/// on a single line. +/// Prevents getting wrong some arguments number and mixing length and size in bytes. +pub trait EncoderParam { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self); +} +macro_rules! primitive { + ($type:ty) => { + impl EncoderParam for $type { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { + encoder.set_bytes( + position, + core::mem::size_of::<$type>() as u64, + &data as *const $type as *const c_void, + ); + } + } + }; +} +primitive!(bool); +primitive!(usize); +primitive!(i32); +primitive!(i64); +primitive!(u32); +primitive!(u64); +primitive!(f32); + +pub struct BufferOffset<'a> { + pub buffer: &'a Buffer, + pub offset_in_bytes: usize, +} + +impl<'a> BufferOffset<'a> { + pub fn zero_offset(buffer: &'a Buffer) -> Self { + Self { + buffer, + offset_in_bytes: 0, + } + } +} + +impl EncoderParam for &[T] { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { + encoder.set_bytes( + position, + core::mem::size_of_val(data) as u64, + data.as_ptr() as *const c_void, + ); + } +} + +impl EncoderParam for &Buffer { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { + encoder.set_buffer(position, Some(data), 0); + } +} + +impl EncoderParam for (&Buffer, usize) { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { + encoder.set_buffer(position, Some(data.0), data.1 as u64); + } +} + +impl<'a> EncoderParam for &BufferOffset<'a> { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { + encoder.set_buffer(position, Some(data.buffer), data.offset_in_bytes as u64); + } +} + +impl EncoderParam for &mut Buffer { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { + encoder.set_buffer(position, Some(data), 0); + } +} + +impl EncoderParam for (&mut Buffer, usize) { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { + encoder.set_buffer(position, Some(data.0), data.1 as u64); + } +} + +#[macro_export] +macro_rules! set_params { + ($encoder:ident, ($($param:expr),+)) => ( + let mut _index = 0; + $( + $crate::utils::set_param($encoder, _index, $param); + _index += 1; + )* + ); +} + +pub trait EncoderProvider { + type Encoder<'a>: AsRef + where + Self: 'a; + fn encoder(&self) -> Self::Encoder<'_>; +} + +pub struct WrappedEncoder<'a> { + inner: &'a ComputeCommandEncoderRef, + end_encoding_on_drop: bool, +} + +impl<'a> Drop for WrappedEncoder<'a> { + fn drop(&mut self) { + if self.end_encoding_on_drop { + self.inner.end_encoding() + } + } +} + +impl<'a> AsRef for WrappedEncoder<'a> { + fn as_ref(&self) -> &metal::ComputeCommandEncoderRef { + self.inner + } +} + +impl EncoderProvider for &metal::CommandBuffer { + type Encoder<'a> = WrappedEncoder<'a> + where + Self: 'a; + fn encoder(&self) -> Self::Encoder<'_> { + WrappedEncoder { + inner: self.new_compute_command_encoder(), + end_encoding_on_drop: true, + } + } +} + +impl EncoderProvider for &metal::CommandBufferRef { + type Encoder<'a> = WrappedEncoder<'a> + where + Self: 'a; + fn encoder(&self) -> Self::Encoder<'_> { + WrappedEncoder { + inner: self.new_compute_command_encoder(), + end_encoding_on_drop: true, + } + } +} + +impl EncoderProvider for &ComputeCommandEncoderRef { + type Encoder<'a> = WrappedEncoder<'a> + where + Self: 'a; + fn encoder(&self) -> Self::Encoder<'_> { + WrappedEncoder { + inner: self, + end_encoding_on_drop: false, + } + } +} diff --git a/mcandle/mcandle-metal-kernels/tmp/affine.rs b/mcandle/mcandle-metal-kernels/tmp/affine.rs new file mode 100644 index 000000000..b693f8e9f --- /dev/null +++ b/mcandle/mcandle-metal-kernels/tmp/affine.rs @@ -0,0 +1,76 @@ +use mcandle_metal_kernels::{call_affine, Kernels}; +use metal::objc::rc::autoreleasepool; +use metal::{Device, MTLResourceOptions}; +use rand; +use std::any::type_name; +use std::time::Instant; + +fn main() { + let device = Device::system_default().unwrap(); + let kernels = Kernels::new(); + + let f32_1k = (0..1000).map(|_| rand::random::()).collect::>(); + let f32_10k = (0..10000) + .map(|_| rand::random::()) + .collect::>(); + let f32_100k = (0..100000) + .map(|_| rand::random::()) + .collect::>(); + + println!( + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}", + "dtype", "kernel", "size", "runs", "total time", "avg time" + ); + + // f32 + run_affine_bench(&device, &kernels, &f32_1k); + run_affine_bench(&device, &kernels, &f32_10k); + run_affine_bench(&device, &kernels, &f32_100k); +} + +fn run_affine_bench(device: &Device, kernels: &Kernels, v: &[T]) { + let command_queue = device.new_command_queue(); + let options = MTLResourceOptions::StorageModeManaged; + + let iterations = 10000; + let input = device.new_buffer_with_data( + v.as_ptr() as *const core::ffi::c_void, + core::mem::size_of_val(v) as u64, + options, + ); + let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options); + + let mul: f32 = 1.2345; + let add: f32 = 2.3456; + let total_time = autoreleasepool(|| { + let command_buffer = command_queue.new_command_buffer(); + let start = Instant::now(); + for _ in 0..iterations { + call_affine( + &device, + command_buffer, + &kernels, + "affine_float", + v.len(), + &input, + &mut output, + mul, + add, + ) + .unwrap(); + } + command_buffer.commit(); + command_buffer.wait_until_completed(); + + start.elapsed() + }); + println!( + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}", + type_name::().split("::").last().unwrap(), + "affine", + v.len(), + iterations, + total_time, + total_time / iterations + ); +} diff --git a/mcandle/mcandle-metal-kernels/tmp/binary.rs b/mcandle/mcandle-metal-kernels/tmp/binary.rs new file mode 100644 index 000000000..6c207ded7 --- /dev/null +++ b/mcandle/mcandle-metal-kernels/tmp/binary.rs @@ -0,0 +1,182 @@ +use mcandle_metal_kernels::{binary, call_binary_contiguous, call_binary_strided, Kernels}; +use half::{bf16, f16}; +use metal::objc::rc::autoreleasepool; +use metal::{Device, MTLResourceOptions}; +use rand; +use std::any::type_name; +use std::time::Instant; + +fn main() { + let device = Device::system_default().unwrap(); + let kernels = Kernels::new(); + + let f32_1k = (0..1000).map(|_| rand::random::()).collect::>(); + let f32_10k = (0..10000) + .map(|_| rand::random::()) + .collect::>(); + let f32_100k = (0..100000) + .map(|_| rand::random::()) + .collect::>(); + + let f16_map = |v: &[f32]| v.iter().map(|v| f16::from_f32(*v)).collect::>(); + let f16_1k = f16_map(&f32_1k); + let f16_10k = f16_map(&f32_10k); + let f16_100k = f16_map(&f32_100k); + + let bf16_map = |v: &[f32]| v.iter().map(|v| bf16::from_f32(*v)).collect::>(); + let bf16_1k = bf16_map(&f32_1k); + let bf16_10k = bf16_map(&f32_10k); + let bf16_100k = bf16_map(&f32_100k); + + let f32_ckernels = [ + binary::contiguous::add::FLOAT, + binary::contiguous::sub::FLOAT, + binary::contiguous::mul::FLOAT, + binary::contiguous::div::FLOAT, + ]; + let f32_skernels = [ + binary::strided::add::FLOAT, + binary::strided::sub::FLOAT, + binary::strided::mul::FLOAT, + binary::strided::div::FLOAT, + ]; + let f16_ckernels = [ + binary::contiguous::add::HALF, + binary::contiguous::sub::HALF, + binary::contiguous::mul::HALF, + binary::contiguous::div::HALF, + ]; + let f16_skernels = [ + binary::strided::add::HALF, + binary::strided::sub::HALF, + binary::strided::mul::HALF, + binary::strided::div::HALF, + ]; + let bf16_ckernels = [ + binary::contiguous::add::BFLOAT, + binary::contiguous::sub::BFLOAT, + binary::contiguous::mul::BFLOAT, + binary::contiguous::div::BFLOAT, + ]; + let bf16_skernels = [ + binary::strided::add::BFLOAT, + binary::strided::sub::BFLOAT, + binary::strided::mul::BFLOAT, + binary::strided::div::BFLOAT, + ]; + + println!( + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}", + "dtype", "kernel", "size", "runs", "total time", "avg time" + ); + + // f32 + run_binary_bench(&device, &kernels, &f32_1k, f32_ckernels, f32_skernels); + run_binary_bench(&device, &kernels, &f32_10k, f32_ckernels, f32_skernels); + run_binary_bench(&device, &kernels, &f32_100k, f32_ckernels, f32_skernels); + + // f16 + run_binary_bench(&device, &kernels, &f16_1k, f16_ckernels, f16_skernels); + run_binary_bench(&device, &kernels, &f16_10k, f16_ckernels, f16_skernels); + run_binary_bench(&device, &kernels, &f16_100k, f16_ckernels, f16_skernels); + + // bf16 + run_binary_bench(&device, &kernels, &bf16_1k, bf16_ckernels, bf16_skernels); + run_binary_bench(&device, &kernels, &bf16_10k, bf16_ckernels, bf16_skernels); + run_binary_bench(&device, &kernels, &bf16_100k, bf16_ckernels, bf16_skernels); +} + +fn run_binary_bench( + device: &Device, + kernels: &Kernels, + v: &[T], + contiguous: [binary::contiguous::Kernel; 4], + strided: [binary::strided::Kernel; 4], +) { + let command_queue = device.new_command_queue(); + let options = MTLResourceOptions::StorageModeManaged; + + let iterations = 1000; + let input = device.new_buffer_with_data( + v.as_ptr() as *const core::ffi::c_void, + core::mem::size_of_val(v) as u64, + options, + ); + let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options); + + // Contiguous + for kernel_name in contiguous { + let total_time = autoreleasepool(|| { + let command_buffer = command_queue.new_command_buffer(); + let start = Instant::now(); + for _ in 0..iterations { + call_binary_contiguous( + device, + &command_buffer, + kernels, + kernel_name, + v.len(), + &input, + &input, + &mut output, + ) + .unwrap(); + } + command_buffer.commit(); + command_buffer.wait_until_completed(); + + start.elapsed() + }); + println!( + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}", + type_name::().split("::").last().unwrap(), + kernel_name.to_string(), + v.len(), + iterations, + total_time, + total_time / iterations + ); + } + + // Strided + let shape = vec![2, 5_000]; + let strides = vec![2, 1]; + let offset = 0; + for kernel_name in strided { + let total_time = autoreleasepool(|| { + let command_buffer = command_queue.new_command_buffer(); + let start = Instant::now(); + for _ in 0..iterations { + call_binary_strided( + device, + command_buffer, + &kernels, + kernel_name, + &shape, + &input, + &strides, + offset, + &input, + &strides, + offset, + &mut output, + ) + .unwrap(); + } + command_buffer.commit(); + command_buffer.wait_until_completed(); + + start.elapsed() + }); + + println!( + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}", + type_name::().split("::").last().unwrap(), + kernel_name.to_string(), + v.len(), + iterations, + total_time, + total_time / iterations + ); + } +} diff --git a/mcandle/mcandle-metal-kernels/tmp/cast.rs b/mcandle/mcandle-metal-kernels/tmp/cast.rs new file mode 100644 index 000000000..2e9e26371 --- /dev/null +++ b/mcandle/mcandle-metal-kernels/tmp/cast.rs @@ -0,0 +1,84 @@ +use mcandle_metal_kernels::{call_cast_contiguous, Kernels}; +use metal::objc::rc::autoreleasepool; +use metal::{Device, MTLResourceOptions}; +use rand; +use std::any::type_name; +use std::time::Instant; + +fn main() { + let device = Device::system_default().unwrap(); + let kernels = Kernels::new(); + + let f32_1k = (0..1000).map(|_| rand::random::()).collect::>(); + let f32_10k = (0..10000) + .map(|_| rand::random::()) + .collect::>(); + let f32_100k = (0..100000) + .map(|_| rand::random::()) + .collect::>(); + + let contiguous_kernels = ["cast_u32_f32"]; + + println!( + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}", + "dtype", "kernel", "size", "runs", "total time", "avg time" + ); + + // f32 + run_cast_bench(&device, &kernels, &f32_1k, &contiguous_kernels); + run_cast_bench(&device, &kernels, &f32_10k, &contiguous_kernels); + run_cast_bench(&device, &kernels, &f32_100k, &contiguous_kernels); +} + +fn run_cast_bench( + device: &Device, + kernels: &Kernels, + v: &[T], + contiguous: &[&'static str], +) { + let command_queue = device.new_command_queue(); + let options = MTLResourceOptions::StorageModeManaged; + + let iterations = 1000; + let input = device.new_buffer_with_data( + v.as_ptr() as *const core::ffi::c_void, + core::mem::size_of_val(v) as u64, + options, + ); + let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options); + + // Contiguous + for kernel_name in contiguous { + let total_time = autoreleasepool(|| { + let command_buffer = command_queue.new_command_buffer(); + let start = Instant::now(); + for _ in 0..iterations { + call_cast_contiguous( + device, + &command_buffer, + kernels, + kernel_name, + v.len(), + &input, + &mut output, + ) + .unwrap(); + } + command_buffer.commit(); + command_buffer.wait_until_completed(); + + start.elapsed() + }); + println!( + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}", + type_name::().split("::").last().unwrap(), + kernel_name.to_string(), + v.len(), + iterations, + total_time, + total_time / iterations + ); + } + + // Strided? +} diff --git a/mcandle/mcandle-metal-kernels/tmp/unary.rs b/mcandle/mcandle-metal-kernels/tmp/unary.rs new file mode 100644 index 000000000..21c902a62 --- /dev/null +++ b/mcandle/mcandle-metal-kernels/tmp/unary.rs @@ -0,0 +1,197 @@ +use mcandle_metal_kernels::{call_unary_contiguous, call_unary_strided, unary, Kernels}; +use half::{bf16, f16}; +use metal::objc::rc::autoreleasepool; +use metal::{Device, MTLResourceOptions}; +use rand; +use std::any::type_name; +use std::time::Instant; + +fn main() { + let device = Device::system_default().unwrap(); + let kernels = Kernels::new(); + + let f32_1k = (0..1000).map(|_| rand::random::()).collect::>(); + let f32_10k = (0..10000) + .map(|_| rand::random::()) + .collect::>(); + let f32_100k = (0..100000) + .map(|_| rand::random::()) + .collect::>(); + + let f16_map = |v: &[f32]| v.iter().map(|v| f16::from_f32(*v)).collect::>(); + let f16_1k = f16_map(&f32_1k); + let f16_10k = f16_map(&f32_10k); + let f16_100k = f16_map(&f32_100k); + + let bf16_map = |v: &[f32]| v.iter().map(|v| bf16::from_f32(*v)).collect::>(); + let bf16_1k = bf16_map(&f32_1k); + let bf16_10k = bf16_map(&f32_10k); + let bf16_100k = bf16_map(&f32_100k); + + let f32_ckernels = [ + unary::contiguous::sin::FLOAT, + unary::contiguous::cos::FLOAT, + unary::contiguous::exp::FLOAT, + unary::contiguous::sqr::FLOAT, + unary::contiguous::sqrt::FLOAT, + unary::contiguous::neg::FLOAT, + unary::contiguous::copy::FLOAT, + ]; + let f32_skernels = [ + unary::strided::sin::FLOAT, + unary::strided::cos::FLOAT, + unary::strided::exp::FLOAT, + unary::strided::sqr::FLOAT, + unary::strided::sqrt::FLOAT, + unary::strided::neg::FLOAT, + unary::strided::copy::FLOAT, + ]; + let f16_ckernels = [ + unary::contiguous::sin::HALF, + unary::contiguous::cos::HALF, + unary::contiguous::exp::HALF, + unary::contiguous::sqr::HALF, + unary::contiguous::sqrt::HALF, + unary::contiguous::neg::HALF, + unary::contiguous::copy::HALF, + ]; + let f16_skernels = [ + unary::strided::sin::HALF, + unary::strided::cos::HALF, + unary::strided::exp::HALF, + unary::strided::sqr::HALF, + unary::strided::sqrt::HALF, + unary::strided::neg::HALF, + unary::strided::copy::HALF, + ]; + let bf16_ckernels = [ + unary::contiguous::sin::BFLOAT, + unary::contiguous::cos::BFLOAT, + unary::contiguous::exp::BFLOAT, + unary::contiguous::sqr::BFLOAT, + unary::contiguous::sqrt::BFLOAT, + unary::contiguous::neg::BFLOAT, + unary::contiguous::copy::BFLOAT, + ]; + let bf16_skernels = [ + unary::strided::sin::BFLOAT, + unary::strided::cos::BFLOAT, + unary::strided::exp::BFLOAT, + unary::strided::sqr::BFLOAT, + unary::strided::sqrt::BFLOAT, + unary::strided::neg::BFLOAT, + unary::strided::copy::BFLOAT, + ]; + + println!( + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}", + "dtype", "kernel", "size", "runs", "total time", "avg time" + ); + + // f32 + run_unary_bench(&device, &kernels, &f32_1k, f32_ckernels, f32_skernels); + run_unary_bench(&device, &kernels, &f32_10k, f32_ckernels, f32_skernels); + run_unary_bench(&device, &kernels, &f32_100k, f32_ckernels, f32_skernels); + + // f16 + run_unary_bench(&device, &kernels, &f16_1k, f16_ckernels, f16_skernels); + run_unary_bench(&device, &kernels, &f16_10k, f16_ckernels, f16_skernels); + run_unary_bench(&device, &kernels, &f16_100k, f16_ckernels, f16_skernels); + + // bf16 + run_unary_bench(&device, &kernels, &bf16_1k, bf16_ckernels, bf16_skernels); + run_unary_bench(&device, &kernels, &bf16_10k, bf16_ckernels, bf16_skernels); + run_unary_bench(&device, &kernels, &bf16_100k, bf16_ckernels, bf16_skernels); +} + +fn run_unary_bench( + device: &Device, + kernels: &Kernels, + v: &[T], + contiguous: [unary::contiguous::Kernel; 7], + strided: [unary::strided::Kernel; 7], +) { + let command_queue = device.new_command_queue(); + let options = MTLResourceOptions::StorageModeManaged; + + let iterations = 10000; + let input = device.new_buffer_with_data( + v.as_ptr() as *const core::ffi::c_void, + core::mem::size_of_val(v) as u64, + options, + ); + let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options); + + // Contiguous + for kernel_name in contiguous { + let total_time = autoreleasepool(|| { + let command_buffer = command_queue.new_command_buffer(); + let start = Instant::now(); + for _ in 0..iterations { + call_unary_contiguous( + device, + &command_buffer, + kernels, + kernel_name, + v.len(), + &input, + &mut output, + ) + .unwrap(); + } + command_buffer.commit(); + command_buffer.wait_until_completed(); + + start.elapsed() + }); + println!( + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}", + type_name::().split("::").last().unwrap(), + kernel_name.0, + v.len(), + iterations, + total_time, + total_time / iterations + ); + } + + // Strided + let shape = vec![2, 5_000]; + let strides = vec![2, 1]; + let offset = 0; + for kernel_name in &strided { + let total_time = autoreleasepool(|| { + let command_buffer = command_queue.new_command_buffer(); + let start = Instant::now(); + for _ in 0..iterations { + call_unary_strided( + device, + command_buffer, + &kernels, + kernel_name, + &shape, + &input, + &strides, + offset, + &mut output, + 0, + ) + .unwrap(); + } + command_buffer.commit(); + command_buffer.wait_until_completed(); + + start.elapsed() + }); + + println!( + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}", + type_name::().split("::").last().unwrap(), + kernel_name.0, + v.len(), + iterations, + total_time, + total_time / iterations + ); + } +} diff --git a/mcandle/mcandle-nn/Cargo.toml b/mcandle/mcandle-nn/Cargo.toml new file mode 100644 index 000000000..f78f0ca85 --- /dev/null +++ b/mcandle/mcandle-nn/Cargo.toml @@ -0,0 +1,42 @@ +[package] +name = "mcandle-nn" +version.workspace = true +edition.workspace = true +description.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true +readme = "README.md" + +[dependencies] +accelerate-src = { workspace = true, optional = true } +mcandle-core = { workspace = true } +half = { workspace = true } +thiserror = { workspace = true } +intel-mkl-src = { workspace = true, optional = true } +num-traits = { workspace = true } +rayon = { workspace = true } +safetensors = { workspace = true } +serde = { workspace = true } +metal = { workspace = true, optional = true } +mcandle-metal-kernels = { workspace = true, optional = true } +mcandle-flash-attn = { workspace = true, optional = true } + +[dev-dependencies] +anyhow = { workspace = true } +clap = { workspace = true } +rand = { workspace = true } +criterion = { workspace = true } + +[features] +default = [] +accelerate = ["dep:accelerate-src", "mcandle-core/accelerate"] +cuda = ["mcandle-core/cuda"] +mkl = ["dep:intel-mkl-src", "mcandle-core/mkl"] +metal = ["mcandle-core/metal", "dep:mcandle-metal-kernels", "dep:metal"] +flash-attn = ["cuda", "dep:mcandle-flash-attn"] + +[[bench]] +name = "bench_main" +harness = false \ No newline at end of file diff --git a/mcandle/mcandle-nn/README.md b/mcandle/mcandle-nn/README.md new file mode 100644 index 000000000..d013f9dd5 --- /dev/null +++ b/mcandle/mcandle-nn/README.md @@ -0,0 +1 @@ +# mcandle-nn diff --git a/mcandle/mcandle-nn/benches/bench_main.rs b/mcandle/mcandle-nn/benches/bench_main.rs new file mode 100644 index 000000000..727479b5c --- /dev/null +++ b/mcandle/mcandle-nn/benches/bench_main.rs @@ -0,0 +1,9 @@ +mod benchmarks; + +use criterion::criterion_main; +criterion_main!( + benchmarks::layer_norm::benches, + benchmarks::conv::benches, + benchmarks::attention::benches_fast, + benchmarks::attention::benches_naive +); diff --git a/mcandle/mcandle-nn/benches/benchmarks/attention.rs b/mcandle/mcandle-nn/benches/benchmarks/attention.rs new file mode 100644 index 000000000..a97a9c45f --- /dev/null +++ b/mcandle/mcandle-nn/benches/benchmarks/attention.rs @@ -0,0 +1,111 @@ +use crate::benchmarks::{BenchDevice, BenchDeviceHandler}; +use criterion::{black_box, criterion_group, Criterion, Throughput}; +use mcandle_core::{DType, Device, Tensor}; +use mcandle_nn::scaled_dot_product_attention; +use std::time::Instant; + +fn run_attention(q: &Tensor, k: &Tensor, v: &Tensor, m: &Tensor, s: f64) { + let att = (q + .contiguous() + .unwrap() + .matmul(&k.t().unwrap().contiguous().unwrap()) + .unwrap() + / s) + .unwrap(); + + let att = att.broadcast_add(m).unwrap(); + + let att = mcandle_nn::ops::softmax_last_dim(&att).unwrap(); + // Convert to contiguous as matmul doesn't support strided vs for now. + att.matmul(&v.contiguous().unwrap()).unwrap(); +} + +fn run_bench_naive(c: &mut Criterion, device: &Device) { + let b = 4; + let seq = 1024; + let heads = 32; + let hd = 128; + + let dtype = DType::F32; + let q = Tensor::zeros((b, heads, seq, hd), dtype, device).unwrap(); + let k = Tensor::zeros((b, heads, seq, hd), dtype, device).unwrap(); + let v = Tensor::zeros((b, heads, seq, hd), dtype, device).unwrap(); + let m = Tensor::zeros((b, heads, seq, seq), dtype, device).unwrap(); + + let flops = b * seq * heads * hd; + + let mut group = c.benchmark_group(device.bench_name("attention_naive")); + group.throughput(Throughput::Bytes(flops as u64)); + group.bench_function("iter", move |b| { + b.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + run_attention( + black_box(&q), + black_box(&k), + black_box(&v), + black_box(&m), + 0.3, + ); + } + device.sync().unwrap(); + start.elapsed() + }) + }); + group.finish(); +} + +fn criterion_benchmark_naive(c: &mut Criterion) { + let handler = BenchDeviceHandler::new().unwrap(); + for device in handler.devices { + run_bench_naive(c, &device); + } +} + +fn run_bench_fast(c: &mut Criterion, device: &Device) { + let b = 4; + let seq = 1024; + let heads = 32; + let hd = 128; + + let dtype = DType::F32; + let q = Tensor::zeros((b, heads, seq, hd), dtype, device).unwrap(); + let k = Tensor::zeros((b, heads, seq, hd), dtype, device).unwrap(); + let v = Tensor::zeros((b, heads, seq, hd), dtype, device).unwrap(); + let m = Tensor::zeros((b, heads, seq, seq), dtype, device).unwrap(); + + let flops = b * seq * heads * hd; + + let mut group = c.benchmark_group(device.bench_name("attention_fast")); + group.throughput(Throughput::Bytes(flops as u64)); + group.bench_function("iter", move |b| { + b.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + let _ = scaled_dot_product_attention( + black_box(&q), + black_box(&k), + black_box(&v), + 0.3, + Some(black_box(&m)), + false, + seq, + ) + .unwrap(); + } + device.sync().unwrap(); + start.elapsed() + }) + }); + group.finish(); +} + +fn criterion_benchmark_fast(c: &mut Criterion) { + let handler = BenchDeviceHandler::new().unwrap(); + for device in handler.devices { + run_bench_fast(c, &device); + } +} + +criterion_group!(benches_naive, criterion_benchmark_naive); +criterion_group!(benches_fast, criterion_benchmark_fast); diff --git a/mcandle/mcandle-nn/benches/benchmarks/conv.rs b/mcandle/mcandle-nn/benches/benchmarks/conv.rs new file mode 100644 index 000000000..5a1e221c6 --- /dev/null +++ b/mcandle/mcandle-nn/benches/benchmarks/conv.rs @@ -0,0 +1,54 @@ +use crate::benchmarks::{BenchDevice, BenchDeviceHandler}; +use criterion::{black_box, criterion_group, Criterion}; +use mcandle_core::{DType, Device, Module, Tensor}; +use mcandle_nn::{Conv2d, Conv2dConfig}; +use std::time::Instant; + +const B: usize = 1; +const C: usize = 1; +const M: usize = 128; +const K: usize = 128; +const K_SIZE: usize = 3; + +fn run(input: Tensor, weight: Tensor, bias: Tensor, config: Conv2dConfig) { + Conv2d::new(weight, Some(bias), config) + .forward(&input) + .unwrap(); +} + +fn run_conv2d_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) { + let weight = Tensor::ones((1, 1, K_SIZE, K_SIZE), dtype, device) + .unwrap() + .to_dtype(dtype) + .unwrap(); + let bias = Tensor::zeros(K, dtype, device).unwrap(); + let input = Tensor::ones((B, C, M, K), dtype, device).unwrap(); + + let mut group = c.benchmark_group(device.bench_name(name)); + group.bench_function("iter", move |b| { + b.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + run( + black_box(input.clone()), + black_box(weight.clone()), + black_box(bias.clone()), + Default::default(), + ); + } + device.sync().unwrap(); + start.elapsed() + }) + }); + group.finish(); +} + +fn criterion_benchmark(c: &mut Criterion) { + let device = BenchDeviceHandler::new().unwrap(); + for d in device.devices { + run_conv2d_benchmark(c, &d, DType::F32, "conv2d_f32"); + run_conv2d_benchmark(c, &d, DType::F16, "conv2d_f16"); + } +} + +criterion_group!(benches, criterion_benchmark); diff --git a/mcandle/mcandle-nn/benches/benchmarks/layer_norm.rs b/mcandle/mcandle-nn/benches/benchmarks/layer_norm.rs new file mode 100644 index 000000000..5acd8fb38 --- /dev/null +++ b/mcandle/mcandle-nn/benches/benchmarks/layer_norm.rs @@ -0,0 +1,48 @@ +use crate::benchmarks::{BenchDevice, BenchDeviceHandler}; +use criterion::{black_box, criterion_group, Criterion}; +use mcandle_core::{DType, Device, Module, Tensor}; +use mcandle_nn::LayerNorm; +use std::time::Instant; + +fn run(input: &Tensor, weight: &Tensor, bias: &Tensor) { + let _ = LayerNorm::new(weight.clone(), bias.clone(), 1e-5).forward(input); +} + +const B: usize = 1; +const M: usize = 1024; +const K: usize = 1024; + +fn run_layer_norm_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) { + let elements = B * M * K; + + let weight = Tensor::arange(0.0, elements as f32, device) + .unwrap() + .to_dtype(dtype) + .unwrap(); + let bias = weight.ones_like().unwrap(); + let input = weight.ones_like().unwrap(); + + let mut group = c.benchmark_group(device.bench_name(name)); + group.bench_function("iter", move |b| { + b.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + run(black_box(&input), black_box(&weight), black_box(&bias)); + } + device.sync().unwrap(); + start.elapsed() + }) + }); + group.finish(); +} + +fn criterion_benchmark(c: &mut Criterion) { + let device = BenchDeviceHandler::new().unwrap(); + for d in device.devices { + run_layer_norm_benchmark(c, &d, DType::F32, "layer_norm_f32"); + run_layer_norm_benchmark(c, &d, DType::BF16, "layer_norm_bf16"); + run_layer_norm_benchmark(c, &d, DType::F16, "layer_norm_f16"); + } +} + +criterion_group!(benches, criterion_benchmark); diff --git a/mcandle/mcandle-nn/benches/benchmarks/mod.rs b/mcandle/mcandle-nn/benches/benchmarks/mod.rs new file mode 100644 index 000000000..6249c660a --- /dev/null +++ b/mcandle/mcandle-nn/benches/benchmarks/mod.rs @@ -0,0 +1,68 @@ +pub(crate) mod attention; +pub(crate) mod conv; +pub(crate) mod layer_norm; + +use mcandle_core::{Device, Result}; + +pub(crate) trait BenchDevice { + fn sync(&self) -> Result<()>; + + fn bench_name>(&self, name: S) -> String; +} + +impl BenchDevice for Device { + fn sync(&self) -> Result<()> { + match self { + Device::Cpu => Ok(()), + Device::Cuda(device) => { + #[cfg(feature = "cuda")] + { + use mcandle_core::cuda::WrapErr; + return Ok(device.synchronize().w()?); + } + #[cfg(not(feature = "cuda"))] + panic!("Cuda device without cuda feature enabled: {:?}", device) + } + Device::Metal(device) => { + #[cfg(feature = "metal")] + return Ok(device.wait_until_completed()?); + #[cfg(not(feature = "metal"))] + panic!("Metal device without metal feature enabled: {:?}", device) + } + } + } + + fn bench_name>(&self, name: S) -> String { + match self { + Device::Cpu => { + let cpu_type = if cfg!(feature = "accelerate") { + "accelerate" + } else if cfg!(feature = "mkl") { + "mkl" + } else { + "cpu" + }; + format!("{}_{}", cpu_type, name.into()) + } + Device::Cuda(_) => format!("cuda_{}", name.into()), + Device::Metal(_) => format!("metal_{}", name.into()), + } + } +} + +struct BenchDeviceHandler { + devices: Vec, +} + +impl BenchDeviceHandler { + pub fn new() -> Result { + let mut devices = Vec::new(); + if cfg!(feature = "metal") { + devices.push(Device::new_metal(0)?); + } else if cfg!(feature = "cuda") { + devices.push(Device::new_cuda(0)?); + } + devices.push(Device::Cpu); + Ok(Self { devices }) + } +} diff --git a/mcandle/mcandle-nn/examples/basic_optimizer.rs b/mcandle/mcandle-nn/examples/basic_optimizer.rs new file mode 100644 index 000000000..6a09a4634 --- /dev/null +++ b/mcandle/mcandle-nn/examples/basic_optimizer.rs @@ -0,0 +1,39 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use mcandle_core::{DType, Device, Result, Tensor}; +use mcandle_nn::{linear, AdamW, Linear, Module, Optimizer, ParamsAdamW, VarBuilder, VarMap}; + +fn gen_data() -> Result<(Tensor, Tensor)> { + // Generate some sample linear data. + let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; + let b_gen = Tensor::new(-2f32, &Device::Cpu)?; + let gen = Linear::new(w_gen, Some(b_gen)); + let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; + let sample_ys = gen.forward(&sample_xs)?; + Ok((sample_xs, sample_ys)) +} + +fn main() -> Result<()> { + let (sample_xs, sample_ys) = gen_data()?; + + // Use backprop to run a linear regression between samples and get the coefficients back. + let varmap = VarMap::new(); + let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::Cpu); + let model = linear(2, 1, vb.pp("linear"))?; + let params = ParamsAdamW { + lr: 0.1, + ..Default::default() + }; + let mut opt = AdamW::new(varmap.all_vars(), params)?; + for step in 0..10000 { + let ys = model.forward(&sample_xs)?; + let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; + opt.backward_step(&loss)?; + println!("{step} {}", loss.to_vec0::()?); + } + Ok(()) +} diff --git a/mcandle/mcandle-nn/examples/cpu_benchmarks.rs b/mcandle/mcandle-nn/examples/cpu_benchmarks.rs new file mode 100644 index 000000000..f6bb14b0d --- /dev/null +++ b/mcandle/mcandle-nn/examples/cpu_benchmarks.rs @@ -0,0 +1,343 @@ +/// This example contains some simple benchmarks so that it's easy to run them in perf etc. +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use clap::{Parser, Subcommand}; +use mcandle_core::quantized::GgmlType; +use mcandle_core::{CpuStorage, Device, Layout, Module, Result, Shape, Tensor, D}; + +const CHECK_CONV2D: bool = false; + +trait Benchmark { + type PreProcessData; + type RunResult; + + fn preprocess() -> Result; + fn run_one(_: &Self::PreProcessData) -> Result; + + const ITERS: usize; +} + +struct Im2Col { + h_k: usize, + w_k: usize, + stride: usize, + dilation: usize, + padding: usize, +} + +impl Im2Col { + fn hw_out(&self, h: usize, w: usize) -> (usize, usize) { + let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1; + let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1; + (h_out, w_out) + } +} + +impl mcandle_core::CustomOp1 for Im2Col { + fn name(&self) -> &'static str { + "im2col" + } + + fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> { + let &Self { + h_k, + w_k, + stride, + dilation, + padding, + } = self; + let (b, c, h, w) = layout.shape().dims4()?; + let (h_out, w_out) = self.hw_out(h, w); + let slice = storage.as_slice::()?; + let src = &slice[layout.start_offset()..]; + let mut dst = vec![0f32; b * h_out * w_out * c * h_k * w_k]; + let (src_s0, src_s1, src_s2, src_s3) = { + let s = layout.stride(); + (s[0], s[1], s[2], s[3]) + }; + // TODO: provide specialized kernels for the common use cases. + // - h_k = w_k = 1 + // - padding = 0 + // - stride = 1 + // - dilation = 1 + for b_idx in 0..b { + let src_idx = b_idx * src_s0; + let dst_idx = b_idx * h_out * w_out * c * h_k * w_k; + for h_idx in 0..h_out { + let dst_idx = dst_idx + h_idx * w_out * c * h_k * w_k; + for w_idx in 0..w_out { + let dst_idx = dst_idx + w_idx * c * h_k * w_k; + for c_idx in 0..c { + let dst_idx = dst_idx + c_idx * h_k * w_k; + let src_idx = c_idx * src_s1 + src_idx; + for h_k_idx in 0..h_k { + let src_h = h_idx * stride + h_k_idx * dilation; + if padding != 0 && (src_h < padding || src_h >= h + padding) { + continue; + } + let src_h = src_h - padding; + let src_idx = src_idx + src_h * src_s2; + let dst_idx = dst_idx + h_k_idx * w_k; + for w_k_idx in 0..w_k { + let src_w = w_idx * stride + w_k_idx * dilation; + if padding != 0 && (src_w < padding || src_w >= w + padding) { + continue; + } + let src_w = src_w - padding; + let src_idx = src_idx + src_w * src_s3; + let dst_idx = dst_idx + w_k_idx; + dst[dst_idx] = src[src_idx] + } + } + } + } + } + } + let storage = mcandle_core::WithDType::to_cpu_storage_owned(dst); + Ok((storage, (b * h_out * w_out, c * h_k * w_k).into())) + } +} + +// Conv1d example as used in whisper. +struct Conv1d; +impl Benchmark for Conv1d { + type PreProcessData = (Tensor, Tensor); + type RunResult = Tensor; + fn preprocess() -> Result { + let inp = Tensor::randn(0f32, 1., (1, 384, 3000), &Device::Cpu)?; + let w = Tensor::randn(0f32, 1., (384, 384, 3), &Device::Cpu)?; + Ok((inp, w)) + } + + fn run_one(d: &Self::PreProcessData) -> Result { + d.0.conv1d(&d.1, 0, 1, 1, 1) + } + + const ITERS: usize = 5; +} + +// Conv2d example as used in stable-diffusion. +struct Conv2d; +impl Benchmark for Conv2d { + type PreProcessData = (Tensor, Tensor); + type RunResult = Tensor; + + fn preprocess() -> Result { + let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?; + let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?; + Ok((inp, w)) + } + + fn run_one(d: &Self::PreProcessData) -> Result { + d.0.conv2d(&d.1, 0, 1, 1, 1) + } + + const ITERS: usize = 5; +} + +// Conv2d example as used in stable-diffusion, im2col implementation. +struct Conv2dIm2Col; +impl Benchmark for Conv2dIm2Col { + type PreProcessData = (Tensor, Tensor); + type RunResult = Tensor; + + fn preprocess() -> Result { + let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?; + let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?; + Ok((inp, w)) + } + + fn run_one(d: &Self::PreProcessData) -> Result { + // d.0.conv2d(&d.1, 0, 1, 1, 1) + let (b, _, h, w) = d.0.dims4()?; + let (_, _, h_k, w_k) = d.1.dims4()?; + let op = Im2Col { + h_k, + w_k, + stride: 1, + dilation: 1, + padding: 0, + }; + let (h_out, w_out) = op.hw_out(h, w); + let col = d.0.apply_op1_no_bwd(&op)?; + let res = col.matmul(&d.1.flatten_from(1)?.t()?)?; + let res = res + .reshape((b, h_out, w_out, ()))? + .permute((0, 3, 1, 2))? + .contiguous()?; + if CHECK_CONV2D { + let res2 = d.0.conv2d(&d.1, op.padding, op.stride, op.dilation, 1); + let diff = (&res - res2)?.sqr()?.mean_all()?; + println!("{diff}"); + } + Ok(res) + } + + const ITERS: usize = 5; +} + +struct MatMul; +impl Benchmark for MatMul { + type PreProcessData = (Tensor, Tensor); + type RunResult = Tensor; + fn preprocess() -> Result { + let lhs = Tensor::randn(0f32, 1., (1024, 1024), &Device::Cpu)?; + let rhs = Tensor::randn(0f32, 1., (1024, 1024), &Device::Cpu)?; + Ok((lhs, rhs)) + } + + fn run_one(d: &Self::PreProcessData) -> Result { + d.0.matmul(&d.1) + } + + const ITERS: usize = 100; +} + +struct MatVec; +impl Benchmark for MatVec { + type PreProcessData = (Tensor, Tensor); + type RunResult = Tensor; + fn preprocess() -> Result { + let lhs = Tensor::randn(0f32, 1., (1024 * 4, 1024 * 4), &Device::Cpu)?; + let rhs = Tensor::randn(0f32, 1., (1024 * 4, 1), &Device::Cpu)?; + Ok((lhs, rhs)) + } + + fn run_one(d: &Self::PreProcessData) -> Result { + d.0.matmul(&d.1) + } + + const ITERS: usize = 100; +} + +// This benchmark is similar to: +// https://github.com/ggerganov/llama.cpp/blob/master/examples/benchmark/benchmark-matmult.cpp +struct QMatMul; +impl Benchmark for QMatMul { + type PreProcessData = (mcandle_core::quantized::QMatMul, Tensor); + type RunResult = Tensor; + fn preprocess() -> Result { + let zeros = vec![mcandle_core::quantized::k_quants::BlockQ4_0::zeros(); 4096 * 11008 / 32]; + let mm = mcandle_core::quantized::QTensor::new( + mcandle_core::quantized::QStorage::Cpu(Box::new(zeros)), + (4096, 11008), + )?; + let mm = mcandle_core::quantized::QMatMul::from_qtensor(mm)?; + let arg = Tensor::randn(0f32, 1., (128, 11008), &Device::Cpu)?; + Ok((mm, arg)) + } + + fn run_one(d: &Self::PreProcessData) -> Result { + d.0.forward(&d.1) + } + + const ITERS: usize = 100; +} + +struct Cat; +impl Benchmark for Cat { + type PreProcessData = (Tensor, Tensor); + type RunResult = Tensor; + fn preprocess() -> Result { + let lhs = Tensor::randn(0f32, 1., (1, 32, 2000, 128), &Device::Cpu)?; + let rhs = Tensor::randn(0f32, 1., (1, 32, 1, 128), &Device::Cpu)?; + Ok((lhs, rhs)) + } + + fn run_one(d: &Self::PreProcessData) -> Result { + Tensor::cat(&[&d.0, &d.1], 2) + } + + const ITERS: usize = 1000; +} + +struct Softmax; +impl Benchmark for Softmax { + type PreProcessData = Tensor; + type RunResult = Tensor; + fn preprocess() -> Result { + // Typical whisper tiny size. + let x = Tensor::randn(0f32, 1., (1, 6, 200, 1500), &Device::Cpu)?; + Ok(x) + } + + fn run_one(d: &Self::PreProcessData) -> Result { + mcandle_nn::ops::softmax(d, D::Minus1) + } + + const ITERS: usize = 100; +} + +struct SoftmaxLastDim; +impl Benchmark for SoftmaxLastDim { + type PreProcessData = Tensor; + type RunResult = Tensor; + fn preprocess() -> Result { + // Typical whisper tiny size. + let x = Tensor::randn(0f32, 1., (1, 6, 200, 1500), &Device::Cpu)?; + Ok(x) + } + + fn run_one(d: &Self::PreProcessData) -> Result { + mcandle_nn::ops::softmax_last_dim(d) + } + + const ITERS: usize = 100; +} + +fn run(iters: Option) -> Result<()> { + use std::hint::black_box; + + let iters = iters.unwrap_or(B::ITERS); + let d = B::preprocess()?; + let start = std::time::Instant::now(); + for _iter in 0..iters { + let _res = black_box(B::run_one(black_box(&d))?); + } + println!("{:?}", start.elapsed() / iters as u32); + Ok(()) +} + +#[derive(Subcommand, Debug, Clone)] +enum Task { + Conv1d, + Conv2d, + Conv2dIm2Col, + Matmul, + Matvec, + Qmatmul, + Softmax, + SoftmaxLastDim, + Cat, +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +pub struct Args { + /// The benchmark to be run. + #[command(subcommand)] + task: Task, + + #[arg(long)] + iters: Option, +} + +fn main() -> Result<()> { + let args = Args::parse(); + match args.task { + Task::Conv1d => run::(args.iters)?, + Task::Conv2d => run::(args.iters)?, + Task::Conv2dIm2Col => run::(args.iters)?, + Task::Matmul => run::(args.iters)?, + Task::Matvec => run::(args.iters)?, + Task::Softmax => run::(args.iters)?, + Task::SoftmaxLastDim => run::(args.iters)?, + Task::Qmatmul => run::(args.iters)?, + Task::Cat => run::(args.iters)?, + } + Ok(()) +} diff --git a/mcandle/mcandle-nn/src/activation.rs b/mcandle/mcandle-nn/src/activation.rs new file mode 100644 index 000000000..fa07c5ed5 --- /dev/null +++ b/mcandle/mcandle-nn/src/activation.rs @@ -0,0 +1,106 @@ +//! Activation Functions +//! +use mcandle_core::{Result, Tensor}; +use serde::Deserialize; + +#[derive(Debug, Clone, Copy, PartialEq, Deserialize, Default)] +#[serde(rename_all = "lowercase")] +pub enum Activation { + #[default] + #[serde(alias = "gelu")] + Gelu, + #[serde(alias = "gelu_new")] + NewGelu, + Relu, + Relu2, + Relu6, + Silu, + Sigmoid, + HardSigmoid, + Swiglu, + Swish, + HardSwish, + Elu(f64), + LeakyRelu(f64), + #[serde(alias = "gelu_pytorch_tanh")] + GeluPytorchTanh, +} + +impl super::Module for Activation { + fn forward(&self, xs: &Tensor) -> Result { + match self { + Self::Gelu => xs.gelu_erf(), + // https://github.com/huggingface/transformers/blob/12f043eaeaabfef6f6efea411d98e6f6d3c094b7/src/transformers/activations.py#L49-L78 + Self::NewGelu => xs.gelu(), + Self::Relu => xs.relu(), + Self::Relu2 => xs.relu()?.sqr(), + Self::Relu6 => xs.clamp(0f32, 6f32), + Self::Silu => xs.silu(), + Self::Sigmoid => crate::ops::sigmoid(xs), + Self::HardSigmoid => crate::ops::hard_sigmoid(xs), + Self::Swiglu => crate::ops::swiglu(xs), + Self::Swish => xs * crate::ops::sigmoid(xs)?, + Self::HardSwish => xs * crate::ops::hard_sigmoid(xs)?, + &Self::Elu(alpha) => xs.elu(alpha), + &Self::LeakyRelu(negative_slope) => crate::ops::leaky_relu(xs, negative_slope), + Self::GeluPytorchTanh => xs.gelu(), + } + } +} + +#[derive(Clone, Debug)] +pub struct PReLU { + weight: Tensor, + is_scalar: bool, +} + +impl PReLU { + pub fn new(weight: Tensor, is_scalar: bool) -> Self { + Self { weight, is_scalar } + } + + pub fn weight(&self) -> &Tensor { + &self.weight + } + + pub fn is_scalar(&self) -> bool { + self.is_scalar + } +} + +impl mcandle_core::Module for PReLU { + fn forward(&self, xs: &Tensor) -> Result { + let weight = if self.is_scalar { + self.weight.reshape(())? + } else if xs.rank() >= 2 { + let num_channels = xs.dim(1)?; + let num_weights = self.weight.elem_count(); + if num_weights != num_channels { + mcandle_core::bail!("error in prelu: unexpected number of channels for the input, got {num_channels}, weight dim is {num_weights}") + } + let mut s = vec![1; xs.rank()]; + s[1] = self.weight.elem_count(); + self.weight.reshape(s)? + } else { + self.weight.clone() + }; + let zeros = xs.zeros_like()?; + xs.maximum(&zeros)? + xs.minimum(&zeros)?.broadcast_mul(&weight)? + } +} + +/// Create or initialize a new PReLU layer. +/// +/// This uses some default name for weights, namely `"weight"`. +/// # Arguments +/// +/// * `num_channels` - The number of channels. Use `None` to have as single trainable value and +/// `Some` for a 1D vector with the appropriate number of channels. When applying the `forward` +/// function, the input tensor shape `s` should either be one dimension with this number of +/// channels or if `s.len() >= 2` it should have `s[1]` equal to this number. +pub fn prelu(num_channels: Option, vs: crate::VarBuilder) -> Result { + let init_ws = crate::init::Init::Const(0.25); + // When using a scalar weight, the PyTorch encoding is to use a 1d vector of length 1. + let ws = vs.get_with_hints((num_channels.unwrap_or(1),), "weight", init_ws)?; + Ok(PReLU::new(ws, num_channels.is_none())) +} diff --git a/mcandle/mcandle-nn/src/attention.rs b/mcandle/mcandle-nn/src/attention.rs new file mode 100644 index 000000000..701162691 --- /dev/null +++ b/mcandle/mcandle-nn/src/attention.rs @@ -0,0 +1,63 @@ +use mcandle_core::{Result, Tensor}; + +#[cfg(feature = "flash-attn")] +pub fn flash_attn( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + causal: bool, +) -> Result { + mcandle_flash_attn::flash_attn(q, k, v, softmax_scale, causal) +} + +#[cfg(not(feature = "flash-attn"))] +pub fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result { + unimplemented!("Compile with '--features flash-attn'") +} + +/// Computes (softmax(QK^T*sqrt(d_k)) + M)V. `M` is the attention mask, and is a bias (0 for unmasked, -inf for masked). +/// +/// The attention implementation is automatically accelerated and dispatched as follows: +/// 1) If `use_flash_attn == true`, use a Flash Attention V2 kernel +/// 2) Otherwise, use SDPA with fusion of softmax scale and attention bias application +/// +/// Note that there may be minute differences in output because floating point operations are not associative. +#[allow(unused_variables, clippy::too_many_arguments)] +pub fn scaled_dot_product_attention( + q: &Tensor, + k: &Tensor, + v: &Tensor, + scale: f64, + mask: Option<&Tensor>, + use_flash_attn: bool, + seq_len: usize, +) -> Result { + if use_flash_attn { + // flash-attn expects (b_sz, seq_len, nheads, head_dim) + let q = q.transpose(1, 2)?; + let k = k.transpose(1, 2)?; + let v = v.transpose(1, 2)?; + return flash_attn(&q, &k, &v, scale as f32, seq_len > 1)?.transpose(1, 2); + } + + let att = match mask { + Some(mask) => { + let (b, n, s, _h) = q.dims4()?; + let mut mask_and_output = mask.broadcast_as((b, n, s, s))?.contiguous()?; + q.contiguous()?.matmul_with_alpha_beta( + &k.t()?.contiguous()?, + &mut mask_and_output, + Some(scale), + )?; + mask_and_output + } + None => q + .contiguous()? + .matmul_with_alpha(&k.t()?.contiguous()?, Some(scale))?, + }; + + let att = crate::ops::softmax_last_dim(&att)?; + // Convert to contiguous as matmul doesn't support strided vs for now. + att.matmul(&v.contiguous()?) +} diff --git a/mcandle/mcandle-nn/src/batch_norm.rs b/mcandle/mcandle-nn/src/batch_norm.rs new file mode 100644 index 000000000..c911a6a15 --- /dev/null +++ b/mcandle/mcandle-nn/src/batch_norm.rs @@ -0,0 +1,328 @@ +//! Batch Normalization. +//! +//! This layer applies Batch Normalization over a mini-batch of inputs as described in [`Batch +//! Normalization`]. The input is expected to have at least three dimensions. +//! +//! Note that this implementation is for inference only, there is no possibility to track the +//! running stats. +//! +//! [`Batch Normalization`]: https://arxiv.org/abs/1502.03167 +use mcandle_core::{DType, Result, Tensor, Var}; + +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct BatchNormConfig { + pub eps: f64, + pub remove_mean: bool, + + /// The meaning of affine here is different from LayerNorm: when false there is no learnable + /// parameter at all, 1 used for gamma and 0 for beta. + pub affine: bool, + + /// Controls exponential moving average of running stats. Defaults to 0.1 + /// + /// `running_stat * (1.0 - momentum) + stat * momentum`. + pub momentum: f64, +} + +impl Default for BatchNormConfig { + fn default() -> Self { + Self { + eps: 1e-5, + remove_mean: true, + affine: true, + momentum: 0.1, + } + } +} + +impl From for BatchNormConfig { + fn from(eps: f64) -> Self { + Self { + eps, + ..Default::default() + } + } +} + +#[derive(Clone, Debug)] +pub struct BatchNorm { + running_mean: Var, + running_var: Var, + weight_and_bias: Option<(Tensor, Tensor)>, + remove_mean: bool, + eps: f64, + momentum: f64, +} + +impl BatchNorm { + fn check_validity(&self, num_features: usize) -> Result<()> { + if self.eps < 0. { + mcandle_core::bail!("batch-norm eps cannot be negative {}", self.eps) + } + if !(0.0..=1.0).contains(&self.momentum) { + mcandle_core::bail!( + "batch-norm momentum must be between 0 and 1, is {}", + self.momentum + ) + } + if self.running_mean.dims() != [num_features] { + mcandle_core::bail!( + "batch-norm running mean has unexpected shape {:?} should have shape [{num_features}]", + self.running_mean.shape(), + ) + } + if self.running_var.dims() != [num_features] { + mcandle_core::bail!( + "batch-norm running variance has unexpected shape {:?} should have shape [{num_features}]", + self.running_var.shape(), + ) + } + if let Some((ref weight, ref bias)) = self.weight_and_bias.as_ref() { + if weight.dims() != [num_features] { + mcandle_core::bail!( + "batch-norm weight has unexpected shape {:?} should have shape [{num_features}]", + weight.shape(), + ) + } + if bias.dims() != [num_features] { + mcandle_core::bail!( + "batch-norm weight has unexpected shape {:?} should have shape [{num_features}]", + bias.shape(), + ) + } + } + Ok(()) + } + + pub fn new( + num_features: usize, + running_mean: Tensor, + running_var: Tensor, + weight: Tensor, + bias: Tensor, + eps: f64, + ) -> Result { + let out = Self { + running_mean: Var::from_tensor(&running_mean)?, + running_var: Var::from_tensor(&running_var)?, + weight_and_bias: Some((weight, bias)), + remove_mean: true, + eps, + momentum: 0.1, + }; + out.check_validity(num_features)?; + Ok(out) + } + + pub fn new_no_bias( + num_features: usize, + running_mean: Tensor, + running_var: Tensor, + eps: f64, + ) -> Result { + let out = Self { + running_mean: Var::from_tensor(&running_mean)?, + running_var: Var::from_tensor(&running_var)?, + weight_and_bias: None, + remove_mean: true, + eps, + momentum: 0.1, + }; + out.check_validity(num_features)?; + Ok(out) + } + + pub fn new_with_momentum( + num_features: usize, + running_mean: Tensor, + running_var: Tensor, + weight: Tensor, + bias: Tensor, + eps: f64, + momentum: f64, + ) -> Result { + let out = Self { + running_mean: Var::from_tensor(&running_mean)?, + running_var: Var::from_tensor(&running_var)?, + weight_and_bias: Some((weight, bias)), + remove_mean: true, + eps, + momentum, + }; + out.check_validity(num_features)?; + Ok(out) + } + + pub fn new_no_bias_with_momentum( + num_features: usize, + running_mean: Tensor, + running_var: Tensor, + eps: f64, + momentum: f64, + ) -> Result { + let out = Self { + running_mean: Var::from_tensor(&running_mean)?, + running_var: Var::from_tensor(&running_var)?, + weight_and_bias: None, + remove_mean: true, + eps, + momentum, + }; + out.check_validity(num_features)?; + Ok(out) + } + + pub fn running_mean(&self) -> &Tensor { + self.running_mean.as_tensor() + } + + pub fn running_var(&self) -> &Tensor { + self.running_var.as_tensor() + } + + pub fn eps(&self) -> f64 { + self.eps + } + + pub fn weight_and_bias(&self) -> Option<(&Tensor, &Tensor)> { + self.weight_and_bias.as_ref().map(|v| (&v.0, &v.1)) + } + + pub fn momentum(&self) -> f64 { + self.momentum + } + + pub fn forward_train(&self, x: &Tensor) -> Result { + let num_features = self.running_mean.as_tensor().dim(0)?; + let x_dtype = x.dtype(); + let internal_dtype = match x_dtype { + DType::F16 | DType::BF16 => DType::F32, + d => d, + }; + if x.rank() < 2 { + mcandle_core::bail!( + "batch-norm input tensor must have at least two dimensions ({:?})", + x.shape() + ) + } + if x.dim(1)? != num_features { + mcandle_core::bail!( + "batch-norm input doesn't have the expected number of features ({:?} <> {})", + x.shape(), + num_features + ) + } + let x = x.to_dtype(internal_dtype)?; + let x = x.transpose(0, 1)?; + let x_dims_post_transpose = x.dims(); + // Flatten all the dimensions exception the channel one as this performs a Spatial Batch + // Normalization. + let x = x.flatten_from(1)?.contiguous()?; + let x = if self.remove_mean { + // The mean is taken over dim 1 as this is the batch dim after the transpose(0, 1) above. + let mean_x = x.mean_keepdim(1)?; + let updated_running_mean = ((self.running_mean.as_tensor() * (1.0 - self.momentum))? + + (mean_x.flatten_all()? * self.momentum)?)?; + self.running_mean.set(&updated_running_mean)?; + x.broadcast_sub(&mean_x)? + } else { + x + }; + // The mean is taken over dim 1 as this is the batch dim after the transpose(0, 1) above. + let norm_x = x.sqr()?.mean_keepdim(1)?; + let updated_running_var = { + let batch_size = x.dim(1)? as f64; + let running_var_weight = 1.0 - self.momentum; + let norm_x_weight = self.momentum * batch_size / (batch_size - 1.0); + ((self.running_var.as_tensor() * running_var_weight)? + + (&norm_x.flatten_all()? * norm_x_weight)?)? + }; + self.running_var.set(&updated_running_var)?; + let x = x + .broadcast_div(&(norm_x + self.eps)?.sqrt()?)? + .to_dtype(x_dtype)?; + let x = match &self.weight_and_bias { + None => x, + Some((weight, bias)) => { + let weight = weight.reshape(((), 1))?; + let bias = bias.reshape(((), 1))?; + x.broadcast_mul(&weight)?.broadcast_add(&bias)? + } + }; + x.reshape(x_dims_post_transpose)?.transpose(0, 1) + } + + fn forward_eval(&self, x: &Tensor) -> Result { + let target_shape: Vec = x + .dims() + .iter() + .enumerate() + .map(|(idx, v)| if idx == 1 { *v } else { 1 }) + .collect(); + let target_shape = target_shape.as_slice(); + + let x = x + .broadcast_sub( + &self + .running_mean + .as_detached_tensor() + .reshape(target_shape)?, + )? + .broadcast_div( + &(self + .running_var + .as_detached_tensor() + .reshape(target_shape)? + + self.eps)? + .sqrt()?, + )?; + + match &self.weight_and_bias { + None => Ok(x), + Some((weight, bias)) => { + let weight = weight.reshape(target_shape)?; + let bias = bias.reshape(target_shape)?; + x.broadcast_mul(&weight)?.broadcast_add(&bias) + } + } + } +} + +impl crate::ModuleT for BatchNorm { + fn forward_t(&self, x: &Tensor, train: bool) -> Result { + if train { + self.forward_train(x) + } else { + self.forward_eval(x) + } + } +} + +pub fn batch_norm>( + num_features: usize, + config: C, + vb: crate::VarBuilder, +) -> Result { + use crate::Init; + let config = config.into(); + if config.eps < 0. { + mcandle_core::bail!("batch-norm eps cannot be negative {}", config.eps) + } + let running_mean = vb.get_with_hints(num_features, "running_mean", Init::Const(0.))?; + let running_var = vb.get_with_hints(num_features, "running_var", Init::Const(1.))?; + let weight_and_bias = if config.affine { + let weight = vb.get_with_hints(num_features, "weight", Init::Const(1.))?; + let bias = vb.get_with_hints(num_features, "bias", Init::Const(0.))?; + Some((weight, bias)) + } else { + None + }; + Ok(BatchNorm { + running_mean: Var::from_tensor(&running_mean)?, + running_var: Var::from_tensor(&running_var)?, + weight_and_bias, + remove_mean: config.remove_mean, + eps: config.eps, + momentum: config.momentum, + }) +} diff --git a/mcandle/mcandle-nn/src/conv.rs b/mcandle/mcandle-nn/src/conv.rs new file mode 100644 index 000000000..8afa44d87 --- /dev/null +++ b/mcandle/mcandle-nn/src/conv.rs @@ -0,0 +1,467 @@ +//! Convolution Layers. +use crate::BatchNorm; +use mcandle_core::{Result, Tensor}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct Conv1dConfig { + pub padding: usize, + pub stride: usize, + pub dilation: usize, + pub groups: usize, +} + +impl Default for Conv1dConfig { + fn default() -> Self { + Self { + padding: 0, + stride: 1, + dilation: 1, + groups: 1, + } + } +} + +#[derive(Clone, Debug)] +pub struct Conv1d { + weight: Tensor, + bias: Option, + config: Conv1dConfig, +} + +impl Conv1d { + pub fn new(weight: Tensor, bias: Option, config: Conv1dConfig) -> Self { + Self { + weight, + bias, + config, + } + } + + pub fn config(&self) -> &Conv1dConfig { + &self.config + } + + pub fn weight(&self) -> &Tensor { + &self.weight + } + + pub fn bias(&self) -> Option<&Tensor> { + self.bias.as_ref() + } +} + +impl crate::Module for Conv1d { + fn forward(&self, x: &Tensor) -> Result { + let x = x.conv1d( + &self.weight, + self.config.padding, + self.config.stride, + self.config.dilation, + self.config.groups, + )?; + match &self.bias { + None => Ok(x), + Some(bias) => { + let b = bias.dims1()?; + let bias = bias.reshape((1, b, 1))?; + Ok(x.broadcast_add(&bias)?) + } + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct ConvTranspose1dConfig { + pub padding: usize, + pub output_padding: usize, + pub stride: usize, + pub dilation: usize, + pub groups: usize, +} + +impl Default for ConvTranspose1dConfig { + fn default() -> Self { + Self { + padding: 0, + output_padding: 0, + stride: 1, + dilation: 1, + groups: 1, + } + } +} + +#[derive(Clone, Debug)] +pub struct ConvTranspose1d { + weight: Tensor, + bias: Option, + config: ConvTranspose1dConfig, +} + +impl ConvTranspose1d { + pub fn new(weight: Tensor, bias: Option, config: ConvTranspose1dConfig) -> Self { + Self { + weight, + bias, + config, + } + } + + pub fn config(&self) -> &ConvTranspose1dConfig { + &self.config + } + + pub fn weight(&self) -> &Tensor { + &self.weight + } + + pub fn bias(&self) -> Option<&Tensor> { + self.bias.as_ref() + } +} + +impl crate::Module for ConvTranspose1d { + fn forward(&self, x: &Tensor) -> Result { + let x = x.conv_transpose1d( + &self.weight, + self.config.padding, + self.config.output_padding, + self.config.stride, + self.config.dilation, + self.config.groups, + )?; + match &self.bias { + None => Ok(x), + Some(bias) => { + let b = bias.dims1()?; + let bias = bias.reshape((1, b, 1))?; + Ok(x.broadcast_add(&bias)?) + } + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct Conv2dConfig { + pub padding: usize, + pub stride: usize, + pub dilation: usize, + pub groups: usize, +} + +impl Default for Conv2dConfig { + fn default() -> Self { + Self { + padding: 0, + stride: 1, + dilation: 1, + groups: 1, + } + } +} + +#[derive(Clone, Debug)] +pub struct Conv2d { + weight: Tensor, + bias: Option, + config: Conv2dConfig, +} + +impl Conv2d { + pub fn new(weight: Tensor, bias: Option, config: Conv2dConfig) -> Self { + Self { + weight, + bias, + config, + } + } + + pub fn config(&self) -> &Conv2dConfig { + &self.config + } + + pub fn weight(&self) -> &Tensor { + &self.weight + } + + pub fn bias(&self) -> Option<&Tensor> { + self.bias.as_ref() + } + + pub fn absorb_bn(&self, bn: &BatchNorm) -> Result { + if let Some((w_bn, b_bn)) = bn.weight_and_bias() { + let std_ = w_bn.div(&((bn.running_var() + bn.eps())?.sqrt()?))?; + let weight = self + .weight() + .broadcast_mul(&(std_.reshape((self.weight().dims4()?.0, 1, 1, 1))?))?; + let bias = match &self.bias { + None => b_bn.sub(&(std_.mul(bn.running_mean())?))?, + Some(bias) => b_bn.add(&(std_.mul(&bias.sub(bn.running_mean())?)?))?, + }; + Ok(Self { + weight, + bias: Some(bias), + config: self.config, + }) + } else { + mcandle_core::bail!("batch norm does not have weight_and_bias") + } + } +} + +impl crate::Module for Conv2d { + fn forward(&self, x: &Tensor) -> Result { + let x = x.conv2d( + &self.weight, + self.config.padding, + self.config.stride, + self.config.dilation, + self.config.groups, + )?; + match &self.bias { + None => Ok(x), + Some(bias) => { + let b = bias.dims1()?; + let bias = bias.reshape((1, b, 1, 1))?; + Ok(x.broadcast_add(&bias)?) + } + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct ConvTranspose2dConfig { + pub padding: usize, + pub output_padding: usize, + pub stride: usize, + pub dilation: usize, + // TODO: support groups. +} + +impl Default for ConvTranspose2dConfig { + fn default() -> Self { + Self { + padding: 0, + output_padding: 0, + stride: 1, + dilation: 1, + } + } +} + +#[derive(Clone, Debug)] +pub struct ConvTranspose2d { + weight: Tensor, + bias: Option, + config: ConvTranspose2dConfig, +} + +impl ConvTranspose2d { + pub fn new(weight: Tensor, bias: Option, config: ConvTranspose2dConfig) -> Self { + Self { + weight, + bias, + config, + } + } + + pub fn config(&self) -> &ConvTranspose2dConfig { + &self.config + } + + pub fn weight(&self) -> &Tensor { + &self.weight + } + + pub fn bias(&self) -> Option<&Tensor> { + self.bias.as_ref() + } +} + +impl crate::Module for ConvTranspose2d { + fn forward(&self, x: &Tensor) -> Result { + let x = x.conv_transpose2d( + &self.weight, + self.config.padding, + self.config.output_padding, + self.config.stride, + self.config.dilation, + )?; + match &self.bias { + None => Ok(x), + Some(bias) => { + let b = bias.dims1()?; + let bias = bias.reshape((1, b, 1, 1))?; + Ok(x.broadcast_add(&bias)?) + } + } + } +} + +pub fn conv1d( + in_channels: usize, + out_channels: usize, + kernel_size: usize, + cfg: Conv1dConfig, + vb: crate::VarBuilder, +) -> Result { + let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; + let ws = vb.get_with_hints( + (out_channels, in_channels / cfg.groups, kernel_size), + "weight", + init_ws, + )?; + let bound = 1. / (in_channels as f64).sqrt(); + let init_bs = crate::Init::Uniform { + lo: -bound, + up: bound, + }; + let bs = vb.get_with_hints(out_channels, "bias", init_bs)?; + Ok(Conv1d::new(ws, Some(bs), cfg)) +} + +pub fn conv1d_no_bias( + in_channels: usize, + out_channels: usize, + kernel_size: usize, + cfg: Conv1dConfig, + vb: crate::VarBuilder, +) -> Result { + let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; + let ws = vb.get_with_hints( + (out_channels, in_channels / cfg.groups, kernel_size), + "weight", + init_ws, + )?; + Ok(Conv1d::new(ws, None, cfg)) +} + +pub fn conv_transpose1d( + in_channels: usize, + out_channels: usize, + kernel_size: usize, + cfg: ConvTranspose1dConfig, + vb: crate::VarBuilder, +) -> Result { + let bound = 1. / (out_channels as f64 * kernel_size as f64).sqrt(); + let init = crate::Init::Uniform { + lo: -bound, + up: bound, + }; + let ws = vb.get_with_hints( + (in_channels, out_channels / cfg.groups, kernel_size), + "weight", + init, + )?; + let bs = vb.get_with_hints(out_channels, "bias", init)?; + Ok(ConvTranspose1d::new(ws, Some(bs), cfg)) +} + +pub fn conv_transpose1d_no_bias( + in_channels: usize, + out_channels: usize, + kernel_size: usize, + cfg: ConvTranspose1dConfig, + vb: crate::VarBuilder, +) -> Result { + let bound = 1. / (out_channels as f64 * kernel_size as f64).sqrt(); + let init = crate::Init::Uniform { + lo: -bound, + up: bound, + }; + let ws = vb.get_with_hints( + (in_channels, out_channels / cfg.groups, kernel_size), + "weight", + init, + )?; + Ok(ConvTranspose1d::new(ws, None, cfg)) +} + +pub fn conv2d( + in_channels: usize, + out_channels: usize, + kernel_size: usize, + cfg: Conv2dConfig, + vb: crate::VarBuilder, +) -> Result { + let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; + let ws = vb.get_with_hints( + ( + out_channels, + in_channels / cfg.groups, + kernel_size, + kernel_size, + ), + "weight", + init_ws, + )?; + let bound = 1. / (in_channels as f64).sqrt(); + let init_bs = crate::Init::Uniform { + lo: -bound, + up: bound, + }; + let bs = vb.get_with_hints(out_channels, "bias", init_bs)?; + Ok(Conv2d::new(ws, Some(bs), cfg)) +} + +pub fn conv2d_no_bias( + in_channels: usize, + out_channels: usize, + kernel_size: usize, + cfg: Conv2dConfig, + vb: crate::VarBuilder, +) -> Result { + let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; + let ws = vb.get_with_hints( + ( + out_channels, + in_channels / cfg.groups, + kernel_size, + kernel_size, + ), + "weight", + init_ws, + )?; + Ok(Conv2d::new(ws, None, cfg)) +} + +pub fn conv_transpose2d( + in_channels: usize, + out_channels: usize, + kernel_size: usize, + cfg: ConvTranspose2dConfig, + vb: crate::VarBuilder, +) -> Result { + let bound = 1. / (out_channels as f64).sqrt() / kernel_size as f64; + let init = crate::Init::Uniform { + lo: -bound, + up: bound, + }; + let ws = vb.get_with_hints( + (in_channels, out_channels, kernel_size, kernel_size), + "weight", + init, + )?; + let bs = vb.get_with_hints(out_channels, "bias", init)?; + Ok(ConvTranspose2d::new(ws, Some(bs), cfg)) +} + +pub fn conv_transpose2d_no_bias( + in_channels: usize, + out_channels: usize, + kernel_size: usize, + cfg: ConvTranspose2dConfig, + vb: crate::VarBuilder, +) -> Result { + let bound = 1. / (out_channels as f64).sqrt() / kernel_size as f64; + let init = crate::Init::Uniform { + lo: -bound, + up: bound, + }; + let ws = vb.get_with_hints( + (in_channels, out_channels, kernel_size, kernel_size), + "weight", + init, + )?; + Ok(ConvTranspose2d::new(ws, None, cfg)) +} diff --git a/mcandle/mcandle-nn/src/embedding.rs b/mcandle/mcandle-nn/src/embedding.rs new file mode 100644 index 000000000..3a4ba057d --- /dev/null +++ b/mcandle/mcandle-nn/src/embedding.rs @@ -0,0 +1,49 @@ +//! Embedding Layer. +use mcandle_core::{Result, Tensor}; + +#[derive(Clone, Debug)] +pub struct Embedding { + embeddings: Tensor, + hidden_size: usize, +} + +impl Embedding { + pub fn new(embeddings: Tensor, hidden_size: usize) -> Self { + Self { + embeddings, + hidden_size, + } + } + + pub fn embeddings(&self) -> &Tensor { + &self.embeddings + } + + /// Get the hidden size of the embedding matrix + pub fn hidden_size(&self) -> usize { + self.hidden_size + } +} + +impl crate::Module for Embedding { + fn forward(&self, indexes: &Tensor) -> Result { + let mut final_dims = indexes.dims().to_vec(); + final_dims.push(self.hidden_size); + let indexes = indexes.flatten_all()?; + let values = self.embeddings.index_select(&indexes, 0)?; + let values = values.reshape(final_dims)?; + Ok(values) + } +} + +pub fn embedding(in_size: usize, out_size: usize, vb: crate::VarBuilder) -> Result { + let embeddings = vb.get_with_hints( + (in_size, out_size), + "weight", + crate::Init::Randn { + mean: 0., + stdev: 1., + }, + )?; + Ok(Embedding::new(embeddings, out_size)) +} diff --git a/mcandle/mcandle-nn/src/encoding.rs b/mcandle/mcandle-nn/src/encoding.rs new file mode 100644 index 000000000..47118dbbd --- /dev/null +++ b/mcandle/mcandle-nn/src/encoding.rs @@ -0,0 +1,150 @@ +//! Encoding Utilities. (e.g., one-hot/cold encoding) + +use mcandle_core::{bail, DType, Result, Tensor, WithDType}; + +/// One-hot/cold encoding. +/// +/// Given an input tensor of indices, this function returns a tensor of the same shape as the input +/// tensor with an additional dimension of the given depth size. The values in the returned tensor are +/// all set to the `off_value` except for the positions represented by the indices, which are set to the `on_value`. +/// +/// This method returns a tensor with a rank that is one rank larger than the input tensor. +/// +/// As an example, the following tensor will be encoded to a one-hot matrix: +/// +/// `[[0i64, 2], [1, -1]]` +/// +/// with a depth of 4 will be encoded to: +/// +/// `[[[1, 0, 0, 0], [0, 0, 1, 0]], [[0, 1, 0, 0], [0, 0, 0, 0]]]` +/// +/// When the input tensor index has a value of -1, the corresponding one-hot vector will be ignored, +/// resulting in a vector of values set to the `off_value`. +/// +/// +/// This method supports one-cold encoding by setting `on_value` to `0` and `off_value` to `1`. +/// By default `on_value` is `1` and `off_value` is `0`. +/// +/// Other encoding values can be used by setting `on_value` and `off_value` to the desired values. +/// +/// # Examples +/// +/// ## One-hot encoding +/// +/// ```rust +/// use mcandle_core::{Shape, Tensor, Device}; +/// use mcandle_nn::encoding::one_hot; +/// +/// let device = mcandle_core::Device::Cpu; +/// +/// let indices = Tensor::new(vec![vec![0i64, 2], vec![1, -1]], &device).unwrap(); +/// let depth = 4; +/// let one_hot = one_hot(indices, depth, 1f32, 0f32).unwrap(); +/// +/// let expected_matrix = [ +/// [[1.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0]], +/// [[0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], +/// ]; +/// +/// assert_eq!(one_hot.shape(), &Shape::from((2, 2, depth))); +/// +/// let matrix = one_hot.to_vec3::().unwrap(); +/// +/// assert_eq!(matrix, expected_matrix); +///``` +/// ## One-cold Encoding +/// +/// ```rust +/// use mcandle_core::{Shape, Tensor, Device}; +/// use mcandle_nn::encoding::one_hot; +/// +/// +/// let device = mcandle_core::Device::Cpu; +/// let depth = 4; +/// let indices = Tensor::new(vec![vec![0u8, 2], vec![1, 3]], &device).unwrap(); +/// let one_cold = one_hot(indices, depth, 0u8, 1u8).unwrap(); +/// +/// let expected_matrix = [[[0, 1, 1, 1], [1, 1, 0, 1]], [[1, 0, 1, 1], [1, 1, 1, 0]]]; +/// +/// assert_eq!(one_cold.shape(), &Shape::from((2, 2, depth))); +/// +/// let matrix = one_cold.to_vec3::().unwrap(); +/// +/// assert_eq!(matrix, expected_matrix); +/// ``` +/// +/// +/// # Bails +/// +/// This method bails if: +/// - One of the index value is less than -1. +/// - One of the index value is greater than or equal to the depth value. +/// - The input data type is not `U8`, `U32`, or `I64`. +/// +/// # API Design +/// +/// The api design for this method is loosely based on the [TensorFlow One-Hot](https://www.tensorflow.org/api_docs/python/tf/one_hot) method. +pub fn one_hot( + indices: Tensor, + depth: usize, + on_value: D, + off_value: D, +) -> Result { + let mut target_shape = indices.dims().to_vec(); + target_shape.push(depth); + let indices = indices.flatten_all()?; + let mut out = vec![off_value; depth * indices.elem_count()]; + match indices.dtype() { + DType::U8 => { + let indices = indices.to_vec1::()?; + for (i, &index) in indices.iter().enumerate() { + set_at_index(index, i * depth, depth, &mut out, on_value)?; + } + } + DType::U32 => { + let indices = indices.to_vec1::()?; + for (i, &index) in indices.iter().enumerate() { + set_at_index(index, i * depth, depth, &mut out, on_value)?; + } + } + DType::I64 => { + let indices = indices.to_vec1::()?; + for (i, &index) in indices.iter().enumerate() { + set_at_index(index, i * depth, depth, &mut out, on_value)?; + } + } + dtype => { + bail!("one_hot: unsupported data type {dtype:?}, expected U8, U32, or I64") + } + }; + Tensor::from_vec(out, target_shape, indices.device()) +} + +fn set_at_index>( + value: I, + offset: usize, + depth: usize, + v: &mut [D], + on_value: D, +) -> Result<()> { + let value = value.into(); + // Skip for an entire row of off_values + if value == -1 { + return Ok(()); + } + if value < -1 { + bail!( + "one_hot: invalid negative index value {value}, expected a positive index value or -1" + ); + } + let value = value as usize; + if value >= depth { + bail!("one_hot: index value {value} exceeds depth {depth}") + } + let idx = offset + value; + if idx >= v.len() { + bail!("one_hot: index out of bounds {idx}, len {}", v.len()); + } + v[idx] = on_value; + Ok(()) +} diff --git a/mcandle/mcandle-nn/src/func.rs b/mcandle/mcandle-nn/src/func.rs new file mode 100644 index 000000000..30b9c3010 --- /dev/null +++ b/mcandle/mcandle-nn/src/func.rs @@ -0,0 +1,73 @@ +//! Layers defined by closures. +use mcandle_core::{Result, Tensor}; +use std::sync::Arc; + +/// A layer defined by a simple closure. +#[derive(Clone)] +pub struct Func<'a> { + #[allow(clippy::type_complexity)] + f: Arc Result + Send + Sync>, +} + +impl<'a> std::fmt::Debug for Func<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "func") + } +} + +pub fn func<'a, F>(f: F) -> Func<'a> +where + F: 'a + Fn(&Tensor) -> Result + Send + Sync, +{ + Func { f: Arc::new(f) } +} + +impl<'a> super::Module for Func<'a> { + fn forward(&self, xs: &Tensor) -> Result { + (*self.f)(xs) + } +} + +impl<'a> Func<'a> { + pub fn new(f: F) -> Self + where + F: 'a + Fn(&Tensor) -> Result + Send + Sync, + { + Self { f: Arc::new(f) } + } +} + +/// A layer defined by a simple closure. +#[derive(Clone)] +pub struct FuncT<'a> { + #[allow(clippy::type_complexity)] + f: Arc Result + Send + Sync>, +} + +impl<'a> std::fmt::Debug for FuncT<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "func") + } +} + +pub fn func_t<'a, F>(f: F) -> FuncT<'a> +where + F: 'a + Fn(&Tensor, bool) -> Result + Send + Sync, +{ + FuncT { f: Arc::new(f) } +} + +impl<'a> super::ModuleT for FuncT<'a> { + fn forward_t(&self, xs: &Tensor, train: bool) -> Result { + (*self.f)(xs, train) + } +} + +impl<'a> FuncT<'a> { + pub fn new(f: F) -> Self + where + F: 'a + Fn(&Tensor, bool) -> Result + Send + Sync, + { + Self { f: Arc::new(f) } + } +} diff --git a/mcandle/mcandle-nn/src/group_norm.rs b/mcandle/mcandle-nn/src/group_norm.rs new file mode 100644 index 000000000..033eae1be --- /dev/null +++ b/mcandle/mcandle-nn/src/group_norm.rs @@ -0,0 +1,85 @@ +//! Group Normalization. +//! +//! This layer applies Group Normalization over a mini-batch of inputs. +use mcandle_core::{DType, Result, Tensor}; + +// This group norm version handles both weight and bias so removes the mean. +#[derive(Clone, Debug)] +pub struct GroupNorm { + weight: Tensor, + bias: Tensor, + eps: f64, + num_channels: usize, + num_groups: usize, +} + +impl GroupNorm { + pub fn new( + weight: Tensor, + bias: Tensor, + num_channels: usize, + num_groups: usize, + eps: f64, + ) -> Result { + if num_channels % num_groups != 0 { + mcandle_core::bail!( + "GroupNorm: num_groups ({num_groups}) must divide num_channels ({num_channels})" + ) + } + Ok(Self { + weight, + bias, + eps, + num_channels, + num_groups, + }) + } +} + +impl crate::Module for GroupNorm { + fn forward(&self, x: &Tensor) -> Result { + let x_shape = x.dims(); + if x_shape.len() <= 2 { + mcandle_core::bail!("input rank for GroupNorm should be at least 3"); + } + let (b_sz, n_channels) = (x_shape[0], x_shape[1]); + let hidden_size = x_shape[2..].iter().product::() * n_channels / self.num_groups; + if n_channels != self.num_channels { + mcandle_core::bail!( + "unexpected num-channels in GroupNorm ({n_channels} <> {}", + self.num_channels + ) + } + let x_dtype = x.dtype(); + let internal_dtype = match x_dtype { + DType::F16 | DType::BF16 => DType::F32, + d => d, + }; + let x = x.reshape((b_sz, self.num_groups, hidden_size))?; + let x = x.to_dtype(internal_dtype)?; + let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?; + let x = x.broadcast_sub(&mean_x)?; + let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?; + let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; + let mut w_dims = vec![1; x_shape.len()]; + w_dims[1] = n_channels; + let weight = self.weight.reshape(w_dims.clone())?; + let bias = self.bias.reshape(w_dims)?; + x_normed + .to_dtype(x_dtype)? + .reshape(x_shape)? + .broadcast_mul(&weight)? + .broadcast_add(&bias) + } +} + +pub fn group_norm( + num_groups: usize, + num_channels: usize, + eps: f64, + vb: crate::VarBuilder, +) -> Result { + let weight = vb.get_with_hints(num_channels, "weight", crate::Init::Const(1.))?; + let bias = vb.get_with_hints(num_channels, "bias", crate::Init::Const(0.))?; + GroupNorm::new(weight, bias, num_channels, num_groups, eps) +} diff --git a/mcandle/mcandle-nn/src/init.rs b/mcandle/mcandle-nn/src/init.rs new file mode 100644 index 000000000..22f054914 --- /dev/null +++ b/mcandle/mcandle-nn/src/init.rs @@ -0,0 +1,147 @@ +//! Variable initialization. +// This is based on: +// https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/nn/init.py# +use mcandle_core::{DType, Device, Result, Shape, Tensor, Var}; + +/// Number of features as input or output of a layer. +/// In Kaiming initialization, choosing `FanIn` preserves +/// the magnitude of the variance of the weights in the +/// forward pass, choosing `FanOut` preserves this +/// magnitude in the backward pass. +#[derive(Debug, Copy, Clone)] +pub enum FanInOut { + FanIn, + FanOut, +} + +impl FanInOut { + /// Compute the fan-in or fan-out value for a weight tensor of + /// the specified dimensions. + /// + pub fn for_shape(&self, shape: &Shape) -> usize { + let dims = shape.dims(); + let receptive_field_size: usize = dims.iter().skip(2).product(); + match &self { + FanInOut::FanIn => { + if dims.len() < 2 { + 1 + } else { + dims[1] * receptive_field_size + } + } + FanInOut::FanOut => { + if dims.is_empty() { + 1 + } else { + dims[0] * receptive_field_size + } + } + } + } +} + +#[derive(Debug, Copy, Clone)] +pub enum NormalOrUniform { + Normal, + Uniform, +} + +/// The non-linear function that follows this layer. ReLU is the +/// recommended value. +#[derive(Debug, Copy, Clone)] +pub enum NonLinearity { + ReLU, + Linear, + Sigmoid, + Tanh, + SELU, + ExplicitGain(f64), +} + +impl NonLinearity { + // https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/nn/init.py#L67 + pub fn gain(&self) -> f64 { + match *self { + NonLinearity::ReLU => 2f64.sqrt(), + NonLinearity::Tanh => 5. / 3., + NonLinearity::Linear | NonLinearity::Sigmoid => 1., + NonLinearity::SELU => 0.75, + NonLinearity::ExplicitGain(g) => g, + } + } +} + +/// Variable initializations. +#[derive(Debug, Copy, Clone)] +pub enum Init { + /// Constant value. + Const(f64), + + /// Random normal with some mean and standard deviation. + Randn { mean: f64, stdev: f64 }, + + /// Uniform initialization between some lower and upper bounds. + Uniform { lo: f64, up: f64 }, + + /// Kaiming uniform initialization. + /// See "Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification" + /// He, K. et al. (2015). This uses a uniform distribution. + Kaiming { + dist: NormalOrUniform, + fan: FanInOut, + non_linearity: NonLinearity, + }, +} + +pub const ZERO: Init = Init::Const(0.); +pub const ONE: Init = Init::Const(1.); + +pub const DEFAULT_KAIMING_UNIFORM: Init = Init::Kaiming { + dist: NormalOrUniform::Uniform, + fan: FanInOut::FanIn, + non_linearity: NonLinearity::ReLU, +}; + +pub const DEFAULT_KAIMING_NORMAL: Init = Init::Kaiming { + dist: NormalOrUniform::Normal, + fan: FanInOut::FanIn, + non_linearity: NonLinearity::ReLU, +}; + +impl Init { + /// Creates a new tensor with the specified shape, device, and initialization. + pub fn var>(&self, s: S, dtype: DType, device: &Device) -> Result { + match self { + Self::Const(v) if *v == 0. => Var::zeros(s, dtype, device), + Self::Const(v) if *v == 1. => Var::ones(s, dtype, device), + Self::Const(cst) => { + Var::from_tensor(&Tensor::ones(s, dtype, device)?.affine(*cst, 0.)?) + } + Self::Uniform { lo, up } => Var::rand_f64(*lo, *up, s, dtype, device), + Self::Randn { mean, stdev } => Var::randn_f64(*mean, *stdev, s, dtype, device), + Self::Kaiming { + dist, + fan, + non_linearity, + } => { + let s = s.into(); + let fan = fan.for_shape(&s); + let gain = non_linearity.gain(); + let std = gain / (fan as f64).sqrt(); + match dist { + NormalOrUniform::Uniform => { + let bound = 3f64.sqrt() * std; + Var::rand_f64(-bound, bound, s, dtype, device) + } + NormalOrUniform::Normal => Var::randn_f64(0., std, s, dtype, device), + } + } + } + } +} + +impl Default for Init { + fn default() -> Self { + Self::Const(0.) + } +} diff --git a/mcandle/mcandle-nn/src/kv_cache.rs b/mcandle/mcandle-nn/src/kv_cache.rs new file mode 100644 index 000000000..9ad0834a1 --- /dev/null +++ b/mcandle/mcandle-nn/src/kv_cache.rs @@ -0,0 +1,371 @@ +//! Cache Implementations +//! +use mcandle_core::{Device, Result, Tensor}; + +#[derive(Debug, Clone)] +pub struct Cache { + // all_data is an option on a Tensor, this makes it possible to only create the actual tensor + // on the first call where the batch size is easily known. + // Also this makes it safe to clone a KvCache that has been reseted (as in it will not share + // its internal state with the cloned instance). + all_data: Option, + dim: usize, + current_seq_len: usize, + max_seq_len: usize, +} + +impl Cache { + pub fn new(dim: usize, max_seq_len: usize) -> Self { + Self { + all_data: None, + dim, + current_seq_len: 0, + max_seq_len, + } + } + + pub fn dim(&self) -> usize { + self.dim + } + + pub fn current_seq_len(&self) -> usize { + self.current_seq_len + } + + pub fn max_seq_len(&self) -> usize { + self.max_seq_len + } + + pub fn all_data(&self) -> &Option { + &self.all_data + } + + pub fn current_data(&self) -> Result> { + let data = match self.all_data.as_ref() { + None => None, + Some(d) => Some(d.narrow(self.dim, 0, self.current_seq_len)?), + }; + Ok(data) + } + + pub fn reset(&mut self) { + self.current_seq_len = 0; + self.all_data = None; + } + + pub fn append(&mut self, src: &Tensor) -> Result<()> { + let seq_len = src.dim(self.dim)?; + // This doesn't seem very idiomatic but because the creation can fail, it's tricky to use + // self.all_data.get_or_insert_with. + if self.all_data.is_none() { + let mut shape = src.dims().to_vec(); + shape[self.dim] = self.max_seq_len; + let ad = Tensor::zeros(shape, src.dtype(), src.device())?; + self.all_data = Some(ad) + }; + let ad = self.all_data.as_mut().unwrap(); + if self.current_seq_len + seq_len > self.max_seq_len { + mcandle_core::bail!( + "kv-cache: above max-seq-len {}+{seq_len}>{}", + self.current_seq_len, + self.max_seq_len + ) + } + ad.slice_set(src, self.dim, self.current_seq_len)?; + self.current_seq_len += seq_len; + Ok(()) + } +} + +#[derive(Debug, Clone)] +pub struct KvCache { + k: Cache, + v: Cache, +} + +impl KvCache { + pub fn new(dim: usize, max_seq_len: usize) -> Self { + let k = Cache::new(dim, max_seq_len); + let v = Cache::new(dim, max_seq_len); + Self { k, v } + } + + pub fn k_cache(&self) -> &Cache { + &self.k + } + + pub fn v_cache(&self) -> &Cache { + &self.v + } + + pub fn k_cache_mut(&mut self) -> &mut Cache { + &mut self.k + } + + pub fn v_cache_mut(&mut self) -> &mut Cache { + &mut self.v + } + + pub fn k(&self) -> Result> { + self.k.current_data() + } + + pub fn v(&self) -> Result> { + self.v.current_data() + } + + pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> { + self.k.append(k)?; + self.v.append(v)?; + let out_k = self.k.current_data()?; + let out_v = self.v.current_data()?; + let k = match out_k { + None => { + let mut shape = k.dims().to_vec(); + shape[self.k.dim] = 0; + Tensor::zeros(shape, k.dtype(), k.device())? + } + Some(k) => k, + }; + let v = match out_v { + None => { + let mut shape = v.dims().to_vec(); + shape[self.k.dim] = 0; + Tensor::zeros(shape, v.dtype(), v.device())? + } + Some(v) => v, + }; + Ok((k, v)) + } + + pub fn current_seq_len(&self) -> usize { + self.k.current_seq_len() + } + + pub fn reset(&mut self) { + self.k.reset(); + self.v.reset(); + } +} + +#[derive(Debug, Clone)] +pub struct RotatingCache { + all_data: Option, + dim: usize, + // `offset` is the current write index in the buffer + offset: usize, + // The total size of the sequence seen so far. + current_seq_len: usize, + // max_seq_len is the size of the rotating buffer, it is actually allowed for the full + // sequence to grow past this limit. + max_seq_len: usize, +} + +impl RotatingCache { + pub fn new(dim: usize, max_seq_len: usize) -> Self { + Self { + all_data: None, + dim, + offset: 0, + current_seq_len: 0, + max_seq_len, + } + } + + pub fn offset(&self) -> usize { + self.offset + } + + pub fn dim(&self) -> usize { + self.dim + } + + pub fn current_seq_len(&self) -> usize { + self.current_seq_len + } + + pub fn max_seq_len(&self) -> usize { + self.max_seq_len + } + + pub fn all_data(&self) -> &Option { + &self.all_data + } + + pub fn current_data(&self) -> Result> { + let data = match self.all_data.as_ref() { + None => None, + Some(d) => { + if self.current_seq_len >= self.max_seq_len { + Some(d.clone()) + } else { + Some(d.narrow(self.dim, 0, self.current_seq_len)?) + } + } + }; + Ok(data) + } + + pub fn reset(&mut self) { + self.offset = 0; + self.current_seq_len = 0; + self.all_data = None; + } + + pub fn append(&mut self, src: &Tensor) -> Result { + let seq_len = src.dim(self.dim)?; + // This doesn't seem very idiomatic but because the creation can fail, it's tricky to use + // self.all_data.get_or_insert_with. + if self.all_data.is_none() { + let mut shape = src.dims().to_vec(); + shape[self.dim] = self.max_seq_len; + let ad = Tensor::zeros(shape, src.dtype(), src.device())?; + self.all_data = Some(ad) + }; + let ad = self.all_data.as_mut().unwrap(); + + self.current_seq_len += seq_len; + if seq_len >= self.max_seq_len { + let to_copy = src + .narrow(self.dim, seq_len - self.max_seq_len, self.max_seq_len)? + .contiguous()?; + ad.slice_set(&to_copy, self.dim, 0)?; + self.offset = 0; + // Here we return `src` rather than `ad` so that all the past can be used. + Ok(src.clone()) + } else { + let rem_len = self.max_seq_len - self.offset; + if seq_len <= rem_len { + ad.slice_set(&src.contiguous()?, self.dim, self.offset)?; + self.offset = (self.offset + seq_len) % self.max_seq_len; + } else { + // We have to make two copies here as we go over the boundary of the cache. + if rem_len > 0 { + let src1 = src.narrow(self.dim, 0, rem_len)?.contiguous()?; + ad.slice_set(&src1, self.dim, self.offset)?; + } + let src2 = src + .narrow(self.dim, rem_len, seq_len - rem_len)? + .contiguous()?; + ad.slice_set(&src2, self.dim, 0)?; + self.offset = seq_len - rem_len; + } + if self.current_seq_len >= self.max_seq_len { + Ok(ad.clone()) + } else { + Ok(ad.narrow(self.dim, 0, self.current_seq_len)?) + } + } + } + + fn get_mask_abs(&self, size1: usize, size2: usize, device: &Device) -> Result { + let context = self.max_seq_len; + let mask: Vec<_> = (0..size1) + .flat_map(|i| { + (0..size2).map(move |j| { + u8::from(size1 + j > size2 + i || size1 + j + context < size2 + i) + }) + }) + .collect(); + Tensor::from_slice(&mask, (size1, size2), device) + } + + fn get_mask_rel(&self, size1: usize, size2: usize, device: &Device) -> Result { + let context = self.max_seq_len; + let upd_offset = (self.offset + size1) % self.max_seq_len; + let mask: Vec<_> = (0..size1) + .flat_map(|pos_src| { + // The absolute position of the elements that will get added to the cache. + let pos_src = self.current_seq_len + pos_src; + (0..size2).map(move |pos_cache_rel| { + // The absolute position of the cache elements after the addition. + let pos_cache = self.current_seq_len + size1 + pos_cache_rel - upd_offset; + let pos_cache = if pos_cache_rel < upd_offset { + pos_cache + } else { + pos_cache - self.max_seq_len + }; + u8::from(pos_cache > pos_src || pos_cache + context < pos_src) + }) + }) + .collect(); + Tensor::from_slice(&mask, (size1, size2), device) + } + + /// Returns the attn_mask to be applied *after* adding `seq_len` to the cache. + pub fn attn_mask(&self, seq_len: usize, device: &Device) -> Result> { + let mask = if seq_len == 1 { + None + } else { + let mask = if seq_len < self.max_seq_len { + let cache_out_len = (self.current_seq_len + seq_len).min(self.max_seq_len); + self.get_mask_rel(seq_len, cache_out_len, device)? + } else { + self.get_mask_abs(seq_len, seq_len, device)? + }; + Some(mask) + }; + Ok(mask) + } +} + +#[derive(Debug, Clone)] +pub struct RotatingKvCache { + k: RotatingCache, + v: RotatingCache, +} + +impl RotatingKvCache { + pub fn new(dim: usize, max_seq_len: usize) -> Self { + let k = RotatingCache::new(dim, max_seq_len); + let v = RotatingCache::new(dim, max_seq_len); + Self { k, v } + } + + pub fn k_cache(&self) -> &RotatingCache { + &self.k + } + + pub fn v_cache(&self) -> &RotatingCache { + &self.v + } + + pub fn k_cache_mut(&mut self) -> &mut RotatingCache { + &mut self.k + } + + pub fn v_cache_mut(&mut self) -> &mut RotatingCache { + &mut self.v + } + + pub fn k(&self) -> Result> { + self.k.current_data() + } + + pub fn v(&self) -> Result> { + self.v.current_data() + } + + pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> { + let out_k = self.k.append(k)?; + let out_v = self.v.append(v)?; + Ok((out_k, out_v)) + } + + pub fn offset(&self) -> usize { + self.k.offset() + } + + pub fn current_seq_len(&self) -> usize { + self.k.current_seq_len() + } + + pub fn attn_mask(&self, seq_len: usize, device: &Device) -> Result> { + self.k.attn_mask(seq_len, device) + } + + pub fn reset(&mut self) { + self.k.reset(); + self.v.reset(); + } +} diff --git a/mcandle/mcandle-nn/src/layer_norm.rs b/mcandle/mcandle-nn/src/layer_norm.rs new file mode 100644 index 000000000..92b54f74f --- /dev/null +++ b/mcandle/mcandle-nn/src/layer_norm.rs @@ -0,0 +1,339 @@ +//! Layer Normalization. +//! +//! This layer applies Layer Normalization over a mini-batch of inputs as described in [`Layer +//! Normalization`]. The input is expected to have three dimensions: a batch dimension, a length, +//! and a hidden size, the normalization is applied over the last dimension. +//! +//! # Example +//! +//! ```rust +//! use mcandle_core::{Tensor, Device::Cpu, test_utils::to_vec3_round}; +//! use mcandle_nn::{LayerNorm, Module}; +//! # fn main() -> mcandle_core::Result<()> { +//! +//! let w = Tensor::new(&[1f32, 1f32, 1f32], &Cpu)?; +//! let b = Tensor::new(&[0f32, 0f32, 0f32], &Cpu)?; +//! let layer = LayerNorm::new(w, b, 1e-5); +//! +//! let xs = Tensor::new( +//! &[[[1f32, 2., 3.], [4., 5., 6.], [9., 8., 7.]]], +//! &Cpu)?; +//! let ys = layer.forward(&xs)?; +//! assert_eq!( +//! to_vec3_round(&ys, 4)?, +//! &[[[-1.2247, 0.0, 1.2247], +//! [-1.2247, 0.0, 1.2247], +//! [ 1.2247, 0.0, -1.2247]]]); +//! # Ok(()) } +//! ``` +//! +//! [`Layer Normalization`]: https://arxiv.org/abs/1607.06450 + +use std::marker::PhantomData; + +#[cfg(feature = "cuda")] +use mcandle_core::cuda_backend::{ + cudarc::driver::{DeviceRepr, LaunchAsync, LaunchConfig}, + kernel_name, kernels, CudaDType, WrapErr, +}; + +#[cfg(feature = "cuda")] +use mcandle_core::{ + backend::BackendStorage, from_storage_no_op, CudaDevice, CudaStorage, Device, Storage, + WithDType, +}; + +use mcandle_core::{DType, Module, Result, Tensor, D}; + +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct LayerNormConfig { + pub eps: f64, + /// Whether to remove the mean or not, the default is true and when set to false, this turns + /// this layer into RmsNorm. + pub remove_mean: bool, + pub affine: bool, +} + +impl Default for LayerNormConfig { + fn default() -> Self { + Self { + eps: 1e-5, + remove_mean: true, + affine: true, + } + } +} + +impl From for LayerNormConfig { + fn from(eps: f64) -> Self { + Self { + eps, + remove_mean: true, + affine: true, + } + } +} + +// This layer norm version handles both weight and bias so removes the mean. +#[derive(Clone, Debug)] +pub struct LayerNorm { + weight: Tensor, + bias: Tensor, + remove_mean: bool, + eps: f64, +} + +impl LayerNorm { + pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self { + Self { + weight, + bias, + remove_mean: true, + eps, + } + } + + pub fn new_no_bias(weight: Tensor, eps: f64) -> Self { + Self { + weight: weight.clone(), + bias: Tensor::zeros_like(&weight).unwrap(), + remove_mean: true, + eps, + } + } + + pub fn rms_norm(weight: Tensor, eps: f64) -> Self { + Self { + weight: weight.clone(), + bias: Tensor::zeros_like(&weight).unwrap(), + remove_mean: false, + eps, + } + } + + pub fn weight(&self) -> &Tensor { + &self.weight + } + + pub fn bias(&self) -> &Tensor { + &self.bias + } +} + +impl Module for LayerNorm { + fn forward(&self, x: &Tensor) -> Result { + if x.is_contiguous() && self.remove_mean { + return crate::ops::layer_norm(x, &self.weight, &self.bias, self.eps as f32); + } + let x_dtype = x.dtype(); + let internal_dtype = match x_dtype { + DType::F16 | DType::BF16 => DType::F32, + d => d, + }; + let hidden_size = x.dim(D::Minus1)?; + let x = x.to_dtype(internal_dtype)?; + let x = if self.remove_mean { + let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?; + x.broadcast_sub(&mean_x)? + } else { + x + }; + let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?; + let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; + let x = x_normed.to_dtype(x_dtype)?.broadcast_mul(&self.weight)?; + x.broadcast_add(&self.bias) + } +} + +pub fn layer_norm>( + size: usize, + config: C, + vb: crate::VarBuilder, +) -> Result { + let config = config.into(); + let weight = vb.get_with_hints(size, "weight", crate::Init::Const(1.))?; + let bias = if config.affine { + Some(vb.get_with_hints(size, "bias", crate::Init::Const(0.))?) + } else { + None + }; + Ok(LayerNorm { + weight: weight.clone(), + bias: bias.unwrap_or(Tensor::zeros_like(&weight)?), + remove_mean: config.remove_mean, + eps: config.eps, + }) +} + +// This whole non quantized/quantized RmsNorm is a hack. It seems like quantized works without this impl, but it is slower. +#[derive(Clone, Debug)] +pub struct RmsNormQuantized; +#[derive(Clone, Debug)] +pub struct RmsNormNonQuantized; + +/// RmsNorm is a specialized version of the LayerNorm module. +#[derive(Clone, Debug)] +pub struct RmsNorm { + inner: LayerNorm, + _ghost: PhantomData, +} + +impl RmsNorm { + pub fn new(weight: Tensor, eps: f64) -> Self { + Self { + inner: LayerNorm::rms_norm(weight, eps), + _ghost: PhantomData, + } + } +} + +impl RmsNorm { + pub fn new(weight: Tensor, eps: f64) -> Self { + Self { + inner: LayerNorm::rms_norm(weight, eps), + _ghost: PhantomData, + } + } + + #[cfg(feature = "cuda")] + fn dtype_execute_rmsnorm( + &self, + dev: &CudaDevice, + eps_converter: F, + x_storage: &CudaStorage, + weight_storage: &CudaStorage, + x: &Tensor, + ) -> Result + where + F: FnOnce(f64) -> T, + { + assert!(x.layout().is_contiguous()); + let hidden_size = *x.dims().last().unwrap(); + let elem_count = x.elem_count(); + let num_tokens = elem_count / hidden_size; + let out = unsafe { dev.alloc::(elem_count) }.w()?; + + let cfg = LaunchConfig { + grid_dim: (num_tokens as u32, 1, 1), + block_dim: (u32::min(hidden_size as u32, 1024), 1, 1), + shared_mem_bytes: 0, + }; + + let func = dev.get_or_load_func(&kernel_name::("rms_norm"), kernels::FUSED_RMS_NORM)?; + + let params = ( + &out, + x_storage.as_cuda_slice::()?, + weight_storage.as_cuda_slice::()?, + eps_converter(self.inner.eps), + num_tokens as i32, + hidden_size as i32, + ); + unsafe { func.launch(cfg, params) }.w()?; + + Ok(from_storage_no_op( + Storage::Cuda(CudaStorage::wrap_cuda_slice(out, dev.clone())), + x.shape(), + false, + )) + } + + #[cfg(feature = "cuda")] + fn fused_rmsnorm(&self, x: &Tensor, dev: &CudaDevice) -> Result { + match ( + &*x.storage_and_layout().0, + &*self.inner.weight().storage_and_layout().0, + ) { + (Storage::Cuda(x_storage), Storage::Cuda(weight_storage)) => { + match (x_storage.dtype(), weight_storage.dtype()) { + (DType::BF16, DType::BF16) => self.dtype_execute_rmsnorm::( + dev, + |x| half::bf16::from_f64(x), + &x_storage, + &weight_storage, + x, + ), + (DType::F16, DType::F16) => self.dtype_execute_rmsnorm::( + dev, + |x| half::f16::from_f64(x), + &x_storage, + &weight_storage, + x, + ), + (DType::F32, DType::F32) => self.dtype_execute_rmsnorm::( + dev, + |x| x as f32, + &x_storage, + &weight_storage, + x, + ), + _ => mcandle_core::bail!("DType mismatch in fused rmsnorm."), + } + } + _ => unreachable!(), + } + } +} + +impl RmsNorm { + pub fn into_inner(self) -> LayerNorm { + self.inner + } + pub fn inner(&self) -> &LayerNorm { + &self.inner + } +} + +impl Module for RmsNorm { + fn forward(&self, xs: &Tensor) -> Result { + self.inner.forward(xs) + } +} + +impl Module for RmsNorm { + fn forward(&self, xs: &Tensor) -> Result { + #[cfg(feature = "cuda")] + match (xs.dtype(), xs.device()) { + (DType::BF16, Device::Cuda(dev)) + | (DType::F32, Device::Cuda(dev)) + | (DType::F16, Device::Cuda(dev)) => return self.fused_rmsnorm(xs, &dev), + _ => return self.inner.forward(xs), + } + #[cfg(not(feature = "cuda"))] + { + self.inner.forward(xs) + } + } +} + +pub fn rms_norm_non_quant( + size: usize, + eps: f64, + vb: crate::VarBuilder, +) -> Result> { + let config = LayerNormConfig { + eps, + remove_mean: false, + affine: false, + }; + Ok(RmsNorm { + inner: layer_norm(size, config, vb)?, + _ghost: PhantomData, + }) +} + +pub fn rms_norm_quant( + size: usize, + eps: f64, + vb: crate::VarBuilder, +) -> Result> { + let config = LayerNormConfig { + eps, + remove_mean: false, + affine: false, + }; + Ok(RmsNorm { + inner: layer_norm(size, config, vb)?, + _ghost: PhantomData, + }) +} diff --git a/mcandle/mcandle-nn/src/lib.rs b/mcandle/mcandle-nn/src/lib.rs new file mode 100644 index 000000000..fc07e3022 --- /dev/null +++ b/mcandle/mcandle-nn/src/lib.rs @@ -0,0 +1,64 @@ +//! mcandle-nn +//! +//! ## Other Crates +//! +//! Candle consists of a number of crates. This crate holds structs and functions +//! that allow you to build and train neural nets. You may wish +//! to look at the docs for the other crates which can be found here: +//! +//! - [mcandle-core](https://docs.rs/mcandle-core/). Core Datastructures and DataTypes. +//! - [mcandle-nn](https://docs.rs/mcandle-nn/). Building blocks for Neural Nets. +//! - [mcandle-datasets](https://docs.rs/mcandle-datasets/). Rust access to commonly used Datasets like MNIST. +//! - [mcandle-examples](https://docs.rs/mcandle-examples/). Examples of Candle in Use. +//! - [mcandle-onnx](https://docs.rs/mcandle-onnx/). Loading and using ONNX models. +//! - [mcandle-pyo3](https://docs.rs/mcandle-pyo3/). Access to Candle from Python. +//! - [mcandle-transformers](https://docs.rs/mcandle-transformers/). Candle implemntation of many published transformer models. +//! + +pub mod activation; +pub mod attention; +pub mod batch_norm; +pub mod conv; +pub mod embedding; +pub mod encoding; +pub mod func; +pub mod group_norm; +pub mod init; +pub mod kv_cache; +pub mod layer_norm; +pub mod linear; +pub mod loss; +pub mod ops; +pub mod optim; +pub mod rnn; +pub mod rope; +pub mod rotary_emb; +pub mod sequential; +pub mod var_builder; +pub mod var_map; + +pub use activation::{prelu, Activation, PReLU}; +pub use attention::scaled_dot_product_attention; +pub use batch_norm::{batch_norm, BatchNorm, BatchNormConfig}; +pub use conv::{ + conv1d, conv1d_no_bias, conv2d, conv2d_no_bias, conv_transpose1d, conv_transpose1d_no_bias, + conv_transpose2d, conv_transpose2d_no_bias, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig, + ConvTranspose1d, ConvTranspose1dConfig, ConvTranspose2d, ConvTranspose2dConfig, +}; +pub use embedding::{embedding, Embedding}; +pub use func::{func, func_t, Func, FuncT}; +pub use group_norm::{group_norm, GroupNorm}; +pub use init::Init; +pub use layer_norm::{ + layer_norm, rms_norm_non_quant, rms_norm_quant, LayerNorm, LayerNormConfig, RmsNorm, +}; +pub use linear::{linear, linear_b, linear_no_bias, Linear}; +pub use ops::{kvconcat, Dropout}; +pub use optim::{AdamW, Optimizer, ParamsAdamW, SGD}; +pub use rnn::{gru, lstm, GRUConfig, LSTMConfig, GRU, LSTM, RNN}; +pub use rope::RotaryEmbedding; +pub use sequential::{seq, Sequential}; +pub use var_builder::VarBuilder; +pub use var_map::VarMap; + +pub use mcandle_core::{Module, ModuleT}; diff --git a/mcandle/mcandle-nn/src/linear.rs b/mcandle/mcandle-nn/src/linear.rs new file mode 100644 index 000000000..0c834d66b --- /dev/null +++ b/mcandle/mcandle-nn/src/linear.rs @@ -0,0 +1,90 @@ +//! Linear layer +//! +//! This layer applies a linear transformation to the incoming data, `y = x@w.t() + b`. +//! The bias is optional. The `forward` method can be used to apply the layer, it supports input +//! with a batch dimension (so of shape `(b_sz, in_c)`) or without (of shape `(in_c,)`), the +//! output has shape `(b_sz, out_c)` and `(out_c,)` respectively. +//! +//! ```rust +//! use mcandle_core::{Tensor, Device::Cpu}; +//! use mcandle_nn::{Linear, Module}; +//! # fn main() -> mcandle_core::Result<()> { +//! +//! let w = Tensor::new(&[[1f32, 2.], [3., 4.], [5., 6.]], &Cpu)?; +//! let layer = Linear::new(w, None); // Use no bias. +//! let xs = Tensor::new(&[[10f32, 100.]], &Cpu)?; +//! let ys = layer.forward(&xs)?; +//! assert_eq!(ys.to_vec2::()?, &[[210.0, 430.0, 650.0]]); +//! # Ok(()) } +//! ``` +use mcandle_core::{Result, Tensor}; + +#[derive(Clone, Debug)] +pub struct Linear { + weight: Tensor, + bias: Option, +} + +impl Linear { + pub fn new(weight: Tensor, bias: Option) -> Self { + Self { weight, bias } + } + + pub fn weight(&self) -> &Tensor { + &self.weight + } + + pub fn bias(&self) -> Option<&Tensor> { + self.bias.as_ref() + } +} + +impl super::Module for Linear { + fn forward(&self, x: &Tensor) -> mcandle_core::Result { + let w = match *x.dims() { + [b1, b2, _, _] => self.weight.broadcast_left((b1, b2))?.t()?, + [bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?, + _ => self.weight.t()?, + }; + let x = x.matmul(&w)?; + match &self.bias { + None => Ok(x), + Some(bias) => x.broadcast_add(bias), + } + } +} + +/// Create or initialize a new linear layer. +/// +/// This uses some default names for weights and biases, namely `"weight"` and `"bias"`. +pub fn linear(in_dim: usize, out_dim: usize, vb: crate::VarBuilder) -> Result { + let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; + let ws = vb.get_with_hints((out_dim, in_dim), "weight", init_ws)?; + let bound = 1. / (in_dim as f64).sqrt(); + let init_bs = crate::Init::Uniform { + lo: -bound, + up: bound, + }; + let bs = vb.get_with_hints(out_dim, "bias", init_bs)?; + Ok(Linear::new(ws, Some(bs))) +} + +/// Create or initialize a new linear layer without biases. +pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: crate::VarBuilder) -> Result { + let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; + let ws = vb.get_with_hints((out_dim, in_dim), "weight", init_ws)?; + Ok(Linear::new(ws, None)) +} + +pub fn linear_b( + in_dim: usize, + out_dim: usize, + bias: bool, + vb: crate::VarBuilder, +) -> Result { + if bias { + linear(in_dim, out_dim, vb) + } else { + linear_no_bias(in_dim, out_dim, vb) + } +} diff --git a/mcandle/mcandle-nn/src/loss.rs b/mcandle/mcandle-nn/src/loss.rs new file mode 100644 index 000000000..75fa456e0 --- /dev/null +++ b/mcandle/mcandle-nn/src/loss.rs @@ -0,0 +1,76 @@ +//! Loss Calculations +//! +use mcandle_core::{Result, Tensor}; + +/// The negative log likelihood loss. +/// +/// Arguments +/// +/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number +/// of categories. This is expected to contain log probabilities. +/// * [target]: The ground truth labels as a tensor of u32 of dimension `N`. +/// +/// The resulting tensor is a scalar containing the average value over the batch. +pub fn nll(inp: &Tensor, target: &Tensor) -> Result { + let b_sz = match target.dims() { + &[b_sz] => b_sz, + dims => mcandle_core::bail!("the target tensor should have a single dimension ({dims:?})"), + }; + match inp.dims() { + &[inp_b_sz, _] => { + if inp_b_sz != b_sz { + mcandle_core::bail!( + "batch size mismatch between inp ({inp_b_sz}) and target ({b_sz})" + ) + } + } + dims => mcandle_core::bail!("the target tensor should have two dimensions ({dims:?})"), + } + inp.gather(&target.unsqueeze(1)?, 1)? + .sum_all()? + .affine(-1f64 / b_sz as f64, 0.) +} + +/// The cross-entropy loss. +/// +/// Arguments +/// +/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number +/// of categories. This is expected to raw logits. +/// * [target]: The ground truth labels as a tensor of u32 of dimension `N`. +/// +/// The resulting tensor is a scalar containing the average value over the batch. +pub fn cross_entropy(inp: &Tensor, target: &Tensor) -> Result { + if inp.rank() != 2 { + mcandle_core::bail!("cross_entropy expects an input tensor of rank 2") + } + let inp = crate::ops::log_softmax(inp, 1)?; + nll(&inp, target) +} + +/// The mean squared error loss. +pub fn mse(inp: &Tensor, target: &Tensor) -> Result { + (inp - target)?.sqr()?.mean_all() +} + +/// The binary cross-entropy with logit loss. +/// +/// Arguments +/// +/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number +/// of categories. This is expected to raw logits. +/// * [target]: The ground truth labels as a tensor of u32 of dimension `N, C` where `N` is the batch size and `C` the number +/// of categories. +/// +/// The resulting tensor is a scalar containing the average value over the batch. +pub fn binary_cross_entropy_with_logit(inp: &Tensor, target: &Tensor) -> Result { + let inp = crate::ops::sigmoid(inp)?; + + let left_side = target * inp.log()?; + let right_side = (target.affine(-1., 1.))? * inp.affine(-1., 1.)?.log()?; + + let loss = left_side? + right_side?; + let loss = loss?.neg()?.mean_all()?; + + Ok(loss) +} diff --git a/mcandle/mcandle-nn/src/ops.rs b/mcandle/mcandle-nn/src/ops.rs new file mode 100644 index 000000000..9fbb96393 --- /dev/null +++ b/mcandle/mcandle-nn/src/ops.rs @@ -0,0 +1,1293 @@ +//! Tensor ops. +//! + +use mcandle_core::{CpuStorage, DType, Layout, Module, Result, Shape, Tensor, D}; +use rayon::prelude::*; + +/// Applies the softmax function to the input tensor, rescaling the element so that elements on +/// a slice of fixed index on dimension `dim` are between 0 and 1 and sum to 1. +/// +/// ```rust +/// use mcandle_core::{Tensor, Device, test_utils::to_vec2_round}; +/// let a = Tensor::new(&[[0f32, 1., 0., 1.], [-2., 2., 3., -3.]], &Device::Cpu)?; +/// let a = mcandle_nn::ops::softmax(&a, 1)?; +/// assert_eq!( +/// to_vec2_round(&a, 4)?, +/// &[ +/// [0.1345, 0.3655, 0.1345, 0.3655], +/// [0.0049, 0.2671, 0.7262, 0.0018] +/// ]); +/// # Ok::<(), mcandle_core::Error>(()) +/// ``` +pub fn softmax(xs: &Tensor, dim: D) -> Result { + let dim = dim.to_index(xs.shape(), "softmax")?; + let max = xs.max_keepdim(dim)?; + let diff = xs.broadcast_sub(&max)?; + let num = diff.exp()?; + let den = num.sum_keepdim(dim)?; + num.broadcast_div(&den) +} + +pub fn log_softmax(xs: &Tensor, d: D) -> Result { + let d = d.to_index(xs.shape(), "log-softmax")?; + let max = xs.max_keepdim(d)?; + let diff = xs.broadcast_sub(&max)?; + let sum_exp = diff.exp()?.sum_keepdim(d)?; + let log_sm = diff.broadcast_sub(&sum_exp.log()?)?; + Ok(log_sm) +} + +pub fn silu(xs: &Tensor) -> Result { + xs.silu() +} + +pub fn swiglu(xs: &Tensor) -> Result { + let xs = xs.chunk(2, D::Minus1)?; + &xs[0].silu()? * &xs[1] +} + +struct Sigmoid; + +impl mcandle_core::CustomOp1 for Sigmoid { + fn name(&self) -> &'static str { + "sigmoid" + } + + fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> { + use mcandle_core::backend::BackendStorage; + + fn fwd(v: T) -> T { + (v.neg().exp() + T::one()).recip() + } + + // FIXME: using `mcandle_core::map_dtype` causes compilation errors. + let storage = match storage { + CpuStorage::BF16(slice) => { + CpuStorage::BF16(mcandle_core::cpu_backend::unary_map(slice, layout, fwd)) + } + CpuStorage::F16(slice) => { + CpuStorage::F16(mcandle_core::cpu_backend::unary_map(slice, layout, fwd)) + } + CpuStorage::F32(slice) => { + CpuStorage::F32(mcandle_core::cpu_backend::unary_map(slice, layout, fwd)) + } + CpuStorage::F64(slice) => { + CpuStorage::F64(mcandle_core::cpu_backend::unary_map(slice, layout, fwd)) + } + _ => Err(mcandle_core::Error::UnsupportedDTypeForOp( + storage.dtype(), + self.name(), + ))?, + }; + Ok((storage, layout.shape().clone())) + } + + #[cfg(feature = "cuda")] + fn cuda_fwd( + &self, + storage: &mcandle_core::CudaStorage, + layout: &Layout, + ) -> Result<(mcandle_core::CudaStorage, Shape)> { + use mcandle_core::backend::BackendStorage; + use mcandle_core::cuda_backend::cudarc::driver::{ + CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits, + }; + use mcandle_core::cuda_backend::SlicePtrOrNull; + use mcandle_core::cuda_backend::{kernel_name, kernels, Map1, WrapErr}; + use mcandle_core::{CudaDevice, WithDType}; + + struct S; + impl Map1 for S { + fn f( + &self, + src: &CudaSlice, + dev: &CudaDevice, + layout: &Layout, + ) -> Result> { + let shape = layout.shape(); + let dims = shape.dims(); + let el_count = shape.elem_count(); + let cfg = LaunchConfig::for_num_elems(el_count as u32); + let ds = SlicePtrOrNull::params_from_layout(dev, layout)?; + let src = &src.slice(layout.start_offset()..); + let func = dev.get_or_load_func(&kernel_name::("usigmoid"), kernels::UNARY)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(el_count) }.w()?; + + let params = (el_count, dims.len(), &ds, src, &out); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(out) + } + } + + let dev = storage.device(); + let slice = S.map(&storage.slice, dev, layout)?; + let dst = mcandle_core::CudaStorage { + slice, + device: dev.clone(), + }; + Ok((dst, layout.shape().clone())) + } + + #[cfg(feature = "metal")] + fn metal_fwd( + &self, + storage: &mcandle_core::MetalStorage, + layout: &Layout, + ) -> Result<(mcandle_core::MetalStorage, Shape)> { + use mcandle_core::backend::BackendStorage; + use mcandle_core::MetalError; + let device = storage.device(); + let dtype = storage.dtype(); + let shape = layout.shape(); + let el_count = shape.elem_count(); + let buffer = device.new_buffer(el_count, dtype, "sigmoid")?; + let command_buffer = device.command_buffer()?; + command_buffer.set_label("sigmoid"); + let src = mcandle_metal_kernels::BufferOffset { + buffer: storage.buffer(), + offset_in_bytes: layout.start_offset() * storage.dtype().size_in_bytes(), + }; + + match (el_count % 2, dtype, layout.is_contiguous()) { + (0, DType::BF16 | DType::F16, true) => { + use mcandle_metal_kernels::unary::contiguous_tiled; + let kernel_name = match dtype { + DType::F16 => contiguous_tiled::sigmoid::HALF, + DType::F32 => contiguous_tiled::sigmoid::FLOAT, + DType::BF16 => contiguous_tiled::sigmoid::BFLOAT, + dtype => { + mcandle_core::bail!( + "Metal contiguous_tiled unary sigmoid {dtype:?} not implemented" + ) + } + }; + mcandle_metal_kernels::call_unary_contiguous_tiled( + device.metal_device(), + &command_buffer, + device.kernels(), + kernel_name, + el_count, + src, + &buffer, + ) + .map_err(MetalError::from)?; + } + (_, _, true) => { + use mcandle_metal_kernels::unary::contiguous; + let kernel_name = match dtype { + DType::F16 => contiguous::sigmoid::HALF, + DType::F32 => contiguous::sigmoid::FLOAT, + DType::BF16 => contiguous::sigmoid::BFLOAT, + dtype => { + mcandle_core::bail!( + "Metal contiguous unary sigmoid {dtype:?} not implemented" + ) + } + }; + mcandle_metal_kernels::call_unary_contiguous( + device.metal_device(), + &command_buffer, + device.kernels(), + kernel_name, + el_count, + src, + &buffer, + ) + .map_err(MetalError::from)?; + } + (_, _, false) => { + use mcandle_metal_kernels::unary::strided; + let kernel_name = match dtype { + DType::F16 => strided::sigmoid::HALF, + DType::F32 => strided::sigmoid::FLOAT, + DType::BF16 => strided::sigmoid::BFLOAT, + dtype => { + mcandle_core::bail!("Metal strided unary sigmoid {dtype:?} not implemented") + } + }; + let dst = mcandle_metal_kernels::BufferOffset::zero_offset(&buffer); + mcandle_metal_kernels::call_unary_strided( + device.metal_device(), + &command_buffer, + device.kernels(), + kernel_name, + layout.dims(), + src, + layout.stride(), + dst, + ) + .map_err(MetalError::from)?; + } + } + + let new_storage = mcandle_core::MetalStorage::new(buffer, device.clone(), el_count, dtype); + Ok((new_storage, layout.shape().clone())) + } + + fn bwd(&self, _arg: &Tensor, res: &Tensor, grad_res: &Tensor) -> Result> { + // d/dx sigmoid(x) = (1 - sigmoid(x)) * sigmoid(x) + let d_dx_sigmoid = res.ones_like()?.sub(res)?.mul(res)?; + Ok(Some(grad_res.mul(&d_dx_sigmoid)?)) + } +} + +pub fn sigmoid(xs: &Tensor) -> Result { + xs.apply_op1(Sigmoid) +} + +pub fn hard_sigmoid(xs: &Tensor) -> Result { + // TODO: Should we have a specialized op for this? + ((xs + 3.0)? / 6.0)?.clamp(0f32, 1f32) +} + +pub fn leaky_relu(xs: &Tensor, negative_slope: f64) -> Result { + let zeros = xs.zeros_like()?; + xs.maximum(&zeros)? + xs.minimum(&zeros)? * negative_slope +} + +pub fn dropout(xs: &Tensor, drop_p: f32) -> Result { + // This implementation is inefficient as it stores the full mask for the backward pass. + // Instead we could just store the seed and have a specialized kernel that would both + // generate the random mask and apply it. + // Another easier optimization would be to be able to generate boolean mask using just a bit of + // entropy per element rather than generating a full float per element. + if !(0. ..1.).contains(&drop_p) { + mcandle_core::bail!("dropout probability has to be in [0, 1), got {drop_p}") + } + let rand = Tensor::rand(0f32, 1f32, xs.shape(), xs.device())?; + let scale = 1.0 / (1.0 - drop_p as f64); + let drop_p = Tensor::new(drop_p, xs.device())?.broadcast_as(xs.shape())?; + let mask = (rand.ge(&drop_p)?.to_dtype(xs.dtype())? * scale)?; + xs * mask +} + +#[derive(Clone, Debug)] +pub struct Dropout { + drop_p: f32, +} + +impl Dropout { + pub fn new(drop_p: f32) -> Dropout { + Self { drop_p } + } + + pub fn forward(&self, xs: &Tensor, train: bool) -> Result { + if train { + dropout(xs, self.drop_p) + } else { + Ok(xs.clone()) + } + } +} + +impl mcandle_core::ModuleT for Dropout { + fn forward_t(&self, xs: &Tensor, train: bool) -> Result { + self.forward(xs, train) + } +} + +struct SoftmaxLastDim; + +impl mcandle_core::CustomOp1 for SoftmaxLastDim { + fn name(&self) -> &'static str { + "softmax-last-dim" + } + + fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> { + fn softmax( + src: &[T], + layout: &Layout, + ) -> Result<(CpuStorage, Shape)> { + let src = match layout.contiguous_offsets() { + None => mcandle_core::bail!("input has to be contiguous"), + Some((o1, o2)) => &src[o1..o2], + }; + let el_count = layout.shape().elem_count(); + let dims = layout.shape().dims(); + let dim_m1 = dims[dims.len() - 1]; + let mut dst = vec![T::zero(); el_count]; + src.par_chunks(dim_m1) + .zip(dst.par_chunks_mut(dim_m1)) + .for_each(|(src, dst)| { + let mut max = T::neg_infinity(); + unsafe { T::vec_reduce_max(src.as_ptr(), &mut max, dim_m1) }; + for (s, d) in src.iter().zip(dst.iter_mut()) { + *d = (*s - max).exp(); + } + let mut sum_exp = T::zero(); + unsafe { T::vec_reduce_sum(dst.as_ptr(), &mut sum_exp, dim_m1) }; + for d in dst.iter_mut() { + *d /= sum_exp + } + }); + let storage = mcandle_core::WithDType::to_cpu_storage_owned(dst); + Ok((storage, Shape::from_dims(dims))) + } + + match storage { + CpuStorage::BF16(slice) => softmax::(slice, layout), + CpuStorage::F16(slice) => softmax::(slice, layout), + CpuStorage::F32(slice) => softmax::(slice, layout), + CpuStorage::F64(slice) => softmax::(slice, layout), + _ => mcandle_core::bail!("unsupported dtype for softmax {:?}", storage), + } + } + + #[cfg(feature = "cuda")] + fn cuda_fwd( + &self, + storage: &mcandle_core::CudaStorage, + layout: &Layout, + ) -> Result<(mcandle_core::CudaStorage, Shape)> { + use mcandle_core::cuda_backend::cudarc::driver::{ + CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, + }; + use mcandle_core::cuda_backend::{kernel_name, kernels, Map1, WrapErr}; + use mcandle_core::{CudaDevice, WithDType}; + + struct S; + impl Map1 for S { + fn f( + &self, + src: &CudaSlice, + dev: &CudaDevice, + layout: &Layout, + ) -> Result> { + let src = match layout.contiguous_offsets() { + None => mcandle_core::bail!("input has to be contiguous"), + Some((o1, o2)) => src.slice(o1..o2), + }; + let el = layout.shape().elem_count(); + let dims = layout.shape().dims(); + let dim_m1 = dims[dims.len() - 1]; + let (n_rows, n_cols) = (el / dim_m1, dim_m1); + + let cfg = LaunchConfig { + grid_dim: (n_rows as u32, 1, 1), + block_dim: (1, 32, 1), + shared_mem_bytes: 0, + }; + let func = dev.get_or_load_func(&kernel_name::("softmax"), kernels::REDUCE)?; + // SAFETY: Set later by running the kernel. + let dst = unsafe { dev.alloc::(el) }.w()?; + let params = (&src, &dst, n_cols as i32); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(dst) + } + } + + use mcandle_core::backend::BackendStorage; + let dev = storage.device(); + let slice = S.map(&storage.slice, dev, layout)?; + let dst = mcandle_core::cuda_backend::CudaStorage { + slice, + device: dev.clone(), + }; + Ok((dst, layout.shape().clone())) + } + + #[cfg(feature = "metal")] + fn metal_fwd( + &self, + storage: &mcandle_core::MetalStorage, + layout: &Layout, + ) -> Result<(mcandle_core::MetalStorage, Shape)> { + use mcandle_core::backend::BackendStorage; + let device = storage.device(); + let command_buffer = device.command_buffer()?; + let kernels = device.kernels(); + let name = match storage.dtype() { + DType::F32 => "softmax_f32", + DType::F16 => "softmax_f16", + DType::BF16 => "softmax_bf16", + dtype => mcandle_core::bail!("softmax-last-dim is not implemented for {dtype:?}"), + }; + + let n = layout.stride().len(); + if !(layout.is_contiguous() && layout.stride()[n - 1] == 1) { + mcandle_core::bail!("Non contiguous softmax-last-dim is not implemented"); + } + + let last_dim = layout.dims()[layout.shape().rank() - 1]; + let elem_count = layout.shape().elem_count(); + let output = device.new_buffer(elem_count, storage.dtype(), "softmax")?; + mcandle_metal_kernels::call_last_softmax( + device.metal_device(), + &command_buffer, + kernels, + name, + elem_count, + last_dim, + storage.buffer(), + layout.start_offset() * storage.dtype().size_in_bytes(), + &output, + ) + .map_err(mcandle_core::Error::wrap)?; + let newstorage = + mcandle_core::MetalStorage::new(output, device.clone(), elem_count, storage.dtype()); + Ok((newstorage, layout.shape().clone())) + } +} + +pub fn softmax_last_dim(xs: &Tensor) -> Result { + xs.apply_op1_no_bwd(&SoftmaxLastDim) +} + +// TODO: need cpu and cuda impls +#[allow(dead_code)] +struct AttnSoftmaxLastDim { + scale: f32, +} + +impl mcandle_core::CustomOp2 for AttnSoftmaxLastDim { + fn name(&self) -> &'static str { + "attn-softmax-last-dim" + } + + fn cpu_fwd( + &self, + _a_s: &CpuStorage, + _a_l: &Layout, + _mask_s: &CpuStorage, + _mask_l: &Layout, + ) -> Result<(CpuStorage, Shape)> { + mcandle_core::bail!("cpu attn-softmax-last-dim is not implemented"); + } + + #[cfg(feature = "metal")] + fn metal_fwd( + &self, + a_s: &mcandle_core::MetalStorage, + a_l: &Layout, + mask_s: &mcandle_core::MetalStorage, + mask_l: &Layout, + ) -> Result<(mcandle_core::MetalStorage, Shape)> { + use mcandle_core::backend::BackendStorage; + let device = a_s.device(); + let command_buffer = device.command_buffer()?; + let kernels = device.kernels(); + + let ty = match a_s.dtype() { + DType::F32 => mcandle_metal_kernels::SdpaDType::F32, + DType::F16 => mcandle_metal_kernels::SdpaDType::F16, + DType::BF16 => mcandle_metal_kernels::SdpaDType::BF16, + dtype => mcandle_core::bail!("attn-softmax-last-dim is not implemented for {dtype:?}"), + }; + + if !a_l.is_contiguous() { + mcandle_core::bail!("Non contiguous xs for attn-softmax-last-dim is not implemented"); + } + if !mask_l.is_contiguous() { + mcandle_core::bail!("Non contiguous mask for attn-softmax-last-dim is not implemented"); + } + + if a_l.dims().len() != 4 { + mcandle_core::bail!("attn-softmax-last-dim expects xs of rank 2"); + } + if mask_l.dims().len() != 2 { + mcandle_core::bail!("attn-softmax-last-dim expects mask of rank 2"); + } + if mask_l.dim(D::Minus1)? != a_l.dim(D::Minus1)? + || mask_l.dim(D::Minus2)? != a_l.dim(D::Minus2)? + { + mcandle_core::bail!( + "attn-softmax-last-dim expects last 2 dims to match xs last 2 dims" + ); + } + + let elem_count = a_l.shape().elem_count(); + let output = device.new_buffer(elem_count, a_s.dtype(), "attn-softmax")?; + mcandle_metal_kernels::call_last_attn_softmax( + device.metal_device(), + &command_buffer, + kernels, + a_s.buffer(), + a_l.start_offset(), + mask_s.buffer(), + mask_l.start_offset(), + a_l.dims(), + self.scale, + ty, + &output, + ) + .map_err(mcandle_core::Error::wrap)?; + let newstorage = + mcandle_core::MetalStorage::new(output, device.clone(), elem_count, a_s.dtype()); + Ok((newstorage, a_l.shape().clone())) + } +} + +/// Softmax with fused broadcast addition of a mask and scale. +/// Equivalent to: +/// ```ignore +/// mcandle_nn::ops::softmax_last_dim(&(xs.broadcast_add(&mask)? * scale as f64)?)? +/// ``` +/// - `xs` must be a rank-4 tensor +/// - `mask` must be a rank-2 matrix +/// - The last 2 dimensions of `xs` must match the dimensions of `mask`. +/// +/// Note: if the last dim of `xs` is a multiple of 4, a vectorized implementation will be used. +pub fn attn_softmax_last_dim(xs: &Tensor, mask: &Tensor, scale: f32) -> Result { + if xs.device().is_metal() { + xs.apply_op2_no_bwd(mask, &AttnSoftmaxLastDim { scale }) + } else { + softmax_last_dim(&(xs.broadcast_add(mask)? * scale as f64)?) + } +} + +#[derive(Debug, Clone)] +struct RmsNorm { + eps: f32, +} + +impl mcandle_core::CustomOp2 for RmsNorm { + fn name(&self) -> &'static str { + "rms-norm" + } + + fn cpu_fwd( + &self, + s1: &CpuStorage, + l1: &Layout, + s2: &CpuStorage, + l2: &Layout, + ) -> Result<(CpuStorage, Shape)> { + use mcandle_core::backend::BackendStorage; + + let eps = self.eps; + fn inner< + T: mcandle_core::WithDType + + num_traits::Float + + num_traits::AsPrimitive + + num_traits::FromPrimitive, + >( + src: &[T], + layout: &Layout, + alpha: &[T], + alpha_layout: &Layout, + eps: f32, + ) -> Result<(CpuStorage, Shape)> { + let src = match layout.contiguous_offsets() { + None => mcandle_core::bail!("input has to be contiguous"), + Some((o1, o2)) => &src[o1..o2], + }; + let alpha = match alpha_layout.contiguous_offsets() { + None => mcandle_core::bail!("alpha has to be contiguous"), + Some((o1, o2)) => &alpha[o1..o2], + }; + let el_count = layout.shape().elem_count(); + let dims = layout.shape().dims(); + let dim_m1 = dims[dims.len() - 1]; + let mut dst = vec![T::zero(); el_count]; + src.par_chunks(dim_m1) + .zip(dst.par_chunks_mut(dim_m1)) + .for_each(|(src, dst)| { + let sum2 = src + .iter() + .map(|&v| { + let v = v.as_(); + v * v + }) + .sum::(); + let m = (sum2 / dim_m1 as f32 + eps).sqrt(); + let m = T::from_f32(m).unwrap_or_else(T::nan); + for ((d, s), alpha) in dst.iter_mut().zip(src.iter()).zip(alpha) { + *d = *s / m * *alpha + } + }); + let storage = mcandle_core::WithDType::to_cpu_storage_owned(dst); + Ok((storage, Shape::from_dims(dims))) + } + + use CpuStorage as C; + match (s1, s2) { + (C::BF16(s1), C::BF16(s2)) => inner::(s1, l1, s2, l2, eps), + (C::F16(s1), C::F16(s2)) => inner::(s1, l1, s2, l2, eps), + (C::F32(s1), C::F32(s2)) => inner::(s1, l1, s2, l2, eps), + _ => mcandle_core::bail!("unsupported dtype for rmsnorm {:?}", s1.dtype()), + } + } + + #[cfg(feature = "cuda")] + fn cuda_fwd( + &self, + s1: &mcandle_core::CudaStorage, + l1: &Layout, + s2: &mcandle_core::CudaStorage, + l2: &Layout, + ) -> Result<(mcandle_core::CudaStorage, Shape)> { + use mcandle_core::cuda_backend::cudarc::driver::{ + CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, + }; + use mcandle_core::cuda_backend::{kernel_name, kernels, Map2, WrapErr}; + use mcandle_core::{CudaDevice, WithDType}; + + struct S { + eps: f32, + } + impl Map2 for S { + fn f( + &self, + src: &CudaSlice, + layout: &Layout, + alpha: &CudaSlice, + alpha_layout: &Layout, + dev: &CudaDevice, + ) -> Result> { + let src = match layout.contiguous_offsets() { + None => mcandle_core::bail!("input has to be contiguous"), + Some((o1, o2)) => src.slice(o1..o2), + }; + let alpha = match alpha_layout.contiguous_offsets() { + None => mcandle_core::bail!("alpha has to be contiguous"), + Some((o1, o2)) => alpha.slice(o1..o2), + }; + let el = layout.shape().elem_count(); + let dims = layout.shape().dims(); + let dim_m1 = dims[dims.len() - 1]; + let (n_rows, n_cols) = (el / dim_m1, dim_m1); + + let block_size = if n_cols < 1024 { 32 } else { 1024 }; + let cfg = LaunchConfig { + grid_dim: (n_rows as u32, 1, 1), + block_dim: (block_size, 1, 1), + shared_mem_bytes: 0, + }; + let func = dev.get_or_load_func(&kernel_name::("rmsnorm"), kernels::REDUCE)?; + // SAFETY: Set later by running the kernel. + let dst = unsafe { dev.alloc::(el) }.w()?; + let params = ( + &src, + &dst, + &alpha, + n_cols as i32, + block_size as i32, + self.eps, + ); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(dst) + } + } + + use mcandle_core::backend::BackendStorage; + let dev = s1.device(); + let slice = S { eps: self.eps }.map(&s1.slice, l1, &s2.slice, l2, dev)?; + let dst = mcandle_core::cuda_backend::CudaStorage { + slice, + device: dev.clone(), + }; + Ok((dst, l1.shape().clone())) + } + + #[cfg(feature = "metal")] + fn metal_fwd( + &self, + s1: &mcandle_core::MetalStorage, + l1: &Layout, + s2: &mcandle_core::MetalStorage, + l2: &Layout, + ) -> Result<(mcandle_core::MetalStorage, Shape)> { + use mcandle_core::backend::BackendStorage; + let device = s1.device(); + let command_buffer = device.command_buffer()?; + let kernels = device.kernels(); + let name = match (s1.dtype(), s2.dtype()) { + (DType::F32, DType::F32) => "rmsnorm_f32", + (DType::F16, DType::F16) => "rmsnorm_f16", + (DType::BF16, DType::BF16) => "rmsnorm_bf16", + (dt1, dt2) => mcandle_core::bail!("rmsnorm is not implemented for {dt1:?} {dt2:?}"), + }; + + if !(l1.is_contiguous() && l2.is_contiguous()) { + mcandle_core::bail!("Non contiguous rmsnorm is not implemented"); + } + + let last_dim = l1.dims()[l1.shape().rank() - 1]; + let elem_count = l1.shape().elem_count(); + let output = device.new_buffer(elem_count, s1.dtype(), "rmsnorm")?; + mcandle_metal_kernels::call_rms_norm( + device.metal_device(), + &command_buffer, + kernels, + name, + elem_count, + last_dim, + self.eps, + s1.buffer(), + l1.start_offset() * s1.dtype().size_in_bytes(), + s2.buffer(), + l2.start_offset() * s2.dtype().size_in_bytes(), + &output, + ) + .map_err(mcandle_core::Error::wrap)?; + let newstorage = + mcandle_core::MetalStorage::new(output, device.clone(), elem_count, s1.dtype()); + Ok((newstorage, l1.shape().clone())) + } +} + +pub fn rms_norm_slow(x: &Tensor, alpha: &Tensor, eps: f32) -> Result { + let x_dtype = x.dtype(); + let internal_dtype = match x_dtype { + DType::F16 | DType::BF16 => DType::F32, + d => d, + }; + let hidden_size = x.dim(D::Minus1)?; + let x = x.to_dtype(internal_dtype)?; + let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?; + let x_normed = x.broadcast_div(&(norm_x + eps as f64)?.sqrt()?)?; + x_normed.to_dtype(x_dtype)?.broadcast_mul(alpha) +} + +pub fn rms_norm(xs: &Tensor, alpha: &Tensor, eps: f32) -> Result { + let hidden_size_xs = xs.dim(D::Minus1)?; + let hidden_size_alpha = alpha.dims1()?; + if hidden_size_xs != hidden_size_alpha { + mcandle_core::bail!( + "shape mismatch in rms-norm {:?} {:?}", + xs.shape(), + alpha.shape() + ) + } + xs.apply_op2_no_bwd(alpha, &RmsNorm { eps }) +} + +#[derive(Debug, Clone)] +struct LayerNorm { + eps: f32, +} + +impl mcandle_core::CustomOp3 for LayerNorm { + fn name(&self) -> &'static str { + "layer-norm" + } + + fn cpu_fwd( + &self, + s1: &CpuStorage, + l1: &Layout, + s2: &CpuStorage, + l2: &Layout, + s3: &CpuStorage, + l3: &Layout, + ) -> Result<(CpuStorage, Shape)> { + use mcandle_core::backend::BackendStorage; + + let eps = self.eps; + fn inner< + T: mcandle_core::WithDType + + num_traits::Float + + num_traits::AsPrimitive + + num_traits::FromPrimitive, + >( + src: &[T], + layout: &Layout, + alpha: &[T], + alpha_layout: &Layout, + beta: &[T], + beta_layout: &Layout, + eps: f32, + ) -> Result<(CpuStorage, Shape)> { + let src = match layout.contiguous_offsets() { + None => mcandle_core::bail!("input has to be contiguous"), + Some((o1, o2)) => &src[o1..o2], + }; + let alpha = match alpha_layout.contiguous_offsets() { + None => mcandle_core::bail!("alpha has to be contiguous"), + Some((o1, o2)) => &alpha[o1..o2], + }; + let beta = match beta_layout.contiguous_offsets() { + None => mcandle_core::bail!("beta has to be contiguous"), + Some((o1, o2)) => &beta[o1..o2], + }; + let el_count = layout.shape().elem_count(); + let dims = layout.shape().dims(); + let dim_m1 = dims[dims.len() - 1]; + let mut dst = vec![T::zero(); el_count]; + src.par_chunks(dim_m1) + .zip(dst.par_chunks_mut(dim_m1)) + .for_each(|(src, dst)| { + let mut sum = 0f32; + let mut sum2 = 0f32; + for v in src { + let v = v.as_(); + sum += v; + sum2 += v * v; + } + let mean = sum / dim_m1 as f32; + let var = sum2 / dim_m1 as f32 - mean * mean; + let inv_std = (var + eps).sqrt().recip(); + for ((d, s), (alpha, beta)) in + dst.iter_mut().zip(src.iter()).zip(alpha.iter().zip(beta)) + { + let alpha = alpha.as_(); + let beta = beta.as_(); + let d_ = (s.as_() - mean) * inv_std * alpha + beta; + *d = T::from_f32(d_).unwrap_or_else(T::nan); + } + }); + let storage = mcandle_core::WithDType::to_cpu_storage_owned(dst); + Ok((storage, Shape::from_dims(dims))) + } + + use CpuStorage as C; + match (s1, s2, s3) { + (C::BF16(s1), C::BF16(s2), C::BF16(s3)) => { + inner::(s1, l1, s2, l2, s3, l3, eps) + } + (C::F16(s1), C::F16(s2), C::F16(s3)) => inner::(s1, l1, s2, l2, s3, l3, eps), + (C::F32(s1), C::F32(s2), C::F32(s3)) => inner::(s1, l1, s2, l2, s3, l3, eps), + _ => mcandle_core::bail!("unsupported dtype for rmsnorm {:?}", s1.dtype()), + } + } + + #[cfg(feature = "cuda")] + fn cuda_fwd( + &self, + s1: &mcandle_core::CudaStorage, + l1: &Layout, + s2: &mcandle_core::CudaStorage, + l2: &Layout, + s3: &mcandle_core::CudaStorage, + l3: &Layout, + ) -> Result<(mcandle_core::CudaStorage, Shape)> { + use mcandle_core::cuda_backend::cudarc::driver::{ + CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, + }; + use mcandle_core::cuda_backend::{kernel_name, kernels, Map3, WrapErr}; + use mcandle_core::{CudaDevice, WithDType}; + + struct S { + eps: f32, + } + impl Map3 for S { + fn f( + &self, + src: &CudaSlice, + layout: &Layout, + alpha: &CudaSlice, + alpha_layout: &Layout, + beta: &CudaSlice, + beta_layout: &Layout, + dev: &CudaDevice, + ) -> Result> { + let src = match layout.contiguous_offsets() { + None => mcandle_core::bail!("input has to be contiguous"), + Some((o1, o2)) => src.slice(o1..o2), + }; + let alpha = match alpha_layout.contiguous_offsets() { + None => mcandle_core::bail!("alpha has to be contiguous"), + Some((o1, o2)) => alpha.slice(o1..o2), + }; + let beta = match beta_layout.contiguous_offsets() { + None => mcandle_core::bail!("beta has to be contiguous"), + Some((o1, o2)) => beta.slice(o1..o2), + }; + let el = layout.shape().elem_count(); + let dims = layout.shape().dims(); + let dim_m1 = dims[dims.len() - 1]; + let (n_rows, n_cols) = (el / dim_m1, dim_m1); + + let block_size = if n_cols < 1024 { 32 } else { 1024 }; + let cfg = LaunchConfig { + grid_dim: (n_rows as u32, 1, 1), + block_dim: (block_size, 1, 1), + shared_mem_bytes: 0, + }; + let func = dev.get_or_load_func(&kernel_name::("layernorm"), kernels::REDUCE)?; + // SAFETY: Set later by running the kernel. + let dst = unsafe { dev.alloc::(el) }.w()?; + let params = ( + &src, + &dst, + &alpha, + &beta, + n_cols as i32, + block_size as i32, + self.eps, + ); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(dst) + } + } + + use mcandle_core::backend::BackendStorage; + let dev = s1.device(); + let slice = S { eps: self.eps }.map(&s1.slice, l1, &s2.slice, l2, &s3.slice, l3, dev)?; + let dst = mcandle_core::cuda_backend::CudaStorage { + slice, + device: dev.clone(), + }; + Ok((dst, l1.shape().clone())) + } + + #[cfg(feature = "metal")] + fn metal_fwd( + &self, + s1: &mcandle_core::MetalStorage, + l1: &Layout, + s2: &mcandle_core::MetalStorage, + l2: &Layout, + s3: &mcandle_core::MetalStorage, + l3: &Layout, + ) -> Result<(mcandle_core::MetalStorage, Shape)> { + use mcandle_core::backend::BackendStorage; + let device = s1.device(); + let command_buffer = device.command_buffer()?; + let kernels = device.kernels(); + let name = match (s1.dtype(), s2.dtype(), s3.dtype()) { + (DType::F32, DType::F32, DType::F32) => "layernorm_f32", + (DType::F16, DType::F16, DType::F16) => "layernorm_f16", + (DType::BF16, DType::BF16, DType::BF16) => "layernorm_bf16", + (dt1, dt2, dt3) => { + mcandle_core::bail!("layernorm is not implemented for {dt1:?} {dt2:?} {dt3:?}") + } + }; + + if !(l1.is_contiguous() && l2.is_contiguous() && l3.is_contiguous()) { + mcandle_core::bail!("Non contiguous layernorm is not implemented"); + } + + let last_dim = l1.dims()[l1.shape().rank() - 1]; + let elem_count = l1.shape().elem_count(); + let output = device.new_buffer(elem_count, s1.dtype(), "layernorm")?; + mcandle_metal_kernels::call_layer_norm( + device.metal_device(), + &command_buffer, + kernels, + name, + elem_count, + last_dim, + self.eps, + s1.buffer(), + l1.start_offset() * s1.dtype().size_in_bytes(), + s2.buffer(), + l2.start_offset() * s2.dtype().size_in_bytes(), + s3.buffer(), + l3.start_offset() * s3.dtype().size_in_bytes(), + &output, + ) + .map_err(mcandle_core::Error::wrap)?; + let newstorage = + mcandle_core::MetalStorage::new(output, device.clone(), elem_count, s1.dtype()); + Ok((newstorage, l1.shape().clone())) + } +} + +pub fn layer_norm_slow(x: &Tensor, alpha: &Tensor, beta: &Tensor, eps: f32) -> Result { + let x_dtype = x.dtype(); + let internal_dtype = match x_dtype { + DType::F16 | DType::BF16 => DType::F32, + d => d, + }; + let hidden_size = x.dim(D::Minus1)?; + let x = x.to_dtype(internal_dtype)?; + let x = { + let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?; + x.broadcast_sub(&mean_x)? + }; + let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?; + let x_normed = x.broadcast_div(&(norm_x + eps as f64)?.sqrt()?)?; + x_normed + .to_dtype(x_dtype)? + .broadcast_mul(alpha)? + .broadcast_add(beta) +} + +pub fn layer_norm(xs: &Tensor, alpha: &Tensor, beta: &Tensor, eps: f32) -> Result { + let hidden_size_xs = xs.dim(D::Minus1)?; + let hidden_size_alpha = alpha.dims1()?; + let hidden_size_beta = beta.dims1()?; + if hidden_size_xs != hidden_size_alpha || hidden_size_xs != hidden_size_beta { + mcandle_core::bail!( + "shape mismatch in layer-norm src: {:?} alpha: {:?} beta: {:?}", + xs.shape(), + alpha.shape(), + beta.shape() + ) + } + xs.apply_op3_no_bwd(alpha, beta, &LayerNorm { eps }) +} + +// https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html +pub fn pixel_shuffle(xs: &Tensor, upscale_factor: usize) -> Result { + let (b_size, c, h, w) = xs.dims4()?; + let out_c = c / upscale_factor / upscale_factor; + xs.reshape((b_size, out_c, upscale_factor, upscale_factor, h, w))? + .permute((0, 1, 4, 2, 5, 3))? + .reshape((b_size, out_c, h * upscale_factor, w * upscale_factor)) +} + +pub fn pixel_unshuffle(xs: &Tensor, downscale_factor: usize) -> Result { + let (b_size, c, h, w) = xs.dims4()?; + let out_c = c * downscale_factor * downscale_factor; + xs.reshape(( + b_size, + c, + h / downscale_factor, + downscale_factor, + w / downscale_factor, + downscale_factor, + ))? + .permute((0, 1, 3, 5, 2, 4))? + .reshape((b_size, out_c, h / downscale_factor, w / downscale_factor)) +} + +// https://pytorch.org/docs/stable/generated/torch.nn.ReplicationPad2d.html +pub fn replication_pad2d(xs: &Tensor, pad: usize) -> Result { + match pad { + 0 => Ok(xs.clone()), + 1 => { + let (_b_size, _c, h, w) = xs.dims4()?; + let (first, last) = (xs.narrow(3, 0, 1)?, xs.narrow(3, w - 1, 1)?); + let xs = Tensor::cat(&[&first, xs, &last], 3)?; + let (first, last) = (xs.narrow(2, 0, 1)?, xs.narrow(2, h - 1, 1)?); + Tensor::cat(&[&first, &xs, &last], 2) + } + n => mcandle_core::bail!("replication-pad with a size of {n} is not supported"), + } +} + +#[cfg(feature = "cuda")] +pub fn kvconcat(ltensor: &Tensor, rtensor: &Tensor, concat_dim: usize) -> Result { + if !ltensor.device().is_cuda() { + return Tensor::cat(&[ltensor, &rtensor], concat_dim as usize)?.contiguous(); + } + use mcandle_core::cuda_backend::KVConcat; + let op = KVConcat { concat_dim }; + //inputs for kvconcat must be contiguous tensors + if ltensor.is_contiguous() && rtensor.is_contiguous() { + ltensor.apply_op2(&rtensor, op) + } else if ltensor.is_contiguous() { + ltensor.apply_op2(&rtensor.contiguous()?, op) + } else if rtensor.is_contiguous() { + let ltensor = ltensor.contiguous()?; + ltensor.apply_op2(&rtensor, op) + } else { + let ltensor = ltensor.contiguous()?; + let rtensor = rtensor.contiguous()?; + ltensor.apply_op2(&rtensor, op) + } +} + +#[cfg(not(feature = "cuda"))] +pub fn kvconcat(ltensor: &Tensor, rtensor: &Tensor, concat_dim: i32) -> Result { + Tensor::cat(&[ltensor, rtensor], concat_dim as usize)?.contiguous() +} + +#[derive(Clone, Debug)] +pub struct Identity; + +impl Identity { + pub fn new() -> Identity { + Self + } +} + +impl Default for Identity { + fn default() -> Self { + Self + } +} + +impl Module for Identity { + fn forward(&self, xs: &Tensor) -> Result { + Ok(xs.clone()) + } +} + +#[allow(dead_code)] +struct Sdpa { + scale: f32, + softcapping: f32, +} + +impl mcandle_core::CustomOp3 for Sdpa { + fn name(&self) -> &'static str { + "metal-sdpa" + } + + fn cpu_fwd( + &self, + _s1: &CpuStorage, + _l1: &Layout, + _s2: &CpuStorage, + _l2: &Layout, + _s3: &CpuStorage, + _l3: &Layout, + ) -> Result<(CpuStorage, Shape)> { + mcandle_core::bail!("SDPA has no cpu impl") + } + + #[cfg(feature = "metal")] + fn metal_fwd( + &self, + q: &mcandle_core::MetalStorage, + q_l: &Layout, + k: &mcandle_core::MetalStorage, + k_l: &Layout, + v: &mcandle_core::MetalStorage, + v_l: &Layout, + ) -> Result<(mcandle_core::MetalStorage, Shape)> { + use mcandle_core::backend::BackendStorage; + use mcandle_metal_kernels::SdpaDType; + + let device = q.device(); + + let out_dims = vec![q_l.dim(0)?, q_l.dim(1)?, q_l.dim(2)?, v_l.dim(3)?]; + let elem_count: usize = out_dims.iter().product(); + + let output = device.new_buffer(elem_count, q.dtype(), "sdpa_o")?; + + // q,k must have matching emb dim + if q_l.dim(D::Minus1)? != k_l.dim(D::Minus1)? { + mcandle_core::bail!("`q` and `k` last dims must match"); + } + + // k,v must have matching n kv heads + if v_l.dim(D::Minus(3))? != k_l.dim(D::Minus(3))? { + mcandle_core::bail!("`k` and `v` head dims must match"); + } + + // n_heads % n_kv_heads == 0; n_heads >= 1, n_kv_heads >= 1. + if q_l.dim(D::Minus(3))? % k_l.dim(D::Minus(3))? != 0 { + mcandle_core::bail!("query `n_heads` must be a multiple of `n_kv_heads`"); + } + + let k_head = k_l.dim(D::Minus1)?; + let q_head = q_l.dim(D::Minus1)?; + let q_seq = q_l.dim(2)?; + + let mut implementation_supports_use_case = q_head == k_head; + let supported_head_dim = + q_head == 32 || q_head == 64 || q_head == 96 || q_head == 128 || q_head == 256; + + const SDPA_FULL_THRESHOLD: usize = 2; + + let supports_sdpa_full = + q_seq >= SDPA_FULL_THRESHOLD && supported_head_dim && q_head == k_head; + let supports_sdpa_vector = q_seq == 1 && supported_head_dim; + + implementation_supports_use_case &= supports_sdpa_full || supports_sdpa_vector; + + if !supported_head_dim { + mcandle_core::bail!( + "Meta SDPA does not support q head dim {q_head}: q dims {:?}, k dims {:?}, v dims {:?}.", + q_l.dims(), + k_l.dims(), + v_l.dims() + ); + } + if !implementation_supports_use_case { + mcandle_core::bail!( + "Meta SDPA does not support q dims {:?}, k dims {:?}, v dims {:?}.", + q_l.dims(), + k_l.dims(), + v_l.dims() + ); + } + + for t in [k.dtype(), v.dtype()] { + if q.dtype() != t { + mcandle_core::bail!("all q, k, v dtypes must match."); + } + } + + let itype = match q.dtype() { + DType::BF16 => SdpaDType::BF16, + DType::F16 => SdpaDType::F16, + DType::F32 => SdpaDType::F32, + other => mcandle_core::bail!("unsupported sdpa type {other:?}"), + }; + + let command_buffer = q.device().command_buffer()?; + if supports_sdpa_vector { + command_buffer.set_label("vector_attention"); + mcandle_metal_kernels::call_sdpa_vector( + q.device().device(), + &command_buffer, + q.device().kernels(), + q_l.start_offset(), + q_l.dims(), + q.buffer(), + k_l.start_offset(), + k_l.dims(), + k_l.stride(), + k.buffer(), + v_l.start_offset(), + v_l.stride(), + v.buffer(), + &output, + self.scale, + self.softcapping, + itype, + ) + .map_err(mcandle_core::Error::wrap)?; + } else if supports_sdpa_full { + if q_l.dim(2)? != k_l.dim(2)? { + mcandle_core::bail!( + "query and key sequence length must be equal if using full metal sdpa" + ) + } + + command_buffer.set_label("full_attention"); + mcandle_metal_kernels::call_sdpa_full( + q.device().device(), + &command_buffer, + q.device().kernels(), + q_l.start_offset(), + q_l.dims(), + q.buffer(), + k_l.start_offset(), + k.buffer(), + v_l.start_offset(), + v.buffer(), + &output, + self.scale, + self.softcapping, + itype, + ) + .map_err(mcandle_core::Error::wrap)?; + } else { + mcandle_core::bail!("must be vector or full sdpa kernel"); + } + + let newstorage = + mcandle_core::MetalStorage::new(output, device.clone(), elem_count, q.dtype()); + Ok((newstorage, Shape::from_dims(&out_dims))) + } +} + +/// Scaled dot product attention with a fused kernel. +/// +/// Computes softmax(qk^T*scale)v. +/// +/// **Inputs shapes:** +/// - `q`: (bs, qhead, seq, hidden) +/// - `k`: (bs, kv_head, kv_seq, hidden) +/// - `k`: (bs, kv_head, kv_seq, v_hidden) +/// - `scale` is applied before softmax. +/// - If `softcapping` != 1.0: +/// - Computation is: softmax(tanh(qk^T*scale/cap)*cap)v +/// +/// **Output shape:** (bs, qhead, seq, v_hidden) +/// +/// **Supported head dims:** 32, 64, 96, 128, 256. +/// +/// ## On Metal: +/// - If `seq` == 1: +/// - Use a vectorized kernel +/// - Supports `seq` != `kv_seq` (cross attn. support) +/// - Supports GQA when `qhead` is a multiple of `kv_head` +/// - Otherwise: +/// - Use an alternate kernel +/// - Requires `seq` == `kv_seq` +/// - GQA is not supported (requires `qhead` == `kv_head`) +pub fn sdpa(q: &Tensor, k: &Tensor, v: &Tensor, scale: f32, softcapping: f32) -> Result { + q.apply_op3_no_bwd(k, v, &Sdpa { scale, softcapping }) +} diff --git a/mcandle/mcandle-nn/src/optim.rs b/mcandle/mcandle-nn/src/optim.rs new file mode 100644 index 000000000..1e0378a87 --- /dev/null +++ b/mcandle/mcandle-nn/src/optim.rs @@ -0,0 +1,201 @@ +//! Various optimization algorithms. +use mcandle_core::{Result, Tensor, Var}; + +/// The interface optimizers should implement. +pub trait Optimizer: Sized { + type Config: Sized; + + fn new(vars: Vec, config: Self::Config) -> Result; + + fn step(&mut self, grads: &mcandle_core::backprop::GradStore) -> Result<()>; + + fn learning_rate(&self) -> f64; + + fn set_learning_rate(&mut self, lr: f64); + + fn empty(config: Self::Config) -> Result { + Self::new(vec![], config) + } + + fn backward_step(&mut self, loss: &Tensor) -> Result<()> { + let grads = loss.backward()?; + self.step(&grads) + } + + fn from_slice(vars: &[&Var], config: Self::Config) -> Result { + let vars: Vec<_> = vars.iter().map(|&v| v.clone()).collect(); + Self::new(vars, config) + } +} + +/// Optimizer for Stochastic Gradient Descent. +/// +/// Contrary to the PyTorch implementation of SGD, this version does not support momentum. +#[derive(Debug)] +pub struct SGD { + vars: Vec, + learning_rate: f64, +} + +impl Optimizer for SGD { + type Config = f64; + + fn new(vars: Vec, learning_rate: f64) -> Result { + let vars = vars + .into_iter() + .filter(|var| var.dtype().is_float()) + .collect(); + Ok(Self { + vars, + learning_rate, + }) + } + + fn learning_rate(&self) -> f64 { + self.learning_rate + } + + fn step(&mut self, grads: &mcandle_core::backprop::GradStore) -> Result<()> { + for var in self.vars.iter() { + if let Some(grad) = grads.get(var) { + var.set(&var.sub(&(grad * self.learning_rate)?)?)?; + } + } + Ok(()) + } + + fn set_learning_rate(&mut self, lr: f64) { + self.learning_rate = lr + } +} + +impl SGD { + pub fn into_inner(self) -> Vec { + self.vars + } + + pub fn push(&mut self, var: &Var) { + self.vars.push(var.clone()) + } +} + +#[derive(Clone, Debug)] +pub struct ParamsAdamW { + pub lr: f64, + pub beta1: f64, + pub beta2: f64, + pub eps: f64, + pub weight_decay: f64, +} + +impl Default for ParamsAdamW { + fn default() -> Self { + Self { + lr: 0.001, + beta1: 0.9, + beta2: 0.999, + eps: 1e-8, + weight_decay: 0.01, + } + } +} + +#[derive(Debug)] +struct VarAdamW { + var: Var, + first_moment: Var, + second_moment: Var, +} + +#[derive(Debug)] +pub struct AdamW { + vars: Vec, + step_t: usize, + params: ParamsAdamW, +} + +impl Optimizer for AdamW { + type Config = ParamsAdamW; + + fn new(vars: Vec, params: ParamsAdamW) -> Result { + let vars = vars + .into_iter() + .filter(|var| var.dtype().is_float()) + .map(|var| { + let dtype = var.dtype(); + let shape = var.shape(); + let device = var.device(); + let first_moment = Var::zeros(shape, dtype, device)?; + let second_moment = Var::zeros(shape, dtype, device)?; + Ok(VarAdamW { + var, + first_moment, + second_moment, + }) + }) + .collect::>>()?; + Ok(Self { + vars, + params, + step_t: 0, + }) + } + + fn learning_rate(&self) -> f64 { + self.params.lr + } + + fn set_learning_rate(&mut self, lr: f64) { + self.params.lr = lr + } + + fn step(&mut self, grads: &mcandle_core::backprop::GradStore) -> Result<()> { + self.step_t += 1; + let lr = self.params.lr; + let lambda = self.params.weight_decay; + let lr_lambda = lr * lambda; + let beta1 = self.params.beta1; + let beta2 = self.params.beta2; + let scale_m = 1f64 / (1f64 - beta1.powi(self.step_t as i32)); + let scale_v = 1f64 / (1f64 - beta2.powi(self.step_t as i32)); + for var in self.vars.iter() { + let theta = &var.var; + let m = &var.first_moment; + let v = &var.second_moment; + if let Some(g) = grads.get(theta) { + // This involves locking 3 RWLocks per params, if the parameters are large this + // should not be an issue but this may be problematic with models with lots of + // small parameters. + let next_m = ((m.as_tensor() * beta1)? + (g * (1.0 - beta1))?)?; + let next_v = ((v.as_tensor() * beta2)? + (g.sqr()? * (1.0 - beta2))?)?; + let m_hat = (&next_m * scale_m)?; + let v_hat = (&next_v * scale_v)?; + let next_theta = (theta.as_tensor() * (1f64 - lr_lambda))?; + let adjusted_grad = (m_hat / (v_hat.sqrt()? + self.params.eps)?)?; + let next_theta = (next_theta - (adjusted_grad * lr)?)?; + m.set(&next_m)?; + v.set(&next_v)?; + theta.set(&next_theta)?; + } + } + Ok(()) + } +} + +impl AdamW { + pub fn new_lr(vars: Vec, learning_rate: f64) -> Result { + let params = ParamsAdamW { + lr: learning_rate, + ..ParamsAdamW::default() + }; + Self::new(vars, params) + } + + pub fn params(&self) -> &ParamsAdamW { + &self.params + } + + pub fn set_params(&mut self, params: ParamsAdamW) { + self.params = params; + } +} diff --git a/mcandle/mcandle-nn/src/rnn.rs b/mcandle/mcandle-nn/src/rnn.rs new file mode 100644 index 000000000..14b1979a4 --- /dev/null +++ b/mcandle/mcandle-nn/src/rnn.rs @@ -0,0 +1,388 @@ +//! Recurrent Neural Networks +use mcandle_core::{DType, Device, IndexOp, Result, Tensor}; + +/// Trait for Recurrent Neural Networks. +#[allow(clippy::upper_case_acronyms)] +pub trait RNN { + type State: Clone; + + /// A zero state from which the recurrent network is usually initialized. + fn zero_state(&self, batch_dim: usize) -> Result; + + /// Applies a single step of the recurrent network. + /// + /// The input should have dimensions [batch_size, features]. + fn step(&self, input: &Tensor, state: &Self::State) -> Result; + + /// Applies multiple steps of the recurrent network. + /// + /// The input should have dimensions [batch_size, seq_len, features]. + /// The initial state is the result of applying zero_state. + fn seq(&self, input: &Tensor) -> Result> { + let batch_dim = input.dim(0)?; + let state = self.zero_state(batch_dim)?; + self.seq_init(input, &state) + } + + /// Applies multiple steps of the recurrent network. + /// + /// The input should have dimensions [batch_size, seq_len, features]. + fn seq_init(&self, input: &Tensor, init_state: &Self::State) -> Result> { + let (_b_size, seq_len, _features) = input.dims3()?; + let mut output = Vec::with_capacity(seq_len); + for seq_index in 0..seq_len { + let input = input.i((.., seq_index, ..))?.contiguous()?; + let state = if seq_index == 0 { + self.step(&input, init_state)? + } else { + self.step(&input, &output[seq_index - 1])? + }; + output.push(state); + } + Ok(output) + } + + /// Converts a sequence of state to a tensor. + fn states_to_tensor(&self, states: &[Self::State]) -> Result; +} + +/// The state for a LSTM network, this contains two tensors. +#[allow(clippy::upper_case_acronyms)] +#[derive(Debug, Clone)] +pub struct LSTMState { + pub h: Tensor, + pub c: Tensor, +} + +impl LSTMState { + pub fn new(h: Tensor, c: Tensor) -> Self { + LSTMState { h, c } + } + + /// The hidden state vector, which is also the output of the LSTM. + pub fn h(&self) -> &Tensor { + &self.h + } + + /// The cell state vector. + pub fn c(&self) -> &Tensor { + &self.c + } +} + +#[derive(Debug, Clone, Copy)] +pub enum Direction { + Forward, + Backward, +} + +#[allow(clippy::upper_case_acronyms)] +#[derive(Debug, Clone, Copy)] +pub struct LSTMConfig { + pub w_ih_init: super::Init, + pub w_hh_init: super::Init, + pub b_ih_init: Option, + pub b_hh_init: Option, + pub layer_idx: usize, + pub direction: Direction, +} + +impl Default for LSTMConfig { + fn default() -> Self { + Self { + w_ih_init: super::init::DEFAULT_KAIMING_UNIFORM, + w_hh_init: super::init::DEFAULT_KAIMING_UNIFORM, + b_ih_init: Some(super::Init::Const(0.)), + b_hh_init: Some(super::Init::Const(0.)), + layer_idx: 0, + direction: Direction::Forward, + } + } +} + +impl LSTMConfig { + pub fn default_no_bias() -> Self { + Self { + w_ih_init: super::init::DEFAULT_KAIMING_UNIFORM, + w_hh_init: super::init::DEFAULT_KAIMING_UNIFORM, + b_ih_init: None, + b_hh_init: None, + layer_idx: 0, + direction: Direction::Forward, + } + } +} + +/// A Long Short-Term Memory (LSTM) layer. +/// +/// +#[allow(clippy::upper_case_acronyms)] +#[derive(Clone, Debug)] +pub struct LSTM { + w_ih: Tensor, + w_hh: Tensor, + b_ih: Option, + b_hh: Option, + hidden_dim: usize, + config: LSTMConfig, + device: Device, + dtype: DType, +} + +impl LSTM { + /// Creates a LSTM layer. + pub fn new( + in_dim: usize, + hidden_dim: usize, + config: LSTMConfig, + vb: crate::VarBuilder, + ) -> Result { + let layer_idx = config.layer_idx; + let direction_str = match config.direction { + Direction::Forward => "", + Direction::Backward => "_reverse", + }; + let w_ih = vb.get_with_hints( + (4 * hidden_dim, in_dim), + &format!("weight_ih_l{layer_idx}{direction_str}"), // Only a single layer is supported. + config.w_ih_init, + )?; + let w_hh = vb.get_with_hints( + (4 * hidden_dim, hidden_dim), + &format!("weight_hh_l{layer_idx}{direction_str}"), // Only a single layer is supported. + config.w_hh_init, + )?; + let b_ih = match config.b_ih_init { + Some(init) => Some(vb.get_with_hints( + 4 * hidden_dim, + &format!("bias_ih_l{layer_idx}{direction_str}"), + init, + )?), + None => None, + }; + let b_hh = match config.b_hh_init { + Some(init) => Some(vb.get_with_hints( + 4 * hidden_dim, + &format!("bias_hh_l{layer_idx}{direction_str}"), + init, + )?), + None => None, + }; + Ok(Self { + w_ih, + w_hh, + b_ih, + b_hh, + hidden_dim, + config, + device: vb.device().clone(), + dtype: vb.dtype(), + }) + } + + pub fn config(&self) -> &LSTMConfig { + &self.config + } +} + +/// Creates a LSTM layer. +pub fn lstm( + in_dim: usize, + hidden_dim: usize, + config: LSTMConfig, + vb: crate::VarBuilder, +) -> Result { + LSTM::new(in_dim, hidden_dim, config, vb) +} + +impl RNN for LSTM { + type State = LSTMState; + + fn zero_state(&self, batch_dim: usize) -> Result { + let zeros = + Tensor::zeros((batch_dim, self.hidden_dim), self.dtype, &self.device)?.contiguous()?; + Ok(Self::State { + h: zeros.clone(), + c: zeros.clone(), + }) + } + + fn step(&self, input: &Tensor, in_state: &Self::State) -> Result { + let w_ih = input.matmul(&self.w_ih.t()?)?; + let w_hh = in_state.h.matmul(&self.w_hh.t()?)?; + let w_ih = match &self.b_ih { + None => w_ih, + Some(b_ih) => w_ih.broadcast_add(b_ih)?, + }; + let w_hh = match &self.b_hh { + None => w_hh, + Some(b_hh) => w_hh.broadcast_add(b_hh)?, + }; + let chunks = (&w_ih + &w_hh)?.chunk(4, 1)?; + let in_gate = crate::ops::sigmoid(&chunks[0])?; + let forget_gate = crate::ops::sigmoid(&chunks[1])?; + let cell_gate = chunks[2].tanh()?; + let out_gate = crate::ops::sigmoid(&chunks[3])?; + + let next_c = ((forget_gate * &in_state.c)? + (in_gate * cell_gate)?)?; + let next_h = (out_gate * next_c.tanh()?)?; + Ok(LSTMState { + c: next_c, + h: next_h, + }) + } + + fn states_to_tensor(&self, states: &[Self::State]) -> Result { + let states = states.iter().map(|s| s.h.clone()).collect::>(); + Tensor::stack(&states, 1) + } +} + +/// The state for a GRU network, this contains a single tensor. +#[allow(clippy::upper_case_acronyms)] +#[derive(Debug, Clone)] +pub struct GRUState { + pub h: Tensor, +} + +impl GRUState { + /// The hidden state vector, which is also the output of the LSTM. + pub fn h(&self) -> &Tensor { + &self.h + } +} + +#[allow(clippy::upper_case_acronyms)] +#[derive(Debug, Clone, Copy)] +pub struct GRUConfig { + pub w_ih_init: super::Init, + pub w_hh_init: super::Init, + pub b_ih_init: Option, + pub b_hh_init: Option, +} + +impl Default for GRUConfig { + fn default() -> Self { + Self { + w_ih_init: super::init::DEFAULT_KAIMING_UNIFORM, + w_hh_init: super::init::DEFAULT_KAIMING_UNIFORM, + b_ih_init: Some(super::Init::Const(0.)), + b_hh_init: Some(super::Init::Const(0.)), + } + } +} + +impl GRUConfig { + pub fn default_no_bias() -> Self { + Self { + w_ih_init: super::init::DEFAULT_KAIMING_UNIFORM, + w_hh_init: super::init::DEFAULT_KAIMING_UNIFORM, + b_ih_init: None, + b_hh_init: None, + } + } +} + +/// A Gated Recurrent Unit (GRU) layer. +/// +/// +#[allow(clippy::upper_case_acronyms)] +#[derive(Clone, Debug)] +pub struct GRU { + w_ih: Tensor, + w_hh: Tensor, + b_ih: Option, + b_hh: Option, + hidden_dim: usize, + config: GRUConfig, + device: Device, + dtype: DType, +} + +impl GRU { + /// Creates a GRU layer. + pub fn new( + in_dim: usize, + hidden_dim: usize, + config: GRUConfig, + vb: crate::VarBuilder, + ) -> Result { + let w_ih = vb.get_with_hints( + (3 * hidden_dim, in_dim), + "weight_ih_l0", // Only a single layer is supported. + config.w_ih_init, + )?; + let w_hh = vb.get_with_hints( + (3 * hidden_dim, hidden_dim), + "weight_hh_l0", // Only a single layer is supported. + config.w_hh_init, + )?; + let b_ih = match config.b_ih_init { + Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_ih_l0", init)?), + None => None, + }; + let b_hh = match config.b_hh_init { + Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_hh_l0", init)?), + None => None, + }; + Ok(Self { + w_ih, + w_hh, + b_ih, + b_hh, + hidden_dim, + config, + device: vb.device().clone(), + dtype: vb.dtype(), + }) + } + + pub fn config(&self) -> &GRUConfig { + &self.config + } +} + +pub fn gru( + in_dim: usize, + hidden_dim: usize, + config: GRUConfig, + vb: crate::VarBuilder, +) -> Result { + GRU::new(in_dim, hidden_dim, config, vb) +} + +impl RNN for GRU { + type State = GRUState; + + fn zero_state(&self, batch_dim: usize) -> Result { + let h = + Tensor::zeros((batch_dim, self.hidden_dim), self.dtype, &self.device)?.contiguous()?; + Ok(Self::State { h }) + } + + fn step(&self, input: &Tensor, in_state: &Self::State) -> Result { + let w_ih = input.matmul(&self.w_ih.t()?)?; + let w_hh = in_state.h.matmul(&self.w_hh.t()?)?; + let w_ih = match &self.b_ih { + None => w_ih, + Some(b_ih) => w_ih.broadcast_add(b_ih)?, + }; + let w_hh = match &self.b_hh { + None => w_hh, + Some(b_hh) => w_hh.broadcast_add(b_hh)?, + }; + let chunks_ih = w_ih.chunk(3, 1)?; + let chunks_hh = w_hh.chunk(3, 1)?; + let r_gate = crate::ops::sigmoid(&(&chunks_ih[0] + &chunks_hh[0])?)?; + let z_gate = crate::ops::sigmoid(&(&chunks_ih[1] + &chunks_hh[1])?)?; + let n_gate = (&chunks_ih[2] + (r_gate * &chunks_hh[2])?)?.tanh(); + + let next_h = ((&z_gate * &in_state.h)? - ((&z_gate - 1.)? * n_gate)?)?; + Ok(GRUState { h: next_h }) + } + + fn states_to_tensor(&self, states: &[Self::State]) -> Result { + let states = states.iter().map(|s| s.h.clone()).collect::>(); + Tensor::cat(&states, 1) + } +} diff --git a/mcandle/mcandle-nn/src/rope.rs b/mcandle/mcandle-nn/src/rope.rs new file mode 100644 index 000000000..e8ff4d49b --- /dev/null +++ b/mcandle/mcandle-nn/src/rope.rs @@ -0,0 +1,328 @@ +use std::iter::zip; + +#[allow(unused_imports)] +use mcandle_core::{ + backend::BackendStorage, CudaDevice, CudaStorage, DType, Device, IndexOp, Module, Result, + Storage, Tensor, WithDType, D, +}; + +#[cfg(feature = "cuda")] +use mcandle_core::cuda_backend::{ + cudarc::driver::{DeviceRepr, LaunchAsync, LaunchConfig}, + kernel_name, kernels, CudaDType, +}; + +#[derive(Debug, Clone)] +#[allow(dead_code)] +pub struct RotaryEmbedding { + cos: Tensor, + sin: Tensor, + head_size: usize, + cache: Tensor, + is_gpt_neox: bool, +} + +impl RotaryEmbedding { + pub fn new( + base: f32, + head_dim: usize, + max_position_embeddings: usize, + device: &Device, + is_gpt_neox: bool, + dtype: DType, + ) -> Result { + let theta: Vec<_> = (0..head_dim) + .step_by(2) + .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32)) + .collect(); + let theta_len = theta.len(); + let theta = Tensor::from_vec(theta, (1, theta_len), device)?.to_dtype(DType::F32)?; + let idx_theta = Tensor::arange(0, max_position_embeddings as u32, device)? + .to_dtype(DType::F32)? + .reshape((max_position_embeddings, 1))? + .matmul(&theta)?; + let cos = idx_theta.cos()?; + let sin = idx_theta.sin()?; + Ok(Self { + head_size: head_dim, + cos: if is_gpt_neox { + Tensor::cat( + &[cos.clone().to_dtype(dtype)?, cos.clone().to_dtype(dtype)?], + D::Minus1, + )? + } else { + cos.clone().to_dtype(dtype)? + }, + sin: if is_gpt_neox { + Tensor::cat( + &[sin.clone().to_dtype(dtype)?, sin.clone().to_dtype(dtype)?], + D::Minus1, + )? + } else { + sin.clone().to_dtype(dtype)? + }, + cache: Tensor::cat(&[cos.clone(), sin.clone()], D::Minus1)? + .contiguous()? + .to_dtype(dtype)?, + is_gpt_neox, + }) + } + + pub fn new_partial( + base: f32, + head_dim: usize, + rot_dim: usize, + max_position_embeddings: usize, + device: &Device, + is_gpt_neox: bool, + dtype: DType, + ) -> Result { + let theta: Vec<_> = (0..rot_dim) + .step_by(2) + .map(|i| 1f32 / base.powf(i as f32 / rot_dim as f32)) + .collect(); + let theta_len = theta.len(); + let theta = Tensor::from_vec(theta, (1, theta_len), device)?.to_dtype(DType::F32)?; + let idx_theta = Tensor::arange(0, max_position_embeddings as u32, device)? + .to_dtype(DType::F32)? + .reshape((max_position_embeddings, 1))? + .matmul(&theta)?; + let cos = idx_theta.cos()?; + let sin = idx_theta.sin()?; + Ok(Self { + head_size: head_dim, + cos: if is_gpt_neox { + Tensor::cat( + &[cos.clone().to_dtype(dtype)?, cos.clone().to_dtype(dtype)?], + D::Minus1, + )? + } else { + cos.clone().to_dtype(dtype)? + }, + sin: if is_gpt_neox { + Tensor::cat( + &[sin.clone().to_dtype(dtype)?, sin.clone().to_dtype(dtype)?], + D::Minus1, + )? + } else { + sin.clone().to_dtype(dtype)? + }, + cache: Tensor::cat(&[cos.clone(), sin.clone()], D::Minus1)? + .contiguous()? + .to_dtype(dtype)?, + is_gpt_neox, + }) + } + + #[cfg(feature = "cuda")] + fn execute_dtype( + &self, + dev: &CudaDevice, + q_storage: &CudaStorage, + k_storage: &CudaStorage, + q: &Tensor, + k: &Tensor, + cache_storage: &CudaStorage, + pos_storage: &CudaStorage, + ) -> Result<()> { + use mcandle_core::cuda_backend::WrapErr; + + let num_tokens = q.dim(0)?; + let rot_dim = self.cache.dim(1)?; + let num_heads = q.dim(1)?; + let num_kv_heads = k.dim(1)?; + let q_stride = q.stride()[0]; + let k_stride = k.stride()[0]; + + let func = dev.get_or_load_func( + &if self.is_gpt_neox { + kernel_name::("rotary_embedding_kernel_neox") + } else { + kernel_name::("rotary_embedding_kernel") + }, + kernels::FUSED_ROPE, + )?; + + let cfg = LaunchConfig { + grid_dim: (num_tokens as u32, 1, 1), + block_dim: (512.min((num_heads * rot_dim / 2) as u32), 1, 1), + shared_mem_bytes: 0, + }; + + let params = ( + pos_storage.as_cuda_slice::()?, + q_storage.as_cuda_slice::()?, + k_storage.as_cuda_slice::()?, + cache_storage.as_cuda_slice::()?, + rot_dim as i32, + q_stride as i64, + k_stride as i64, + num_heads as i32, + num_kv_heads as i32, + self.head_size as i32, + ); + unsafe { func.launch(cfg, params) }.w()?; + + Ok(()) + } + + #[cfg(feature = "cuda")] + fn fused_rope( + &self, + dev: &CudaDevice, + positions: &Tensor, + q: &Tensor, + k: &Tensor, + ) -> Result<()> { + let cache_type = self.cache.dtype(); + match ( + &*q.storage_and_layout().0, + &*k.storage_and_layout().0, + &*self.cache.storage_and_layout().0, + &*positions.storage_and_layout().0, + ) { + ( + Storage::Cuda(q_storage), + Storage::Cuda(k_storage), + Storage::Cuda(cache_storage), + Storage::Cuda(pos_storage), + ) => { + return match (q.dtype(), k.dtype(), cache_type) { + (DType::BF16, DType::BF16, DType::BF16) => self.execute_dtype::( + &dev, + q_storage, + k_storage, + q, + k, + cache_storage, + pos_storage, + ), + (DType::F16, DType::F16, DType::F16) => self.execute_dtype::( + &dev, + q_storage, + k_storage, + q, + k, + cache_storage, + pos_storage, + ), + (DType::F32, DType::F32, DType::F32) => self.execute_dtype::( + &dev, + q_storage, + k_storage, + q, + k, + cache_storage, + pos_storage, + ), + (DType::F64, DType::F64, DType::F64) => self.execute_dtype::( + &dev, + q_storage, + k_storage, + q, + k, + cache_storage, + pos_storage, + ), + _ => mcandle_core::bail!( + "DType mismatch in fused RotaryEmbedding q={:?}, k={:?}, cache={:?}", + q.dtype(), + k.dtype(), + cache_type + ), + } + } + _ => unreachable!(), + }; + } + + /// This may modify the tensors in place! + #[allow(unused_variables)] + pub fn forward( + &self, + positions: &[usize], + positions_kernel: &Tensor, + q: &mut Tensor, + k: &mut Tensor, + b_sz: usize, + ) -> Result<()> { + match (q.device(), k.device()) { + #[cfg(feature = "cuda")] + (Device::Cuda(dev), Device::Cuda(_)) => { + self.fused_rope(dev, positions_kernel, &*q, &*k)?; + } + + _ => { + *q = self.apply_rotary_emb(&*q, positions, b_sz)?; + *k = self.apply_rotary_emb(&*k, positions, b_sz)?; + } + }; + Ok(()) + } + + fn apply_rotary_emb( + &self, + x: &Tensor, + seqlen_offsets: &[usize], + b_sz: usize, + ) -> Result { + let (b_sz_seq_len, h, n_embd) = x.dims3()?; + let x = x + .reshape((b_sz, b_sz_seq_len / b_sz, h, n_embd))? + .transpose(1, 2)?; + + fn rotate_half(xs: &Tensor) -> Result { + let last_dim = xs.dim(D::Minus1)?; + let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?; + let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?; + Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1) + } + let (b_sz, n_head, seq_len, _n_embd) = x.dims4()?; + if self.is_gpt_neox { + let mut embeds = Vec::new(); + for (b, seqlen_offset) in zip(0..b_sz, seqlen_offsets) { + let cos = self.cos.narrow(0, *seqlen_offset, seq_len)?; + let sin = self.sin.narrow(0, *seqlen_offset, seq_len)?; + let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) + let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) + let x_b = x.i(b)?.unsqueeze(0)?; + let embed = (x_b.broadcast_mul(&cos)? + rotate_half(&x_b)?.broadcast_mul(&sin)?)?; + embeds.push(embed); + } + Tensor::cat(&embeds, 0) + } else { + let mut ropes = Vec::new(); + let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?; + for (b, seqlen_offset) in zip(0..b_sz, seqlen_offsets) { + let cos = self.cos.narrow(0, *seqlen_offset, seq_len)?.reshape(( + seq_len, + n_embd / 2, + 1, + ))?; + let sin = self.sin.narrow(0, *seqlen_offset, seq_len)?.reshape(( + seq_len, + n_embd / 2, + 1, + ))?; + let cos = cos.broadcast_as((1, 1, seq_len, n_embd / 2, 1))?; + let sin = sin.broadcast_as((1, 1, seq_len, n_embd / 2, 1))?; + // This mimics the llama.cpp behavior. + // https://github.com/ggerganov/llama.cpp/blob/1f0bccb27929e261744c979bc75114955da49e98/ggml.c#L12104-L12105 + // The x0 and x1 value are interleaved on the n_embd (= head_dim) dimension. + // The resulting y0 and y1 are also interleaved with: + // y0 = x0*cos - x1*sin + // y1 = x0*sin + x1*cos + let x_b = x.i(b)?.unsqueeze(0)?; + let x_b = x_b.reshape((1, n_head, seq_len, n_embd / 2, 2))?; + let x0 = x_b.narrow(D::Minus1, 0, 1)?; + let x1 = x_b.narrow(D::Minus1, 1, 1)?; + let y0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?; + let y1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?; + let rope = Tensor::cat(&[y0, y1], D::Minus1)?; + let rope = rope.flatten_from(D::Minus2)?; + ropes.push(rope); + } + Tensor::cat(&ropes, 0) + } + } +} diff --git a/mcandle/mcandle-nn/src/rotary_emb.rs b/mcandle/mcandle-nn/src/rotary_emb.rs new file mode 100644 index 000000000..f2bb83930 --- /dev/null +++ b/mcandle/mcandle-nn/src/rotary_emb.rs @@ -0,0 +1,732 @@ +//! Rotary Embeddings +//! +use mcandle_core::{CpuStorage, Layout, Result, Shape, Tensor, D}; +use rayon::prelude::*; + +/// Interleaved variant of rotary embeddings. +/// The x0 and x1 value are interleaved on the n_embd (= head_dim) dimension. +/// The resulting y0 and y1 are also interleaved with: +/// y0 = x0*cos - x1*sin +/// y1 = x0*sin + x1*cos +#[derive(Debug, Clone)] +struct RotaryEmbI; + +impl mcandle_core::CustomOp3 for RotaryEmbI { + fn name(&self) -> &'static str { + "rotary-emb-int" + } + + fn cpu_fwd( + &self, + s1: &CpuStorage, + l1: &Layout, + s2: &CpuStorage, + l2: &Layout, + s3: &CpuStorage, + l3: &Layout, + ) -> Result<(CpuStorage, Shape)> { + fn inner( + src: &[T], + l_src: &Layout, + cos: &[T], + l_cos: &Layout, + sin: &[T], + l_sin: &Layout, + ) -> Result<(CpuStorage, Shape)> { + let src = match l_src.contiguous_offsets() { + None => mcandle_core::bail!("input src has to be contiguous"), + Some((o1, o2)) => &src[o1..o2], + }; + let cos = match l_cos.contiguous_offsets() { + None => mcandle_core::bail!("input cos has to be contiguous"), + Some((o1, o2)) => &cos[o1..o2], + }; + let sin = match l_sin.contiguous_offsets() { + None => mcandle_core::bail!("input sin has to be contiguous"), + Some((o1, o2)) => &sin[o1..o2], + }; + let (b, h, t, d) = l_src.shape().dims4()?; + let el_count = b * h * t * d; + let mut dst = vec![T::zero(); el_count]; + src.par_chunks(t * d) + .zip(dst.par_chunks_mut(t * d)) + .for_each(|(src, dst)| { + for i_over_2 in 0..t * d / 2 { + let i = 2 * i_over_2; + dst[i] = src[i] * cos[i_over_2] - src[i + 1] * sin[i_over_2]; + dst[i + 1] = src[i] * sin[i_over_2] + src[i + 1] * cos[i_over_2]; + } + }); + let storage = mcandle_core::WithDType::to_cpu_storage_owned(dst); + Ok((storage, (b, h, t, d).into())) + } + + use mcandle_core::backend::BackendStorage; + use CpuStorage::{BF16, F16, F32, F64}; + match (s1, s2, s3) { + (BF16(s1), BF16(s2), BF16(s3)) => inner(s1, l1, s2, l2, s3, l3), + (F16(s1), F16(s2), F16(s3)) => inner(s1, l1, s2, l2, s3, l3), + (F32(s1), F32(s2), F32(s3)) => inner(s1, l1, s2, l2, s3, l3), + (F64(s1), F64(s2), F64(s3)) => inner(s1, l1, s2, l2, s3, l3), + _ => mcandle_core::bail!( + "unsupported dtype for rope {:?} {:?} {:?}", + s1.dtype(), + s2.dtype(), + s3.dtype() + ), + } + } + + #[cfg(feature = "cuda")] + fn cuda_fwd( + &self, + s1: &mcandle_core::CudaStorage, + l1: &Layout, + s2: &mcandle_core::CudaStorage, + l2: &Layout, + s3: &mcandle_core::CudaStorage, + l3: &Layout, + ) -> Result<(mcandle_core::CudaStorage, Shape)> { + use mcandle_core::cuda_backend::cudarc::driver::{ + CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, + }; + use mcandle_core::cuda_backend::{kernel_name, kernels, WrapErr}; + use mcandle_core::{CudaDevice, WithDType}; + + fn inner( + src: &CudaSlice, + l_src: &Layout, + cos: &CudaSlice, + l_cos: &Layout, + sin: &CudaSlice, + l_sin: &Layout, + dev: &CudaDevice, + ) -> Result> { + let src = match l_src.contiguous_offsets() { + None => mcandle_core::bail!("src input has to be contiguous"), + Some((o1, o2)) => src.slice(o1..o2), + }; + let cos = match l_cos.contiguous_offsets() { + None => mcandle_core::bail!("cos input has to be contiguous"), + Some((o1, o2)) => cos.slice(o1..o2), + }; + let sin = match l_sin.contiguous_offsets() { + None => mcandle_core::bail!("sin input has to be contiguous"), + Some((o1, o2)) => sin.slice(o1..o2), + }; + let (b, h, t, d) = l_src.shape().dims4()?; + let el = b * h * t * d; + let cfg = LaunchConfig::for_num_elems((el / 2) as u32); + let func = dev.get_or_load_func(&kernel_name::("rope_i"), kernels::REDUCE)?; + // SAFETY: Set later by running the kernel. + let dst = unsafe { dev.alloc::(el) }.w()?; + let params = (&src, &cos, &sin, &dst, (b * h) as u32, (t * d) as u32); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(dst) + } + + use mcandle_core::backend::BackendStorage; + use mcandle_core::cuda_backend::CudaStorageSlice::{BF16, F16, F32, F64}; + let dev = s1.device(); + let slice = match (&s1.slice, &s2.slice, &s3.slice) { + (BF16(s1), BF16(s2), BF16(s3)) => BF16(inner(s1, l1, s2, l2, s3, l3, dev)?), + (F16(s1), F16(s2), F16(s3)) => F16(inner(s1, l1, s2, l2, s3, l3, dev)?), + (F32(s1), F32(s2), F32(s3)) => F32(inner(s1, l1, s2, l2, s3, l3, dev)?), + (F64(s1), F64(s2), F64(s3)) => F64(inner(s1, l1, s2, l2, s3, l3, dev)?), + _ => mcandle_core::bail!( + "unsupported dtype for rope {:?} {:?} {:?}", + s1.dtype(), + s2.dtype(), + s3.dtype() + ), + }; + let dst = mcandle_core::cuda_backend::CudaStorage { + slice, + device: dev.clone(), + }; + Ok((dst, l1.shape().clone())) + } + + #[cfg(feature = "metal")] + fn metal_fwd( + &self, + src: &mcandle_core::MetalStorage, + l_src: &Layout, + cos: &mcandle_core::MetalStorage, + l_cos: &Layout, + sin: &mcandle_core::MetalStorage, + l_sin: &Layout, + ) -> Result<(mcandle_core::MetalStorage, Shape)> { + use mcandle_core::backend::BackendStorage; + let device = src.device(); + let command_buffer = device.command_buffer()?; + let kernels = device.kernels(); + if cos.dtype() != src.dtype() || sin.dtype() != src.dtype() { + mcandle_core::bail!( + "dtype mismatch in rope-i {:?} {:?} {:?}", + src.dtype(), + cos.dtype(), + sin.dtype() + ) + } + let name = match src.dtype() { + mcandle_core::DType::F32 => "rope_i_f32", + mcandle_core::DType::F16 => "rope_i_f16", + mcandle_core::DType::BF16 => "rope_i_bf16", + dtype => mcandle_core::bail!("rope-i is not implemented for {dtype:?}"), + }; + let (b, h, t, d) = l_src.shape().dims4()?; + let el = b * h * t * d; + let output = device.new_buffer(el, src.dtype(), "rope-i")?; + mcandle_metal_kernels::call_rope_i( + device.metal_device(), + &command_buffer, + kernels, + name, + b * h, + t * d, + src.buffer(), + l_src.start_offset() * src.dtype().size_in_bytes(), + cos.buffer(), + l_cos.start_offset() * cos.dtype().size_in_bytes(), + sin.buffer(), + l_sin.start_offset() * sin.dtype().size_in_bytes(), + &output, + ) + .map_err(mcandle_core::Error::wrap)?; + let out = mcandle_core::MetalStorage::new(output, device.clone(), el, src.dtype()); + Ok((out, l_src.shape().clone())) + } +} + +pub fn rope_i(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result { + let (_b_sz, _n_head, seq_len, n_embd) = xs.dims4()?; + let (cos_seq_len, cos_n_embd) = cos.dims2()?; + let (sin_seq_len, sin_n_embd) = cos.dims2()?; + if cos_n_embd * 2 != n_embd + || sin_n_embd * 2 != n_embd + || seq_len > cos_seq_len + || seq_len > sin_seq_len + { + mcandle_core::bail!( + "inconsistent last dim size in rope {:?} {:?} {:?}", + xs.shape(), + cos.shape(), + sin.shape() + ) + } + if !xs.is_contiguous() { + mcandle_core::bail!("xs has to be contiguous in rope") + } + if !cos.is_contiguous() { + mcandle_core::bail!("cos has to be contiguous in rope") + } + if !sin.is_contiguous() { + mcandle_core::bail!("sin has to be contiguous in rope") + } + xs.apply_op3_no_bwd(cos, sin, &RotaryEmbI) +} + +pub fn rope_i_slow(x: &Tensor, cos: &Tensor, sin: &Tensor) -> Result { + let (b_sz, n_head, seq_len, n_embd) = x.dims4()?; + let cos = cos + .narrow(0, 0, seq_len)? + .reshape((seq_len, n_embd / 2, 1))?; + let sin = sin + .narrow(0, 0, seq_len)? + .reshape((seq_len, n_embd / 2, 1))?; + let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?; + let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?; + let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?; + let x0 = x.narrow(D::Minus1, 0, 1)?; + let x1 = x.narrow(D::Minus1, 1, 1)?; + let y0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?; + let y1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?; + let rope = Tensor::cat(&[y0, y1], D::Minus1)?; + let rope = rope.flatten_from(D::Minus2)?; + Ok(rope) +} + +/// Contiguous variant of rope embeddings. +#[derive(Debug, Clone)] +struct RotaryEmb; + +impl mcandle_core::CustomOp3 for RotaryEmb { + fn name(&self) -> &'static str { + "rotary-emb" + } + + fn cpu_fwd( + &self, + s1: &CpuStorage, + l1: &Layout, + s2: &CpuStorage, + l2: &Layout, + s3: &CpuStorage, + l3: &Layout, + ) -> Result<(CpuStorage, Shape)> { + fn inner( + src: &[T], + l_src: &Layout, + cos: &[T], + l_cos: &Layout, + sin: &[T], + l_sin: &Layout, + ) -> Result<(CpuStorage, Shape)> { + let src = match l_src.contiguous_offsets() { + None => mcandle_core::bail!("input src has to be contiguous"), + Some((o1, o2)) => &src[o1..o2], + }; + let cos = match l_cos.contiguous_offsets() { + None => mcandle_core::bail!("input cos has to be contiguous"), + Some((o1, o2)) => &cos[o1..o2], + }; + let sin = match l_sin.contiguous_offsets() { + None => mcandle_core::bail!("input sin has to be contiguous"), + Some((o1, o2)) => &sin[o1..o2], + }; + let (b, h, t, d) = l_src.shape().dims4()?; + let el_count = b * h * t * d; + let mut dst = vec![T::zero(); el_count]; + src.par_chunks(t * d) + .zip(dst.par_chunks_mut(t * d)) + .for_each(|(src, dst)| { + for i_t in 0..t { + for i_d in 0..d / 2 { + let i1 = i_t * d + i_d; + let i2 = i1 + d / 2; + let i_cs = i_t * (d / 2) + i_d; + dst[i1] = src[i1] * cos[i_cs] - src[i2] * sin[i_cs]; + dst[i2] = src[i1] * sin[i_cs] + src[i2] * cos[i_cs]; + } + } + }); + let storage = mcandle_core::WithDType::to_cpu_storage_owned(dst); + Ok((storage, (b, h, t, d).into())) + } + + use mcandle_core::backend::BackendStorage; + use CpuStorage::{BF16, F16, F32, F64}; + match (s1, s2, s3) { + (BF16(s1), BF16(s2), BF16(s3)) => inner(s1, l1, s2, l2, s3, l3), + (F16(s1), F16(s2), F16(s3)) => inner(s1, l1, s2, l2, s3, l3), + (F32(s1), F32(s2), F32(s3)) => inner(s1, l1, s2, l2, s3, l3), + (F64(s1), F64(s2), F64(s3)) => inner(s1, l1, s2, l2, s3, l3), + _ => mcandle_core::bail!( + "unsupported dtype for rope {:?} {:?} {:?}", + s1.dtype(), + s2.dtype(), + s3.dtype() + ), + } + } + + #[cfg(feature = "cuda")] + fn cuda_fwd( + &self, + s1: &mcandle_core::CudaStorage, + l1: &Layout, + s2: &mcandle_core::CudaStorage, + l2: &Layout, + s3: &mcandle_core::CudaStorage, + l3: &Layout, + ) -> Result<(mcandle_core::CudaStorage, Shape)> { + use mcandle_core::cuda_backend::cudarc::driver::{ + CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, + }; + use mcandle_core::cuda_backend::{kernel_name, kernels, WrapErr}; + use mcandle_core::{CudaDevice, WithDType}; + + fn inner( + src: &CudaSlice, + l_src: &Layout, + cos: &CudaSlice, + l_cos: &Layout, + sin: &CudaSlice, + l_sin: &Layout, + dev: &CudaDevice, + ) -> Result> { + let src = match l_src.contiguous_offsets() { + None => mcandle_core::bail!("src input has to be contiguous"), + Some((o1, o2)) => src.slice(o1..o2), + }; + let cos = match l_cos.contiguous_offsets() { + None => mcandle_core::bail!("cos input has to be contiguous"), + Some((o1, o2)) => cos.slice(o1..o2), + }; + let sin = match l_sin.contiguous_offsets() { + None => mcandle_core::bail!("sin input has to be contiguous"), + Some((o1, o2)) => sin.slice(o1..o2), + }; + let (b, h, t, d) = l_src.shape().dims4()?; + let el = b * h * t * d; + let cfg = LaunchConfig::for_num_elems((el / 2) as u32); + let func = dev.get_or_load_func(&kernel_name::("rope"), kernels::REDUCE)?; + // SAFETY: Set later by running the kernel. + let dst = unsafe { dev.alloc::(el) }.w()?; + let params = ( + &src, + &cos, + &sin, + &dst, + (b * h) as u32, + (t * d) as u32, + d as u32, + ); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(dst) + } + + use mcandle_core::backend::BackendStorage; + use mcandle_core::cuda_backend::CudaStorageSlice::{BF16, F16, F32, F64}; + let dev = s1.device(); + let slice = match (&s1.slice, &s2.slice, &s3.slice) { + (BF16(s1), BF16(s2), BF16(s3)) => BF16(inner(s1, l1, s2, l2, s3, l3, dev)?), + (F16(s1), F16(s2), F16(s3)) => F16(inner(s1, l1, s2, l2, s3, l3, dev)?), + (F32(s1), F32(s2), F32(s3)) => F32(inner(s1, l1, s2, l2, s3, l3, dev)?), + (F64(s1), F64(s2), F64(s3)) => F64(inner(s1, l1, s2, l2, s3, l3, dev)?), + _ => mcandle_core::bail!( + "unsupported dtype for rope {:?} {:?} {:?}", + s1.dtype(), + s2.dtype(), + s3.dtype() + ), + }; + let dst = mcandle_core::cuda_backend::CudaStorage { + slice, + device: dev.clone(), + }; + Ok((dst, l1.shape().clone())) + } + + #[cfg(feature = "metal")] + fn metal_fwd( + &self, + src: &mcandle_core::MetalStorage, + l_src: &Layout, + cos: &mcandle_core::MetalStorage, + l_cos: &Layout, + sin: &mcandle_core::MetalStorage, + l_sin: &Layout, + ) -> Result<(mcandle_core::MetalStorage, Shape)> { + use mcandle_core::backend::BackendStorage; + let device = src.device(); + let command_buffer = device.command_buffer()?; + let kernels = device.kernels(); + if cos.dtype() != src.dtype() || sin.dtype() != src.dtype() { + mcandle_core::bail!( + "dtype mismatch in rope {:?} {:?} {:?}", + src.dtype(), + cos.dtype(), + sin.dtype() + ) + } + let name = match src.dtype() { + mcandle_core::DType::F32 => "rope_f32", + mcandle_core::DType::F16 => "rope_f16", + mcandle_core::DType::BF16 => "rope_bf16", + dtype => mcandle_core::bail!("rope is not implemented for {dtype:?}"), + }; + let (b, h, t, d) = l_src.shape().dims4()?; + let el = b * h * t * d; + let output = device.new_buffer(el, src.dtype(), "rope-i")?; + mcandle_metal_kernels::call_rope( + device.metal_device(), + &command_buffer, + kernels, + name, + b * h, + t * d, + d, + src.buffer(), + l_src.start_offset() * src.dtype().size_in_bytes(), + cos.buffer(), + l_cos.start_offset() * cos.dtype().size_in_bytes(), + sin.buffer(), + l_sin.start_offset() * sin.dtype().size_in_bytes(), + &output, + ) + .map_err(mcandle_core::Error::wrap)?; + let out = mcandle_core::MetalStorage::new(output, device.clone(), el, src.dtype()); + Ok((out, l_src.shape().clone())) + } +} + +pub fn rope(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result { + let (_b_sz, _n_head, seq_len, n_embd) = xs.dims4()?; + let (cos_seq_len, cos_n_embd) = cos.dims2()?; + let (sin_seq_len, sin_n_embd) = sin.dims2()?; + if cos_n_embd * 2 != n_embd + || sin_n_embd * 2 != n_embd + || seq_len > cos_seq_len + || seq_len > sin_seq_len + { + mcandle_core::bail!( + "inconsistent last dim size in rope {:?} {:?} {:?}", + xs.shape(), + cos.shape(), + sin.shape() + ) + } + if !xs.is_contiguous() { + mcandle_core::bail!("xs has to be contiguous in rope") + } + if !cos.is_contiguous() { + mcandle_core::bail!("cos has to be contiguous in rope") + } + if !sin.is_contiguous() { + mcandle_core::bail!("sin has to be contiguous in rope") + } + xs.apply_op3_no_bwd(cos, sin, &RotaryEmb) +} + +fn rotate_half(xs: &Tensor) -> Result { + let last_dim = xs.dim(D::Minus1)?; + let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?; + let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?; + Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1) +} + +pub fn rope_slow(x: &Tensor, cos: &Tensor, sin: &Tensor) -> Result { + let (_b_sz, _h, seq_len, _n_embd) = x.dims4()?; + let cos = Tensor::cat(&[cos, cos], D::Minus1)?; + let sin = Tensor::cat(&[sin, sin], D::Minus1)?; + let cos = cos.narrow(0, 0, seq_len)?; + let sin = sin.narrow(0, 0, seq_len)?; + let cos = cos.unsqueeze(0)?.unsqueeze(0)?; + let sin = sin.unsqueeze(0)?.unsqueeze(0)?; + x.broadcast_mul(&cos)? + rotate_half(x)?.broadcast_mul(&sin)? +} + +/// T (seqlen)/H (num-heads)/D (head-dim) contiguous variant of rope embeddings. +#[derive(Debug, Clone)] +struct RotaryEmbThd; + +impl mcandle_core::CustomOp3 for RotaryEmbThd { + fn name(&self) -> &'static str { + "rotary-emb" + } + + fn cpu_fwd( + &self, + s1: &CpuStorage, + l1: &Layout, + s2: &CpuStorage, + l2: &Layout, + s3: &CpuStorage, + l3: &Layout, + ) -> Result<(CpuStorage, Shape)> { + fn inner( + src: &[T], + l_src: &Layout, + cos: &[T], + l_cos: &Layout, + sin: &[T], + l_sin: &Layout, + ) -> Result<(CpuStorage, Shape)> { + let src = match l_src.contiguous_offsets() { + None => mcandle_core::bail!("input src has to be contiguous"), + Some((o1, o2)) => &src[o1..o2], + }; + let cos = match l_cos.contiguous_offsets() { + None => mcandle_core::bail!("input cos has to be contiguous"), + Some((o1, o2)) => &cos[o1..o2], + }; + let sin = match l_sin.contiguous_offsets() { + None => mcandle_core::bail!("input sin has to be contiguous"), + Some((o1, o2)) => &sin[o1..o2], + }; + let (b, t, h, d) = l_src.shape().dims4()?; + let el_count = b * h * t * d; + let mut dst = vec![T::zero(); el_count]; + src.par_chunks(t * h * d) + .zip(dst.par_chunks_mut(t * h * d)) + .for_each(|(src, dst)| { + for i_t in 0..t { + for i_d in 0..d / 2 { + let i_cs = i_t * (d / 2) + i_d; + for i_h in 0..h { + let i1 = i_t * h * d + i_h * d + i_d; + let i2 = i1 + d / 2; + dst[i1] = src[i1] * cos[i_cs] - src[i2] * sin[i_cs]; + dst[i2] = src[i1] * sin[i_cs] + src[i2] * cos[i_cs]; + } + } + } + }); + let storage = mcandle_core::WithDType::to_cpu_storage_owned(dst); + Ok((storage, (b, t, h, d).into())) + } + + use mcandle_core::backend::BackendStorage; + use CpuStorage::{BF16, F16, F32, F64}; + match (s1, s2, s3) { + (BF16(s1), BF16(s2), BF16(s3)) => inner(s1, l1, s2, l2, s3, l3), + (F16(s1), F16(s2), F16(s3)) => inner(s1, l1, s2, l2, s3, l3), + (F32(s1), F32(s2), F32(s3)) => inner(s1, l1, s2, l2, s3, l3), + (F64(s1), F64(s2), F64(s3)) => inner(s1, l1, s2, l2, s3, l3), + _ => mcandle_core::bail!( + "unsupported dtype for rope {:?} {:?} {:?}", + s1.dtype(), + s2.dtype(), + s3.dtype() + ), + } + } + + #[cfg(feature = "cuda")] + fn cuda_fwd( + &self, + s1: &mcandle_core::CudaStorage, + l1: &Layout, + s2: &mcandle_core::CudaStorage, + l2: &Layout, + s3: &mcandle_core::CudaStorage, + l3: &Layout, + ) -> Result<(mcandle_core::CudaStorage, Shape)> { + use mcandle_core::cuda_backend::cudarc::driver::{ + CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, + }; + use mcandle_core::cuda_backend::{kernel_name, kernels, WrapErr}; + use mcandle_core::{CudaDevice, WithDType}; + + fn inner( + src: &CudaSlice, + l_src: &Layout, + cos: &CudaSlice, + l_cos: &Layout, + sin: &CudaSlice, + l_sin: &Layout, + dev: &CudaDevice, + ) -> Result> { + let src = match l_src.contiguous_offsets() { + None => mcandle_core::bail!("src input has to be contiguous"), + Some((o1, o2)) => src.slice(o1..o2), + }; + let cos = match l_cos.contiguous_offsets() { + None => mcandle_core::bail!("cos input has to be contiguous"), + Some((o1, o2)) => cos.slice(o1..o2), + }; + let sin = match l_sin.contiguous_offsets() { + None => mcandle_core::bail!("sin input has to be contiguous"), + Some((o1, o2)) => sin.slice(o1..o2), + }; + let (b, t, h, d) = l_src.shape().dims4()?; + let el = b * h * t * d; + let cfg = LaunchConfig::for_num_elems((el / 2) as u32); + let func = dev.get_or_load_func(&kernel_name::("rope_thd"), kernels::REDUCE)?; + // SAFETY: Set later by running the kernel. + let dst = unsafe { dev.alloc::(el) }.w()?; + let params = ( + &src, &cos, &sin, &dst, b as u32, t as u32, h as u32, d as u32, + ); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(dst) + } + + use mcandle_core::backend::BackendStorage; + use mcandle_core::cuda_backend::CudaStorageSlice::{BF16, F16, F32, F64}; + let dev = s1.device(); + let slice = match (&s1.slice, &s2.slice, &s3.slice) { + (BF16(s1), BF16(s2), BF16(s3)) => BF16(inner(s1, l1, s2, l2, s3, l3, dev)?), + (F16(s1), F16(s2), F16(s3)) => F16(inner(s1, l1, s2, l2, s3, l3, dev)?), + (F32(s1), F32(s2), F32(s3)) => F32(inner(s1, l1, s2, l2, s3, l3, dev)?), + (F64(s1), F64(s2), F64(s3)) => F64(inner(s1, l1, s2, l2, s3, l3, dev)?), + _ => mcandle_core::bail!( + "unsupported dtype for rope {:?} {:?} {:?}", + s1.dtype(), + s2.dtype(), + s3.dtype() + ), + }; + let dst = mcandle_core::cuda_backend::CudaStorage { + slice, + device: dev.clone(), + }; + Ok((dst, l1.shape().clone())) + } + + #[cfg(feature = "metal")] + fn metal_fwd( + &self, + src: &mcandle_core::MetalStorage, + l_src: &Layout, + cos: &mcandle_core::MetalStorage, + l_cos: &Layout, + sin: &mcandle_core::MetalStorage, + l_sin: &Layout, + ) -> Result<(mcandle_core::MetalStorage, Shape)> { + use mcandle_core::backend::BackendStorage; + let device = src.device(); + let command_buffer = device.command_buffer()?; + let kernels = device.kernels(); + if cos.dtype() != src.dtype() || sin.dtype() != src.dtype() { + mcandle_core::bail!( + "dtype mismatch in rope {:?} {:?} {:?}", + src.dtype(), + cos.dtype(), + sin.dtype() + ) + } + let name = match src.dtype() { + mcandle_core::DType::F32 => "rope_thd_f32", + mcandle_core::DType::F16 => "rope_thd_f16", + mcandle_core::DType::BF16 => "rope_thd_bf16", + dtype => mcandle_core::bail!("rope_thd is not implemented for {dtype:?}"), + }; + let (b, t, h, d) = l_src.shape().dims4()?; + let el = b * h * t * d; + let output = device.new_buffer(el, src.dtype(), "rope-thd")?; + mcandle_metal_kernels::call_rope_thd( + device.metal_device(), + &command_buffer, + kernels, + name, + b, + t, + h, + d, + src.buffer(), + l_src.start_offset() * src.dtype().size_in_bytes(), + cos.buffer(), + l_cos.start_offset() * cos.dtype().size_in_bytes(), + sin.buffer(), + l_sin.start_offset() * sin.dtype().size_in_bytes(), + &output, + ) + .map_err(mcandle_core::Error::wrap)?; + let out = mcandle_core::MetalStorage::new(output, device.clone(), el, src.dtype()); + Ok((out, l_src.shape().clone())) + } +} + +pub fn rope_thd(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result { + let (_b_sz, seq_len, _n_head, n_embd) = xs.dims4()?; + let (cos_seq_len, cos_n_embd) = cos.dims2()?; + let (sin_seq_len, sin_n_embd) = sin.dims2()?; + if cos_n_embd * 2 != n_embd + || sin_n_embd * 2 != n_embd + || seq_len > cos_seq_len + || seq_len > sin_seq_len + { + mcandle_core::bail!( + "inconsistent last dim size in rope {:?} {:?} {:?}", + xs.shape(), + cos.shape(), + sin.shape() + ) + } + if !xs.is_contiguous() { + mcandle_core::bail!("xs has to be contiguous in rope") + } + if !cos.is_contiguous() { + mcandle_core::bail!("cos has to be contiguous in rope") + } + if !sin.is_contiguous() { + mcandle_core::bail!("sin has to be contiguous in rope") + } + xs.apply_op3_no_bwd(cos, sin, &RotaryEmbThd) +} diff --git a/mcandle/mcandle-nn/src/sequential.rs b/mcandle/mcandle-nn/src/sequential.rs new file mode 100644 index 000000000..6015cc52f --- /dev/null +++ b/mcandle/mcandle-nn/src/sequential.rs @@ -0,0 +1,64 @@ +//! Sequential Layer +//! +//! A sequential layer used to chain multiple layers and closures. +use mcandle_core::{Module, Result, Tensor}; + +/// A sequential layer combining multiple other layers. +pub struct Sequential { + layers: Vec>, +} + +/// Creates a new empty sequential layer. +pub fn seq() -> Sequential { + Sequential { layers: vec![] } +} + +impl Sequential { + /// The number of sub-layers embedded in this layer. + pub fn len(&self) -> i64 { + self.layers.len() as i64 + } + + /// Returns true if this layer does not have any sub-layer. + pub fn is_empty(&self) -> bool { + self.layers.is_empty() + } +} + +impl Module for Sequential { + fn forward(&self, xs: &Tensor) -> Result { + let mut xs = xs.clone(); + for layer in self.layers.iter() { + xs = layer.forward(&xs)? + } + Ok(xs) + } +} + +impl Sequential { + /// Appends a layer after all the current layers. + #[allow(clippy::should_implement_trait)] + pub fn add(mut self, layer: M) -> Self { + self.layers.push(Box::new(layer)); + self + } + + /// Appends a closure after all the current layers. + pub fn add_fn(self, f: F) -> Self + where + F: 'static + Fn(&Tensor) -> Result + Send + Sync, + { + self.add(super::func(f)) + } + + /// Applies the forward pass and returns the output for each layer. + pub fn forward_all(&self, xs: &Tensor) -> Result> { + let mut vec = Vec::with_capacity(self.layers.len()); + let mut xs = xs.clone(); + for layer in self.layers.iter() { + xs = layer.forward(&xs)?; + vec.push(xs.clone()) + } + Ok(vec) + } +} diff --git a/mcandle/mcandle-nn/src/var_builder.rs b/mcandle/mcandle-nn/src/var_builder.rs new file mode 100644 index 000000000..2898498e2 --- /dev/null +++ b/mcandle/mcandle-nn/src/var_builder.rs @@ -0,0 +1,791 @@ +//! A `VarBuilder` for variable retrieval from models +//! +//! A `VarBuilder` is used to retrieve variables used by a model. These variables can either come +//! from a pre-trained checkpoint, e.g. using `VarBuilder::from_mmaped_safetensors`, or initialized +//! for training, e.g. using `VarBuilder::from_varmap`. +use crate::VarMap; +use mcandle_core::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor}; +use safetensors::{slice::IndexOp, tensor::SafeTensors}; +use std::collections::HashMap; +use std::sync::Arc; + +/// A structure used to retrieve variables, these variables can either come from storage or be +/// generated via some form of initialization. +/// +/// The way to retrieve variables is defined in the backend embedded in the `VarBuilder`. +pub struct VarBuilderArgs<'a, B: Backend> { + data: Arc>, + path: Vec, + pub dtype: DType, + _phantom: std::marker::PhantomData<&'a B>, +} + +impl<'a, B: Backend> Clone for VarBuilderArgs<'a, B> { + fn clone(&self) -> Self { + Self { + data: self.data.clone(), + path: self.path.clone(), + dtype: self.dtype, + _phantom: self._phantom, + } + } +} + +/// A simple `VarBuilder`, this is less generic than `VarBuilderArgs` but should cover most common +/// use cases. +pub type VarBuilder<'a> = VarBuilderArgs<'a, Box>; + +struct TensorData { + backend: Arc, + pub dtype: DType, + pub device: Device, +} + +/// A trait that defines how tensor data is retrieved. +/// +/// Typically this would use disk storage in some specific format, or random initialization. +/// Note that there is a specialized version of this trait (`SimpleBackend`) that can be used most +/// of the time. The main restriction is that it doesn't allow for specific args (besides +/// initialization hints). +pub trait Backend: Send + Sync { + type Hints: Default; + + /// Retrieve a tensor with some target shape. + fn get( + &self, + s: Shape, + name: &str, + h: Self::Hints, + dtype: DType, + dev: &Device, + ) -> Result; + + fn contains_tensor(&self, name: &str) -> bool; +} + +pub trait SimpleBackend: Send + Sync { + /// Retrieve a tensor based on a target name and shape. + fn get( + &self, + s: Shape, + name: &str, + h: crate::Init, + dtype: DType, + dev: &Device, + ) -> Result; + + fn contains_tensor(&self, name: &str) -> bool; +} + +impl<'a> Backend for Box { + type Hints = crate::Init; + fn get( + &self, + s: Shape, + name: &str, + h: Self::Hints, + dtype: DType, + dev: &Device, + ) -> Result { + self.as_ref().get(s, name, h, dtype, dev) + } + + fn contains_tensor(&self, name: &str) -> bool { + self.as_ref().contains_tensor(name) + } +} + +impl<'a, B: Backend> VarBuilderArgs<'a, B> { + pub fn new_with_args(backend: B, dtype: DType, dev: &Device) -> Self { + let data = TensorData { + backend: Arc::new(backend), + dtype, + device: dev.clone(), + }; + Self { + data: Arc::new(data), + path: vec![], + dtype, + _phantom: std::marker::PhantomData, + } + } + + /// Returns the prefix of the `VarBuilder`. + pub fn prefix(&self) -> String { + self.path.join(".") + } + + /// Returns a new `VarBuilder` using the root path. + pub fn root(&self) -> Self { + Self { + data: self.data.clone(), + path: vec![], + dtype: self.dtype, + _phantom: std::marker::PhantomData, + } + } + + /// Returns a new `VarBuilder` with the prefix set to `prefix`. + pub fn set_prefix(&self, prefix: impl ToString) -> Self { + Self { + data: self.data.clone(), + path: vec![prefix.to_string()], + dtype: self.dtype, + _phantom: std::marker::PhantomData, + } + } + + /// Return a new `VarBuilder` adding `s` to the current prefix. This can be think of as `cd` + /// into a directory. + pub fn push_prefix(&self, s: S) -> Self { + let mut path = self.path.clone(); + path.push(s.to_string()); + Self { + data: self.data.clone(), + path, + dtype: self.dtype, + _phantom: std::marker::PhantomData, + } + } + + /// Short alias for `push_prefix`. + pub fn pp(&self, s: S) -> Self { + self.push_prefix(s) + } + + /// The device used by default. + pub fn device(&self) -> &Device { + &self.data.device + } + + /// The dtype used by default. + pub fn dtype(&self) -> DType { + self.dtype + } + + /// Clone the VarBuilder tweaking its dtype + pub fn to_dtype(&self, dtype: DType) -> Self { + Self { + data: self.data.clone(), + path: self.path.clone(), + dtype, + _phantom: std::marker::PhantomData, + } + } + + fn path(&self, tensor_name: &str) -> String { + if self.path.is_empty() { + tensor_name.to_string() + } else { + [&self.path.join("."), tensor_name].join(".") + } + } + + /// This returns true only if a tensor with the passed in name is available. E.g. when passed + /// `a`, true is returned if `prefix.a` exists but false is returned if only `prefix.a.b` + /// exists. + pub fn contains_tensor(&self, tensor_name: &str) -> bool { + let path = self.path(tensor_name); + self.data.backend.contains_tensor(&path) + } + + /// Retrieve the tensor associated with the given name at the current path. + pub fn get_with_hints>( + &self, + s: S, + name: &str, + hints: B::Hints, + ) -> Result { + self.get_with_hints_dtype(s, name, hints, self.dtype) + } + + /// Retrieve the tensor associated with the given name at the current path. + pub fn get>(&self, s: S, name: &str) -> Result { + self.get_with_hints(s, name, Default::default()) + } + + /// Retrieve the tensor associated with the given name & dtype at the current path. + pub fn get_with_hints_dtype>( + &self, + s: S, + name: &str, + hints: B::Hints, + dtype: DType, + ) -> Result { + let path = self.path(name); + self.data + .backend + .get(s.into(), &path, hints, dtype, &self.data.device) + } + + /// Set the device of the VarBuilder. + pub fn set_device(self, device: Device) -> Self { + Self { + data: Arc::new(TensorData { + backend: self.data.backend.clone(), + dtype: self.data.dtype, + device, + }), + ..self + } + } + + /// Set the dtype of the VarBuilder. + pub fn set_dtype(self, dtype: DType) -> Self { + Self { + data: Arc::new(TensorData { + backend: self.data.backend.clone(), + dtype, + device: self.data.device.clone(), + }), + dtype, + ..self + } + } +} + +struct Zeros; + +impl SimpleBackend for Zeros { + fn get(&self, s: Shape, _: &str, _: crate::Init, dtype: DType, dev: &Device) -> Result { + Tensor::zeros(s, dtype, dev) + } + + fn contains_tensor(&self, _name: &str) -> bool { + true + } +} + +impl SimpleBackend for HashMap { + fn get( + &self, + s: Shape, + name: &str, + _: crate::Init, + dtype: DType, + dev: &Device, + ) -> Result { + let tensor = self + .get(name) + .ok_or_else(|| { + Error::CannotFindTensor { + path: name.to_string(), + } + .bt() + })? + .clone(); + if tensor.shape() != &s { + Err(mcandle_core::Error::UnexpectedShape { + msg: format!("shape mismatch for {name}"), + expected: s, + got: tensor.shape().clone(), + } + .bt())? + } + tensor.to_device(dev)?.to_dtype(dtype) + } + + fn contains_tensor(&self, name: &str) -> bool { + self.contains_key(name) + } +} + +impl SimpleBackend for VarMap { + fn get( + &self, + s: Shape, + name: &str, + h: crate::Init, + dtype: DType, + dev: &Device, + ) -> Result { + VarMap::get(self, s, name, h, dtype, dev) + } + + fn contains_tensor(&self, name: &str) -> bool { + self.data().lock().unwrap().contains_key(name) + } +} + +#[allow(dead_code)] +pub struct SafeTensorWithRouting<'a> { + routing: HashMap, + safetensors: Vec>, +} + +impl<'a> SimpleBackend for SafeTensorWithRouting<'a> { + fn get( + &self, + s: Shape, + path: &str, + _: crate::Init, + dtype: DType, + dev: &Device, + ) -> Result { + let index = self.routing.get(path).ok_or_else(|| { + Error::CannotFindTensor { + path: path.to_string(), + } + .bt() + })?; + let tensor = self.safetensors[*index] + .tensor(path)? + .load(dev)? + .to_dtype(dtype)?; + if tensor.shape() != &s { + Err(mcandle_core::Error::UnexpectedShape { + msg: format!("shape mismatch for {path}"), + expected: s, + got: tensor.shape().clone(), + } + .bt())? + } + Ok(tensor) + } + + fn contains_tensor(&self, name: &str) -> bool { + self.routing.contains_key(name) + } +} + +impl SimpleBackend for mcandle_core::npy::NpzTensors { + fn get( + &self, + s: Shape, + path: &str, + _: crate::Init, + dtype: DType, + dev: &Device, + ) -> Result { + let tensor = match self.get(path)? { + None => Err(Error::CannotFindTensor { + path: path.to_string(), + } + .bt())?, + Some(tensor) => tensor, + }; + let tensor = tensor.to_device(dev)?.to_dtype(dtype)?; + if tensor.shape() != &s { + Err(mcandle_core::Error::UnexpectedShape { + msg: format!("shape mismatch for {path}"), + expected: s, + got: tensor.shape().clone(), + } + .bt())? + } + Ok(tensor) + } + + fn contains_tensor(&self, name: &str) -> bool { + self.get(name).map_or(false, |v| v.is_some()) + } +} + +impl SimpleBackend for mcandle_core::pickle::PthTensors { + fn get( + &self, + s: Shape, + path: &str, + _: crate::Init, + dtype: DType, + dev: &Device, + ) -> Result { + let tensor = match self.get(path)? { + None => Err(Error::CannotFindTensor { + path: path.to_string(), + } + .bt())?, + Some(tensor) => tensor, + }; + let tensor = tensor.to_device(dev)?.to_dtype(dtype)?; + if tensor.shape() != &s { + Err(mcandle_core::Error::UnexpectedShape { + msg: format!("shape mismatch for {path}"), + expected: s, + got: tensor.shape().clone(), + } + .bt())? + } + Ok(tensor) + } + + fn contains_tensor(&self, name: &str) -> bool { + self.get(name).map_or(false, |v| v.is_some()) + } +} + +impl SimpleBackend for mcandle_core::safetensors::MmapedSafetensors { + fn get( + &self, + s: Shape, + name: &str, + _: crate::Init, + dtype: DType, + dev: &Device, + ) -> Result { + let tensor = self.load(name, dev)?.to_dtype(dtype)?; + if tensor.shape() != &s { + Err(mcandle_core::Error::UnexpectedShape { + msg: format!("shape mismatch for {name}"), + expected: s, + got: tensor.shape().clone(), + } + .bt())? + } + Ok(tensor) + } + + fn contains_tensor(&self, name: &str) -> bool { + self.get(name).is_ok() + } +} + +impl SimpleBackend for mcandle_core::safetensors::BufferedSafetensors { + fn get( + &self, + s: Shape, + name: &str, + _: crate::Init, + dtype: DType, + dev: &Device, + ) -> Result { + let tensor = self.load(name, dev)?.to_dtype(dtype)?; + if tensor.shape() != &s { + Err(mcandle_core::Error::UnexpectedShape { + msg: format!("shape mismatch for {name}"), + expected: s, + got: tensor.shape().clone(), + } + .bt())? + } + Ok(tensor) + } + + fn contains_tensor(&self, name: &str) -> bool { + self.get(name).is_ok() + } +} + +impl<'a> SimpleBackend for mcandle_core::safetensors::SliceSafetensors<'a> { + fn get( + &self, + s: Shape, + name: &str, + _: crate::Init, + dtype: DType, + dev: &Device, + ) -> Result { + let tensor = self.load(name, dev)?.to_dtype(dtype)?; + if tensor.shape() != &s { + Err(mcandle_core::Error::UnexpectedShape { + msg: format!("shape mismatch for {name}"), + expected: s, + got: tensor.shape().clone(), + } + .bt())? + } + Ok(tensor) + } + + fn contains_tensor(&self, name: &str) -> bool { + self.get(name).is_ok() + } +} + +impl<'a> VarBuilder<'a> { + /// Initializes a `VarBuilder` using a custom backend. + /// + /// It is preferred to use one of the more specific constructors. This + /// constructor is provided to allow downstream users to define their own + /// backends. + pub fn from_backend( + backend: Box, + dtype: DType, + device: Device, + ) -> Self { + let data = TensorData { + backend: Arc::new(backend), + dtype, + device, + }; + Self { + data: Arc::new(data), + path: vec![], + dtype, + _phantom: std::marker::PhantomData, + } + } + + /// Initializes a `VarBuilder` that uses zeros for any tensor. + pub fn zeros(dtype: DType, dev: &Device) -> Self { + Self::from_backend(Box::new(Zeros), dtype, dev.clone()) + } + + /// Initializes a `VarBuilder` that retrieves tensors stored in a hashtable. An error is + /// returned if no tensor is available under the requested path or on shape mismatches. + pub fn from_tensors(ts: HashMap, dtype: DType, dev: &Device) -> Self { + Self::from_backend(Box::new(ts), dtype, dev.clone()) + } + + /// Initializes a `VarBuilder` using a `VarMap`. The requested tensors are created and + /// initialized on new paths, the same tensor is used if the same path is requested multiple + /// times. This is commonly used when initializing a model before training. + /// + /// Note that it is possible to load the tensor values after model creation using the `load` + /// method on `varmap`, this can be used to start model training from an existing checkpoint. + pub fn from_varmap(varmap: &VarMap, dtype: DType, dev: &Device) -> Self { + Self::from_backend(Box::new(varmap.clone()), dtype, dev.clone()) + } + + /// Initializes a `VarBuilder` that retrieves tensors stored in a collection of safetensors + /// files. + /// + /// # Safety + /// + /// The unsafe is inherited from [`memmap2::MmapOptions`]. + pub unsafe fn from_mmaped_safetensors>( + paths: &[P], + dtype: DType, + dev: &Device, + ) -> Result { + let tensors = mcandle_core::safetensors::MmapedSafetensors::multi(paths)?; + Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone())) + } + + /// Initializes a `VarBuilder` from a binary buffer in the safetensor format. + pub fn from_buffered_safetensors(data: Vec, dtype: DType, dev: &Device) -> Result { + let tensors = mcandle_core::safetensors::BufferedSafetensors::new(data)?; + Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone())) + } + + /// Initializes a `VarBuilder` from a binary slice in the safetensor format. + pub fn from_slice_safetensors(data: &'a [u8], dtype: DType, dev: &Device) -> Result { + let tensors = mcandle_core::safetensors::SliceSafetensors::new(data)?; + Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone())) + } + + /// Initializes a `VarBuilder` that retrieves tensors stored in a numpy npz file. + pub fn from_npz>(p: P, dtype: DType, dev: &Device) -> Result { + let npz = mcandle_core::npy::NpzTensors::new(p)?; + Ok(Self::from_backend(Box::new(npz), dtype, dev.clone())) + } + + /// Initializes a `VarBuilder` that retrieves tensors stored in a pytorch pth file. + pub fn from_pth>(p: P, dtype: DType, dev: &Device) -> Result { + let pth = mcandle_core::pickle::PthTensors::new(p, None)?; + Ok(Self::from_backend(Box::new(pth), dtype, dev.clone())) + } + + /// Gets a VarBuilder that applies some renaming function on tensor it gets queried for before + /// passing the new names to the inner VarBuilder. + /// + /// ```rust + /// use mcandle_core::{Tensor, DType, Device}; + /// + /// let a = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((2, 3))?; + /// let tensors: std::collections::HashMap<_, _> = [ + /// ("foo".to_string(), a), + /// ] + /// .into_iter() + /// .collect(); + /// let vb = mcandle_nn::VarBuilder::from_tensors(tensors, DType::F32, &Device::Cpu); + /// assert!(vb.contains_tensor("foo")); + /// assert!(vb.get((2, 3), "foo").is_ok()); + /// assert!(!vb.contains_tensor("bar")); + /// let vb = vb.rename_f(|f: &str| if f == "bar" { "foo".to_string() } else { f.to_string() }); + /// assert!(vb.contains_tensor("bar")); + /// assert!(vb.contains_tensor("foo")); + /// assert!(vb.get((2, 3), "bar").is_ok()); + /// assert!(vb.get((2, 3), "foo").is_ok()); + /// assert!(!vb.contains_tensor("baz")); + /// # Ok::<(), mcandle_core::Error>(()) + /// ``` + pub fn rename_f String + Sync + Send + 'static>(self, f: F) -> Self { + let f: Box String + Sync + Send + 'static> = Box::new(f); + self.rename(f) + } + + pub fn rename(self, renamer: R) -> Self { + let dtype = self.dtype(); + let device = self.device().clone(); + let path = self.path.clone(); + let backend = Rename::new(self, renamer); + let backend: Box = Box::new(backend); + let data = TensorData { + backend: Arc::new(backend), + dtype, + device, + }; + Self { + data: Arc::new(data), + dtype, + path, + _phantom: std::marker::PhantomData, + } + } +} + +pub struct ShardedSafeTensors(mcandle_core::safetensors::MmapedSafetensors); + +pub type ShardedVarBuilder<'a> = VarBuilderArgs<'a, ShardedSafeTensors>; + +impl ShardedSafeTensors { + /// Initializes a `VarBuilder` that retrieves tensors stored in a collection of safetensors + /// files and make them usable in a sharded way. + /// + /// # Safety + /// + /// The unsafe is inherited from [`memmap2::MmapOptions`]. + pub unsafe fn var_builder>( + paths: &[P], + dtype: DType, + dev: &Device, + ) -> Result> { + let tensors = mcandle_core::safetensors::MmapedSafetensors::multi(paths)?; + let backend = ShardedSafeTensors(tensors); + Ok(VarBuilderArgs::new_with_args(backend, dtype, dev)) + } +} + +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +pub struct Shard { + pub dim: usize, + pub rank: usize, + pub world_size: usize, +} + +impl Default for Shard { + fn default() -> Self { + Self { + dim: 0, + rank: 0, + world_size: 1, + } + } +} + +/// Get part of a tensor, typically used to do Tensor Parallelism sharding. +/// +/// If the tensor is of size (1024, 1024). +/// +/// `dim` corresponds to the dimension to slice into +/// `rank` is the rank of the current process +/// `world_size` is the total number of ranks in the process group +/// +/// `get_sharded("tensor", 0, 0, 2)` means `tensor.i((..512))` +/// `get_sharded("tensor", 0, 1, 2)` means `tensor.i((512..))` +/// `get_sharded("tensor", 1, 0, 2)` means `tensor.i((.., ..512))` +impl Backend for ShardedSafeTensors { + type Hints = Shard; + + fn get( + &self, + target_shape: Shape, // The size is only checked when the world size is 1. + path: &str, + h: Self::Hints, + dtype: DType, + dev: &Device, + ) -> Result { + if h.world_size == 1 { + // There is no sharding to be applied here so we use the default backend to speed + // things up. + return SimpleBackend::get(&self.0, target_shape, path, Default::default(), dtype, dev); + } + + let Shard { + dim, + rank, + world_size, + } = h; + let view = self.0.get(path)?; + let view_dtype = view.dtype(); + let mut shape = view.shape().to_vec(); + let size = shape[dim]; + + if size % world_size != 0 { + return Err(Error::ShapeMismatchSplit { + shape: shape.into(), + dim, + n_parts: world_size, + }); + } + let block_size = size / world_size; + let start = rank * block_size; + let stop = (rank + 1) * block_size; + + // Everything is expressed in tensor dimension + // bytes offsets is handled automatically for safetensors. + + let iterator = if dim == 0 { + view.slice(start..stop).map_err(|_| { + Error::Msg(format!( + "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}" + )) + })? + } else if dim == 1 { + view.slice((.., start..stop)).map_err(|_| { + Error::Msg(format!( + "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}" + )) + })? + } else { + mcandle_core::bail!("Get sharded on dimensions != 0 or 1") + }; + + shape[dim] = block_size; + + let view_dtype: DType = view_dtype.try_into()?; + let raw: Vec = iterator.into_iter().flatten().cloned().collect(); + Tensor::from_raw_buffer(&raw, view_dtype, &shape, dev)?.to_dtype(dtype) + } + + fn contains_tensor(&self, name: &str) -> bool { + self.0.get(name).is_ok() + } +} + +/// This traits specifies a way to rename the queried names into names that are stored in an inner +/// VarBuilder. +pub trait Renamer { + /// This is applied to the name obtained by a name call and the resulting name is passed to the + /// inner VarBuilder. + fn rename(&self, v: &str) -> std::borrow::Cow<'_, str>; +} + +pub struct Rename<'a, R: Renamer> { + inner: VarBuilder<'a>, + renamer: R, +} + +impl<'a, R: Renamer + Sync + Send> SimpleBackend for Rename<'a, R> { + fn get( + &self, + s: Shape, + name: &str, + h: crate::Init, + dtype: DType, + dev: &Device, + ) -> Result { + let name = self.renamer.rename(name); + self.inner + .get_with_hints_dtype(s, &name, h, dtype)? + .to_device(dev) + } + + fn contains_tensor(&self, name: &str) -> bool { + let name = self.renamer.rename(name); + self.inner.contains_tensor(&name) + } +} + +impl<'a, R: Renamer> Rename<'a, R> { + pub fn new(inner: VarBuilder<'a>, renamer: R) -> Self { + Self { inner, renamer } + } +} + +impl Renamer for Box String + Sync + Send> { + fn rename(&self, v: &str) -> std::borrow::Cow<'_, str> { + std::borrow::Cow::Owned(self(v)) + } +} diff --git a/mcandle/mcandle-nn/src/var_map.rs b/mcandle/mcandle-nn/src/var_map.rs new file mode 100644 index 000000000..a3f1a209b --- /dev/null +++ b/mcandle/mcandle-nn/src/var_map.rs @@ -0,0 +1,121 @@ +//! A `VarMap` is a store that holds named variables. +//! +use mcandle_core::{DType, Device, Result, Shape, Tensor, Var}; +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; + +/// A `VarMap` is a store that holds named variables. Variables can be retrieved from the stores +/// and new variables can be added by providing some initialization config in case they are +/// missing. +/// `VarMap` structures can be serialized in the safetensors format. +#[derive(Clone)] +pub struct VarMap { + data: Arc>>, +} + +impl VarMap { + /// Create a new empty `VarMap`. + #[allow(clippy::new_without_default)] + pub fn new() -> Self { + let data = Arc::new(Mutex::new(HashMap::new())); + Self { data } + } + + /// Retrieve all the variables currently stored in the map. + pub fn all_vars(&self) -> Vec { + let tensor_data = self.data.lock().unwrap(); + #[allow(clippy::map_clone)] + tensor_data.values().map(|c| c.clone()).collect::>() + } + + /// Save the map in the safetensors format. + pub fn save>(&self, path: P) -> Result<()> { + let tensor_data = self.data.lock().unwrap(); + let data = tensor_data.iter().map(|(k, v)| (k, v.as_tensor())); + safetensors::tensor::serialize_to_file(data, &None, path.as_ref())?; + Ok(()) + } + + /// Load some values from a safetensors file and modify the existing variables to have these + /// values. + /// + /// Note that values for variables that are currently not in the map are not kept. + pub fn load>(&mut self, path: P) -> Result<()> { + let path = path.as_ref(); + let data = unsafe { mcandle_core::safetensors::MmapedSafetensors::new(path)? }; + let mut tensor_data = self.data.lock().unwrap(); + for (name, var) in tensor_data.iter_mut() { + let data = data.load(name, var.device())?; + if let Err(err) = var.set(&data) { + mcandle_core::bail!("error setting {name} using data from {path:?}: {err}",) + } + } + Ok(()) + } + + /// Set a named variable to some value. + pub fn set_one, V: AsRef>(&mut self, name: K, value: V) -> Result<()> { + let tensor_data = self.data.lock().unwrap(); + let name = name.as_ref(); + match tensor_data.get(name) { + None => mcandle_core::bail!("cannot find {name} in VarMap"), + Some(var) => { + if let Err(err) = var.set(value.as_ref()) { + mcandle_core::bail!("error setting {name}: {err}",) + } + } + } + Ok(()) + } + + /// Set some named variables to some values. + /// + /// If an error is returned, some of the variables might have already been set to their new + /// values. + pub fn set, K: AsRef, V: AsRef>( + &mut self, + iter: I, + ) -> Result<()> { + let tensor_data = self.data.lock().unwrap(); + for (name, value) in iter { + let name = name.as_ref(); + match tensor_data.get(name) { + None => mcandle_core::bail!("cannot find {name} in VarMap"), + Some(var) => { + if let Err(err) = var.set(value.as_ref()) { + mcandle_core::bail!("error setting {name}: {err}",) + } + } + } + } + Ok(()) + } + + /// Retrieve or add a new variable. + pub fn get>( + &self, + shape: S, + path: &str, + init: crate::Init, + dtype: DType, + device: &Device, + ) -> Result { + let shape = shape.into(); + let mut tensor_data = self.data.lock().unwrap(); + if let Some(tensor) = tensor_data.get(path) { + let tensor_shape = tensor.shape(); + if &shape != tensor_shape { + mcandle_core::bail!("shape mismatch on {path}: {shape:?} <> {tensor_shape:?}") + } + return Ok(tensor.as_tensor().clone()); + } + let var = init.var(shape, dtype, device)?; + let tensor = var.as_tensor().clone(); + tensor_data.insert(path.to_string(), var); + Ok(tensor) + } + + pub fn data(&self) -> &Mutex> { + &self.data + } +} diff --git a/mcandle/mcandle-nn/tests/batch_norm.rs b/mcandle/mcandle-nn/tests/batch_norm.rs new file mode 100644 index 000000000..5f33e066e --- /dev/null +++ b/mcandle/mcandle-nn/tests/batch_norm.rs @@ -0,0 +1,128 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use anyhow::Result; +use mcandle_core::{test_utils, DType, Device, Tensor}; +use mcandle_nn::{batch_norm, BatchNorm, BatchNormConfig, VarBuilder, VarMap}; + +/* The test below has been generated using the following PyTorch code: +import torch +torch.manual_seed(19551105) +m = torch.nn.BatchNorm2d(5, affine=False) +input = torch.randn(2, 5, 3, 4) +output = m(input) +print(input.flatten()) +print(output.flatten()) +print(m.running_mean) +print(m.running_var) +*/ +#[test] +fn batch_norm_test() -> Result<()> { + let running_mean = Tensor::zeros(5, DType::F32, &Device::Cpu)?; + let running_var = Tensor::ones(5, DType::F32, &Device::Cpu)?; + let bn = BatchNorm::new_no_bias(5, running_mean.clone(), running_var.clone(), 1e-8)?; + let input: [f32; 120] = [ + -0.7493, -1.0410, 1.6977, -0.6579, 1.7982, -0.0087, 0.2812, -0.1190, 0.2908, -0.5975, + -0.0278, -0.2138, -1.3130, -1.6048, -2.2028, 0.9452, 0.4002, 0.0831, 1.0004, 0.1860, + 0.5004, 0.5539, 0.9991, -0.2540, -0.0703, -0.3752, -0.1096, -0.2374, 1.0258, -2.2208, + -0.0257, 0.6073, -1.1627, -0.0964, -1.9718, 1.6577, 0.1931, -0.3692, -0.8011, 0.9059, + 0.4797, 0.6521, -0.0165, -0.6683, -0.4148, 2.0649, -0.8276, 1.7947, -0.2061, 0.5812, + -1.3598, 1.6192, 1.0466, -0.4423, 0.4202, 0.1749, 0.6969, 0.2616, -0.0369, -1.4951, + -0.0814, -0.1877, 0.0267, 0.6150, 0.2402, -1.1440, -2.0068, 0.6032, -2.6639, 0.8260, + 0.1085, -0.1693, 1.2805, 0.7654, -0.4930, 0.3770, 1.1309, 0.2303, 0.2949, -0.2634, -0.5225, + 0.4269, 0.6341, 1.5736, 0.9827, -1.2499, 0.3509, -1.6243, -0.8123, 0.7634, -0.3047, 0.0143, + -0.4032, 0.0537, 0.7022, 0.8405, -1.2221, -1.6847, -0.0714, -0.1608, 0.5579, -1.5858, + 0.4617, -0.6480, 0.1332, 0.0419, -0.9784, 0.4173, 1.2313, -1.9046, -0.1656, 0.1259, 0.0763, + 1.4252, -0.9115, -0.1093, -0.3100, -0.6734, -1.4357, 0.9205, + ]; + let input = Tensor::new(&input, &Device::Cpu)?.reshape((2, 5, 3, 4))?; + let output = bn.forward_train(&input)?; + assert_eq!(output.dims(), &[2, 5, 3, 4]); + let output = output.flatten_all()?; + assert_eq!( + test_utils::to_vec1_round(&output, 4)?, + &[ + -0.6391, -0.9414, 1.8965, -0.5444, 2.0007, 0.1283, 0.4287, 0.014, 0.4387, -0.4818, + 0.1085, -0.0842, -1.6809, -2.0057, -2.6714, 0.8328, 0.2262, -0.1268, 0.8943, -0.0123, + 0.3377, 0.3973, 0.8928, -0.5021, 0.0861, -0.2324, 0.0451, -0.0884, 1.2311, -2.1603, + 0.1327, 0.7939, -1.055, 0.0589, -1.9002, 1.8912, 0.2918, -0.3253, -0.7993, 1.0741, + 0.6063, 0.7955, 0.0617, -0.6536, -0.3754, 2.3461, -0.8284, 2.0495, -0.201, 0.6476, + -1.4446, 1.7665, 1.1493, -0.4556, 0.4741, 0.2097, 0.7723, 0.3031, -0.0186, -1.5905, + 0.053, -0.0572, 0.165, 0.7746, 0.3862, -1.0481, -1.9422, 0.7624, -2.6231, 0.9933, + 0.2498, -0.0381, 1.2061, 0.6327, -0.7681, 0.2004, 1.0396, 0.037, 0.109, -0.5125, + -0.8009, 0.2559, 0.4865, 1.5324, 1.1861, -1.1461, 0.5261, -1.5372, -0.689, 0.957, + -0.1587, 0.1745, -0.2616, 0.2156, 0.8931, 1.0375, -1.2614, -1.7691, 0.0015, -0.0966, + 0.6921, -1.6605, 0.5866, -0.6313, 0.226, 0.1258, -0.9939, 0.5378, 1.3484, -2.0319, + -0.1574, 0.1568, 0.1034, 1.5574, -0.9614, -0.0967, -0.313, -0.7047, -1.5264, 1.0134 + ] + ); + let bn2 = BatchNorm::new( + 5, + running_mean, + running_var, + Tensor::new(&[0.5f32], &Device::Cpu)?.broadcast_as(5)?, + Tensor::new(&[-1.5f32], &Device::Cpu)?.broadcast_as(5)?, + 1e-8, + )?; + let output2 = bn2.forward_train(&input)?; + assert_eq!(output2.dims(), &[2, 5, 3, 4]); + let output2 = output2.flatten_all()?; + let diff2 = ((output2 - (output * 0.5)?)? + 1.5)?.sqr()?; + let sum_diff2 = diff2.sum_keepdim(0)?; + assert_eq!(test_utils::to_vec1_round(&sum_diff2, 4)?, &[0f32]); + + assert_eq!( + test_utils::to_vec1_round(bn.running_mean(), 4)?, + &[-0.0133, 0.0197, -0.0153, -0.0073, -0.0020] + ); + assert_eq!( + test_utils::to_vec1_round(bn.running_var(), 4)?, + &[0.9972, 0.9842, 0.9956, 0.9866, 0.9898] + ); + Ok(()) +} + +// This test makes sure that we can train a batch norm layer using a VarMap. +#[test] +fn train_batch_norm() -> Result<()> { + let vm = VarMap::new(); + let vb = VarBuilder::from_varmap(&vm, DType::F32, &Device::Cpu); + let bn = batch_norm(1, BatchNormConfig::default(), vb)?; + // Get a copy of the original mean to ensure it is being updated. + let original_mean = bn.running_mean().detach().copy()?; + let var_map_mean = { + vm.data() + .lock() + .unwrap() + .get("running_mean") + .unwrap() + .clone() + }; + // Ensure the var map mean is the same as the running mean. + assert_eq!( + test_utils::to_vec1_round(bn.running_mean(), 4)?, + test_utils::to_vec1_round(var_map_mean.as_tensor(), 4)?, + ); + // Train with a something guaranteed to be different from the running mean. + let mean_plus_one = { + let one = original_mean.ones_like()?; + original_mean.add(&one)?.reshape((1, 1))? + }; + + bn.forward_train(&mean_plus_one)?; + // Assert that the running mean has been updated. + assert_ne!( + test_utils::to_vec1_round(bn.running_mean(), 4)?, + test_utils::to_vec1_round(&original_mean, 4)?, + ); + + // Assert that the var map mean has been updated. + assert_eq!( + test_utils::to_vec1_round(bn.running_mean(), 4)?, + test_utils::to_vec1_round(var_map_mean.as_tensor(), 4)?, + ); + Ok(()) +} diff --git a/mcandle/mcandle-nn/tests/group_norm.rs b/mcandle/mcandle-nn/tests/group_norm.rs new file mode 100644 index 000000000..2e6e761f6 --- /dev/null +++ b/mcandle/mcandle-nn/tests/group_norm.rs @@ -0,0 +1,105 @@ +/* Equivalent PyTorch code. +import torch +from torch.nn.functional import group_norm +t = torch.tensor( + [[[-0.3034, 0.2726, -0.9659], + [-1.1845, -1.3236, 0.0172], + [ 1.9507, 1.2554, -0.8625], + [ 1.0682, 0.3604, 0.3985], + [-0.4957, -0.4461, -0.9721], + [ 1.5157, -0.1546, -0.5596]], + + [[-1.6698, -0.4040, -0.7927], + [ 0.3736, -0.0975, -0.1351], + [-0.9461, 0.5461, -0.6334], + [-1.0919, -0.1158, 0.1213], + [-0.9535, 0.1281, 0.4372], + [-0.2845, 0.3488, 0.5641]]]) +print(group_norm(t, num_groups=2)) +print(group_norm(t, num_groups=3)) +*/ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use anyhow::Result; +use mcandle_core::test_utils::to_vec3_round; +use mcandle_core::{Device, Tensor}; +use mcandle_nn::{GroupNorm, Module}; + +#[test] +fn group_norm() -> Result<()> { + let device = &Device::Cpu; + let w = Tensor::from_vec(vec![1f32; 6], 6, device)?; + let b = Tensor::from_vec(vec![0f32; 6], 6, device)?; + let gn2 = GroupNorm::new(w.clone(), b.clone(), 6, 2, 1e-5)?; + let gn3 = GroupNorm::new(w, b, 6, 3, 1e-5)?; + + let input = Tensor::new( + &[ + [ + [-0.3034f32, 0.2726, -0.9659], + [-1.1845, -1.3236, 0.0172], + [1.9507, 1.2554, -0.8625], + [1.0682, 0.3604, 0.3985], + [-0.4957, -0.4461, -0.9721], + [1.5157, -0.1546, -0.5596], + ], + [ + [-1.6698, -0.4040, -0.7927], + [0.3736, -0.0975, -0.1351], + [-0.9461, 0.5461, -0.6334], + [-1.0919, -0.1158, 0.1213], + [-0.9535, 0.1281, 0.4372], + [-0.2845, 0.3488, 0.5641], + ], + ], + device, + )?; + assert_eq!( + to_vec3_round(&gn2.forward(&input)?, 4)?, + &[ + [ + [-0.1653, 0.3748, -0.7866], + [-0.9916, -1.1220, 0.1353], + [1.9485, 1.2965, -0.6896], + [1.2769, 0.3628, 0.4120], + [-0.7427, -0.6786, -1.3578], + [1.8547, -0.3022, -0.8252] + ], + [ + [-1.9342, 0.0211, -0.5793], + [1.2223, 0.4945, 0.4365], + [-0.8163, 1.4887, -0.3333], + [-1.7960, -0.0392, 0.3875], + [-1.5469, 0.3998, 0.9561], + [-0.3428, 0.7970, 1.1845] + ] + ] + ); + assert_eq!( + to_vec3_round(&gn3.forward(&input)?, 4)?, + &[ + [ + [0.4560, 1.4014, -0.6313], + [-0.9901, -1.2184, 0.9822], + [1.4254, 0.6360, -1.7682], + [0.4235, -0.3800, -0.3367], + [-0.3890, -0.3268, -0.9862], + [2.1325, 0.0386, -0.4691] + ], + [ + [-1.8797, 0.0777, -0.5234], + [1.2802, 0.5517, 0.4935], + [-1.0102, 1.5327, -0.4773], + [-1.2587, 0.4047, 0.8088], + [-1.9074, 0.1691, 0.7625], + [-0.6230, 0.5928, 1.0061] + ] + ] + ); + + Ok(()) +} diff --git a/mcandle/mcandle-nn/tests/kv_cache.rs b/mcandle/mcandle-nn/tests/kv_cache.rs new file mode 100644 index 000000000..5913c99dd --- /dev/null +++ b/mcandle/mcandle-nn/tests/kv_cache.rs @@ -0,0 +1,110 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use mcandle_core::{Device, Result, Tensor}; + +#[test] +fn kv_cache() -> Result<()> { + let mut cache = mcandle_nn::kv_cache::Cache::new(0, 16); + for _ in [0, 1] { + assert_eq!(cache.current_seq_len(), 0); + let data = cache.current_data()?; + assert!(data.is_none()); + let t = Tensor::new(&[1f32, 2., 3.], &Device::Cpu)?; + cache.append(&t)?; + let data = cache.current_data()?.unwrap(); + assert_eq!(data.to_vec1::()?, [1., 2., 3.]); + let t = Tensor::new(&[4f32], &Device::Cpu)?; + cache.append(&t)?; + let data = cache.current_data()?.unwrap(); + assert_eq!(data.to_vec1::()?, [1., 2., 3., 4.]); + let t = Tensor::new(&[0f32, 5., 6., 7.], &Device::Cpu)?; + cache.append(&t)?; + let data = cache.current_data()?.unwrap(); + assert_eq!(data.to_vec1::()?, [1., 2., 3., 4., 0., 5., 6., 7.]); + assert_eq!(cache.current_seq_len(), 8); + cache.reset(); + } + Ok(()) +} + +#[test] +fn rotating_kv_cache() -> Result<()> { + let mut cache = mcandle_nn::kv_cache::RotatingCache::new(0, 6); + for _ in [0, 1] { + assert_eq!(cache.offset(), 0); + assert_eq!(cache.current_seq_len(), 0); + let data = cache.current_data()?; + assert!(data.is_none()); + let t = Tensor::new(&[1., 2., 3.], &Device::Cpu)?; + let data = cache.append(&t)?; + assert_eq!(data.to_vec1::()?, [1., 2., 3.]); + let t = Tensor::new(&[4.], &Device::Cpu)?; + let data = cache.append(&t)?; + assert_eq!(data.to_vec1::()?, [1., 2., 3., 4.]); + let t = Tensor::new(&[0., 5., 6., 7.], &Device::Cpu)?; + let data = cache.append(&t)?; + assert_eq!(data.to_vec1::()?, [6., 7., 3., 4., 0., 5.]); + assert_eq!(cache.current_seq_len(), 8); + assert_eq!(cache.offset(), 2); + + let t = Tensor::new(&[8.], &Device::Cpu)?; + let data = cache.append(&t)?; + assert_eq!(data.to_vec1::()?, [6., 7., 8., 4., 0., 5.]); + assert_eq!(cache.current_seq_len(), 9); + assert_eq!(cache.offset(), 3); + + let t = Tensor::new(&[9., 10., 11.], &Device::Cpu)?; + let data = cache.append(&t)?; + assert_eq!(data.to_vec1::()?, [6., 7., 8., 9., 10., 11.]); + assert_eq!(cache.current_seq_len(), 12); + assert_eq!(cache.offset(), 0); + + let t = Tensor::new(&[12.], &Device::Cpu)?; + let data = cache.append(&t)?; + assert_eq!(data.to_vec1::()?, [12., 7., 8., 9., 10., 11.]); + assert_eq!(cache.current_seq_len(), 13); + assert_eq!(cache.offset(), 1); + + let mask = cache.attn_mask(2, &Device::Cpu)?.unwrap(); + assert_eq!( + mask.to_vec2::()?, + &[[0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 0, 0]] + ); + let mask = cache.attn_mask(3, &Device::Cpu)?.unwrap(); + assert_eq!( + mask.to_vec2::()?, + &[[0, 0, 1, 1, 0, 0], [0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 0, 0]], + ); + let t = Tensor::new(&[0., 1., 2., 3., 4., 5., 6., 7., 8.], &Device::Cpu)?; + let data = cache.append(&t)?; + assert_eq!(data.to_vec1::()?, [0., 1., 2., 3., 4., 5., 6., 7., 8.]); + assert_eq!(cache.current_seq_len(), 22); + assert_eq!(cache.offset(), 0); + + let mask = cache.attn_mask(1, &Device::Cpu)?; + assert!(mask.is_none()); + let mask = cache.attn_mask(2, &Device::Cpu)?.unwrap(); + assert_eq!( + mask.to_vec2::()?, + &[[0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]] + ); + let mask = cache.attn_mask(3, &Device::Cpu)?.unwrap(); + assert_eq!( + mask.to_vec2::()?, + &[[0, 1, 1, 0, 0, 0], [0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 0, 0]] + ); + let t = Tensor::new(&[42.], &Device::Cpu)?; + + let data = cache.append(&t)?; + assert_eq!(data.to_vec1::()?, [42., 4., 5., 6., 7., 8.]); + assert_eq!(cache.current_seq_len(), 23); + assert_eq!(cache.offset(), 1); + + cache.reset(); + } + Ok(()) +} diff --git a/mcandle/mcandle-nn/tests/layer_norm.rs b/mcandle/mcandle-nn/tests/layer_norm.rs new file mode 100644 index 000000000..1ff9de1b3 --- /dev/null +++ b/mcandle/mcandle-nn/tests/layer_norm.rs @@ -0,0 +1,55 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use anyhow::Result; +use mcandle_core::{test_utils, Device, Tensor}; +use mcandle_nn::{LayerNorm, Module}; + +#[test] +fn layer_norm() -> Result<()> { + let device = &Device::Cpu; + let w = Tensor::new(&[3f32], device)?; + let b = Tensor::new(&[0.5f32], device)?; + let ln2 = LayerNorm::new(Tensor::cat(&[&w, &w], 0)?, Tensor::cat(&[&b, &b], 0)?, 1e-8); + let ln3 = LayerNorm::new( + Tensor::cat(&[&w, &w, &w], 0)?, + Tensor::cat(&[&b, &b, &b], 0)?, + 1e-8, + ); + let ln = LayerNorm::new(w, b, 1e-8); + + let two = Tensor::new(&[[[2f32]]], device)?; + let res = ln.forward(&two)?.flatten_all()?; + assert_eq!(res.to_vec1::()?, [0.5f32]); + + let inp = Tensor::new(&[[[4f32, 0f32]]], device)?; + let res = ln2.forward(&inp)?; + assert_eq!(res.to_vec3::()?, [[[3.5f32, -2.5]]]); + + let inp = Tensor::new(&[[[1f32, 2., 3.], [4., 5., 6.], [9., 8., 7.]]], device)?; + let res = ln3.forward(&inp)?; + assert_eq!( + test_utils::to_vec3_round(&res, 4)?, + [[ + [-3.1742, 0.5, 4.1742], + [-3.1742, 0.5, 4.1742], + [4.1742, 0.5, -3.1742] + ]] + ); + let mean = (res.sum_keepdim(2)? / 3.0)?; + // The average value should be `b`. + assert_eq!( + test_utils::to_vec3_round(&mean, 4)?, + [[[0.5], [0.5], [0.5]]] + ); + let std = (res.broadcast_sub(&mean)?.sqr()?.sum_keepdim(2)?.sqrt()? / 3.0)?; + // The standard deviation should be sqrt(`w`). + assert_eq!( + test_utils::to_vec3_round(&std, 4)?, + [[[1.7321], [1.7321], [1.7321]]] + ); + Ok(()) +} diff --git a/mcandle/mcandle-nn/tests/loss.rs b/mcandle/mcandle-nn/tests/loss.rs new file mode 100644 index 000000000..587a4280a --- /dev/null +++ b/mcandle/mcandle-nn/tests/loss.rs @@ -0,0 +1,88 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use mcandle_core::test_utils::to_vec0_round; +use mcandle_core::{Device, Result, Tensor}; + +/* Equivalent python code: +import torch +import torch.nn.functional as F +input = torch.tensor([ + [ 1.1050, 0.3013, -1.5394, -2.1528, -0.8634], + [ 1.0730, -0.9419, -0.1670, -0.6582, 0.5061], + [ 0.8318, 1.1154, -0.3610, 0.5351, 1.0830]]) + +target = torch.tensor([1, 0, 4]) +print(F.nll_loss(F.log_softmax(input, dim=1), target)) +print(F.cross_entropy(input, target)) +*/ +#[test] +fn nll_and_cross_entropy() -> Result<()> { + let cpu = Device::Cpu; + let input = Tensor::new( + &[ + [1.1050f32, 0.3013, -1.5394, -2.1528, -0.8634], + [1.0730, -0.9419, -0.1670, -0.6582, 0.5061], + [0.8318, 1.1154, -0.3610, 0.5351, 1.0830], + ], + &cpu, + )?; + let target = Tensor::new(&[1u32, 0, 4], &cpu)?; + + let log_softmax = mcandle_nn::ops::log_softmax(&input, 1)?; + let loss = mcandle_nn::loss::nll(&log_softmax, &target)?; + assert_eq!(to_vec0_round(&loss, 4)?, 1.1312); + let loss = mcandle_nn::loss::cross_entropy(&input, &target)?; + assert_eq!(to_vec0_round(&loss, 4)?, 1.1312); + Ok(()) +} + +/* Equivalent python code: +import torch +import torch.nn.functional as F + +inp = torch.Tensor([[ 2.3611, -0.8813, -0.5006, -0.2178], + [ 0.0419, 0.0763, -1.0457, -1.6692], + [-1.0494, 0.8111, 1.5723, 1.2315], + [ 1.3081, 0.6641, 1.1802, -0.2547], + [ 0.5292, 0.7636, 0.3692, -0.8318]]) + +target = torch.Tensor([[0., 1., 0., 0.], + [0., 1., 0., 0.], + [0., 0., 0., 1.], + [1., 0., 0., 0.], + [0., 0., 1., 0.]]) + +print(F.binary_cross_entropy_with_logits(inp, target)) +*/ +#[test] +fn binary_cross_entropy_with_logit() -> Result<()> { + let cpu = Device::Cpu; + + let inp = [ + [2.3611f32, -0.8813, -0.5006, -0.2178], + [0.0419, 0.0763, -1.0457, -1.6692], + [-1.0494, 0.8111, 1.5723, 1.2315], + [1.3081, 0.6641, 1.1802, -0.2547], + [0.5292, 0.7636, 0.3692, -0.8318], + ]; + + let target = [ + [0.0f32, 1., 0., 0.], + [0., 1., 0., 0.], + [0., 0., 0., 1.], + [1., 0., 0., 0.], + [0., 0., 1., 0.], + ]; + + let inp = Tensor::new(&inp, &cpu)?; + let target = Tensor::new(&target, &cpu)?; + + let loss = mcandle_nn::loss::binary_cross_entropy_with_logit(&inp, &target)?; + + assert_eq!(to_vec0_round(&loss, 4)?, 0.8224); + Ok(()) +} diff --git a/mcandle/mcandle-nn/tests/one_hot.rs b/mcandle/mcandle-nn/tests/one_hot.rs new file mode 100644 index 000000000..c3f567dda --- /dev/null +++ b/mcandle/mcandle-nn/tests/one_hot.rs @@ -0,0 +1,120 @@ +use mcandle_core::{Result, Shape, Tensor}; +use mcandle_nn::encoding::one_hot; + +#[test] +fn test_i64_one_hot() -> Result<()> { + let device = mcandle_core::Device::Cpu; + + let indices = Tensor::new(vec![vec![0i64, 2], vec![1, -1]], &device)?; + let depth = 4; + + let on_value = 1.0; + let off_value = 0.0; + + let one_hot = one_hot::(indices, depth, on_value, off_value)?; + + let expected_matrix = [ + [[1., 0., 0., 0.], [0., 0., 1., 0.]], + [[0., 1., 0., 0.], [0., 0., 0., 0.]], + ]; + + assert_eq!(one_hot.shape(), &Shape::from((2, 2, depth))); + + let matrix = one_hot.to_vec3::()?; + + assert_eq!(matrix, expected_matrix); + + Ok(()) +} + +#[test] +fn test_rank_3_one_hot() -> Result<()> { + let device = mcandle_core::Device::Cpu; + + let indices = Tensor::new( + vec![ + vec![vec![0i64, 1], vec![2, 3]], + vec![vec![3, 1], vec![1, -1]], + ], + &device, + )?; + let depth = 4; + + let on_value = 1.0; + let off_value = 0.0; + + let one_hot = one_hot::(indices, depth, on_value, off_value)?; + + let expected_matrix = Tensor::new( + vec![ + vec![ + vec![vec![1f32, 0., 0., 0.], vec![0., 1., 0., 0.]], + vec![vec![0., 0., 1., 0.], vec![0., 0., 0., 1.]], + ], + vec![ + vec![vec![0., 0., 0., 1.], vec![0., 1., 0., 0.]], + vec![vec![0., 1., 0., 0.], vec![0., 0., 0., 0.]], + ], + ], + &device, + )?; + + assert_eq!(one_hot.shape(), expected_matrix.shape()); + assert_eq!(one_hot.dims(), expected_matrix.dims()); + + let matrix = one_hot.get(1)?.to_vec3::()?; + let expected_matrix = expected_matrix.get(1)?.to_vec3::()?; + + assert_eq!(matrix, expected_matrix); + + Ok(()) +} + +#[test] +fn test_u8_one_cold() -> Result<()> { + let device = mcandle_core::Device::Cpu; + let depth = 4; + let indices = Tensor::new(vec![vec![0i64, 2], vec![1, -1]], &device)?; + + let on_value = 0u8; + let off_value = 1; + + // Note that the method does not require the turbofish operator, as the type is inferred from the on_value. + let one_cold = one_hot(indices, depth, on_value, off_value)?; + + let expected_matrix = [[[0, 1, 1, 1], [1, 1, 0, 1]], [[1, 0, 1, 1], [1, 1, 1, 1]]]; + + assert_eq!(one_cold.shape(), &Shape::from((2, 2, depth))); + + let matrix = one_cold.to_vec3::()?; + + assert_eq!(matrix, expected_matrix); + + Ok(()) +} + +#[test] +fn test_iter() -> Result<()> { + let device = mcandle_core::Device::Cpu; + let depth = 4; + let indices = Tensor::new(vec![vec![0i64, 2], vec![1, -1]], &device)?; + let matrix = indices.to_vec2::()?; + let (dim1, dim2) = indices.dims2()?; + + let iter = (0..dim1).flat_map(|i| (0..dim2).map(move |j| (i, j))); + + let mut v = vec![0; depth * dim1 * dim2]; + + for (i, j) in iter { + let idx = i * depth * dim2 + j * depth; + v[idx] = matrix[i][j]; + } + + for (i, row) in matrix.iter().enumerate() { + for (j, &value) in row.iter().enumerate() { + let idx = i * depth * dim2 + j * depth; + assert_eq!(v[idx], value); + } + } + Ok(()) +} diff --git a/mcandle/mcandle-nn/tests/ops.rs b/mcandle/mcandle-nn/tests/ops.rs new file mode 100644 index 000000000..aa6231131 --- /dev/null +++ b/mcandle/mcandle-nn/tests/ops.rs @@ -0,0 +1,260 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use mcandle_core::{test_device, test_utils::to_vec3_round, Device, Result, Tensor}; + +fn softmax(device: &Device) -> Result<()> { + let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]]; + let tensor = Tensor::new(data, device)?; + let t0 = mcandle_nn::ops::softmax(&tensor.log()?, 0)?; + let t1 = mcandle_nn::ops::softmax(&tensor.log()?, 1)?; + let t2 = mcandle_nn::ops::softmax(&tensor.log()?, 2)?; + assert_eq!( + to_vec3_round(&t0, 4)?, + &[ + // 3/5, 1/2, 4/11 + [[0.6, 0.5, 0.3636], [0.1111, 0.7143, 0.5294]], + // 2/5, 1/2, 7/11 + [[0.4, 0.5, 0.6364], [0.8889, 0.2857, 0.4706]] + ] + ); + assert_eq!( + to_vec3_round(&t1, 4)?, + &[ + // 3/4, 1/6, 4/13 + [[0.75, 0.1667, 0.3077], [0.25, 0.8333, 0.6923]], + // 2/10, 1/3, 7/15 + [[0.2, 0.3333, 0.4667], [0.8, 0.6667, 0.5333]] + ] + ); + assert_eq!( + to_vec3_round(&t2, 4)?, + &[ + // (3, 1, 4) / 8, (1, 5, 9) / 15 + [[0.375, 0.125, 0.5], [0.0667, 0.3333, 0.6]], + // (2, 1, 7) / 10, (8, 2, 8) / 18 + [[0.2, 0.1, 0.7], [0.4444, 0.1111, 0.4444]] + ] + ); + let t2 = mcandle_nn::ops::softmax_last_dim(&tensor.log()?)?; + assert_eq!( + to_vec3_round(&t2, 4)?, + &[ + // (3, 1, 4) / 8, (1, 5, 9) / 15 + [[0.375, 0.125, 0.5], [0.0667, 0.3333, 0.6]], + // (2, 1, 7) / 10, (8, 2, 8) / 18 + [[0.2, 0.1, 0.7], [0.4444, 0.1111, 0.4444]] + ] + ); + Ok(()) +} + +fn rms_norm(device: &Device) -> Result<()> { + let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]]; + let tensor = Tensor::new(data, device)?; + let alpha = Tensor::new(&[1f32, 2f32, 3f32], device)?; + let t = mcandle_nn::ops::rms_norm(&tensor, &alpha, 1e-5)?; + assert_eq!( + to_vec3_round(&t, 4)?, + &[ + [[1.019, 0.6794, 4.0762], [0.1674, 1.6744, 4.521]], + [[0.4714, 0.4714, 4.9497], [1.206, 0.603, 3.6181]] + ] + ); + let t2 = mcandle_nn::ops::rms_norm_slow(&tensor, &alpha, 1e-5)?; + assert_eq!( + to_vec3_round(&t2, 4)?, + &[ + [[1.019, 0.6794, 4.0762], [0.1674, 1.6744, 4.521]], + [[0.4714, 0.4714, 4.9497], [1.206, 0.603, 3.6181]] + ] + ); + let diff = (t - t2)?.abs()?.sum_all()?.to_vec0::()?; + assert!(diff < 1e-5); + Ok(()) +} + +fn rms_norml(device: &Device) -> Result<()> { + use rand::{rngs::StdRng, Rng, SeedableRng}; + + let (b_size, seq_len, head_dim) = (24, 70, 64); + let el_count = b_size * seq_len * head_dim; + let mut rng = StdRng::seed_from_u64(299792458); + let src: Vec = (0..el_count).map(|_| rng.gen::()).collect(); + let tensor = Tensor::new(src, device)?.reshape((b_size, seq_len, head_dim))?; + let alpha = Tensor::ones(head_dim, mcandle_core::DType::F32, device)?; + let t = mcandle_nn::ops::rms_norm(&tensor, &alpha, 1e-5)?; + let t2 = mcandle_nn::ops::rms_norm_slow(&tensor, &alpha, 1e-5)?; + let diff = (t - t2)? + .abs()? + .flatten_all()? + .max(0)? + .reshape(())? + .to_vec0::()?; + assert!(diff < 1e-5); + Ok(()) +} + +fn layer_norm(device: &Device) -> Result<()> { + let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]]; + let tensor = Tensor::new(data, device)?; + let alpha = Tensor::new(&[1f32, 2f32, 3f32], device)?; + let beta = Tensor::new(&[0.5f32, 0f32, -0.2f32], device)?; + let t = mcandle_nn::ops::layer_norm(&tensor, &alpha, &beta, 1e-5)?; + assert_eq!( + to_vec3_round(&t, 4)?, + &[ + [[0.7673, -2.6726, 3.0071], [-0.7247, 0.0, 3.4742]], + [[-0.008, -1.778, 3.991], [1.2071, -2.8284, 1.9213]] + ] + ); + let t2 = mcandle_nn::ops::layer_norm_slow(&tensor, &alpha, &beta, 1e-5)?; + assert_eq!( + to_vec3_round(&t2, 4)?, + &[ + [[0.7673, -2.6726, 3.0071], [-0.7247, 0.0, 3.4742]], + [[-0.008, -1.778, 3.991], [1.2071, -2.8284, 1.9213]] + ] + ); + let diff = (t - t2)?.abs()?.sum_all()?.to_vec0::()?; + assert!(diff < 1e-5); + Ok(()) +} + +fn layer_norml(device: &Device) -> Result<()> { + use rand::{rngs::StdRng, Rng, SeedableRng}; + + let (b_size, seq_len, head_dim) = (24, 70, 64); + let el_count = b_size * seq_len * head_dim; + let mut rng = StdRng::seed_from_u64(299792458); + let src: Vec = (0..el_count).map(|_| rng.gen::()).collect(); + let tensor = Tensor::new(src, device)?.reshape((b_size, seq_len, head_dim))?; + let alpha = Tensor::ones(head_dim, mcandle_core::DType::F32, device)?; + let beta = Tensor::zeros(head_dim, mcandle_core::DType::F32, device)?; + let t = mcandle_nn::ops::layer_norm(&tensor, &alpha, &beta, 1e-5)?; + let t2 = mcandle_nn::ops::layer_norm_slow(&tensor, &alpha, &beta, 1e-5)?; + let diff = (t - t2)? + .abs()? + .flatten_all()? + .max(0)? + .reshape(())? + .to_vec0::()?; + assert!(diff < 1e-5); + Ok(()) +} + +#[test] +fn softmax_numerical_stability() -> Result<()> { + let dev = &Device::Cpu; + let xs = Tensor::new(&[1234f32, 0.], dev)?; + let softmax = mcandle_nn::ops::softmax(&xs, 0)?; + assert_eq!(softmax.to_vec1::()?, &[1f32, 0.]); + Ok(()) +} + +fn ropei(device: &Device) -> Result<()> { + use rand::{rngs::StdRng, Rng, SeedableRng}; + + let (b_size, num_head, seq_len, head_dim) = (2, 5, 10, 16); + let el_count = b_size * num_head * seq_len * head_dim; + let mut rng = StdRng::seed_from_u64(299792458); + let src: Vec = (0..el_count).map(|_| rng.gen::()).collect(); + let cos: Vec = (0..seq_len * head_dim / 2) + .map(|_| rng.gen::()) + .collect(); + let sin: Vec = (0..seq_len * head_dim / 2) + .map(|_| rng.gen::()) + .collect(); + let src = Tensor::from_vec(src, (b_size, num_head, seq_len, head_dim), device)?; + let cos = Tensor::from_vec(cos, (seq_len, head_dim / 2), device)?; + let sin = Tensor::from_vec(sin, (seq_len, head_dim / 2), device)?; + let rope1 = mcandle_nn::rotary_emb::rope_i(&src, &cos, &sin)?; + let rope2 = mcandle_nn::rotary_emb::rope_i_slow(&src, &cos, &sin)?; + let sum_diff = (rope1 - rope2)?.abs()?.sum_all()?.to_vec0::()?; + if device.is_cpu() { + assert_eq!(sum_diff, 0.); + } else { + assert!(sum_diff < 1e-4); + } + Ok(()) +} + +fn rope(device: &Device) -> Result<()> { + use rand::{rngs::StdRng, Rng, SeedableRng}; + + let (b_size, num_head, seq_len, head_dim) = (2, 5, 10, 16); + let el_count = b_size * num_head * seq_len * head_dim; + let mut rng = StdRng::seed_from_u64(299792458); + let src: Vec = (0..el_count).map(|_| rng.gen::()).collect(); + let cos: Vec = (0..seq_len * head_dim / 2) + .map(|_| rng.gen::()) + .collect(); + let sin: Vec = (0..seq_len * head_dim / 2) + .map(|_| rng.gen::()) + .collect(); + let src = Tensor::from_vec(src, (b_size, num_head, seq_len, head_dim), device)?; + let cos = Tensor::from_vec(cos, (seq_len, head_dim / 2), device)?; + let sin = Tensor::from_vec(sin, (seq_len, head_dim / 2), device)?; + let rope1 = mcandle_nn::rotary_emb::rope(&src, &cos, &sin)?; + let rope2 = mcandle_nn::rotary_emb::rope_slow(&src, &cos, &sin)?; + let sum_diff = (rope1 - rope2)?.abs()?.sum_all()?.to_vec0::()?; + if device.is_cpu() { + assert_eq!(sum_diff, 0.); + } else { + assert!(sum_diff < 1e-4); + } + Ok(()) +} + +fn rope_thd(device: &Device) -> Result<()> { + use rand::{rngs::StdRng, Rng, SeedableRng}; + + let (b_size, num_head, seq_len, head_dim) = (2, 5, 10, 16); + let el_count = b_size * num_head * seq_len * head_dim; + let mut rng = StdRng::seed_from_u64(299792458); + let src: Vec = (0..el_count).map(|_| rng.gen::()).collect(); + let cos: Vec = (0..seq_len * head_dim / 2) + .map(|_| rng.gen::()) + .collect(); + let sin: Vec = (0..seq_len * head_dim / 2) + .map(|_| rng.gen::()) + .collect(); + let src = Tensor::from_vec(src, (b_size, num_head, seq_len, head_dim), device)?; + let cos = Tensor::from_vec(cos, (seq_len, head_dim / 2), device)?; + let sin = Tensor::from_vec(sin, (seq_len, head_dim / 2), device)?; + let rope1 = { + let src = src.transpose(1, 2)?.contiguous()?; + mcandle_nn::rotary_emb::rope_thd(&src, &cos, &sin)?.transpose(1, 2)? + }; + let rope2 = mcandle_nn::rotary_emb::rope_slow(&src, &cos, &sin)?; + let sum_diff = (rope1 - rope2)?.abs()?.sum_all()?.to_vec0::()?; + if device.is_cpu() { + assert_eq!(sum_diff, 0.); + } else { + assert!(sum_diff < 1e-4); + } + Ok(()) +} + +fn sigmoid(device: &Device) -> Result<()> { + let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]]; + let tensor = Tensor::new(data, device)?; + let s1 = mcandle_nn::ops::sigmoid(&tensor)?; + let s2 = (1. / (1. + tensor.neg()?.exp()?)?)?; + let diff = (s1 - s2)?.abs()?.sum_all()?.to_vec0::()?; + assert_eq!(diff, 0.); + Ok(()) +} + +test_device!(ropei, ropei_cpu, ropei_gpu, ropei_metal); +test_device!(rope, rope_cpu, rope_gpu, rope_metal); +test_device!(rope_thd, rope_thd_cpu, rope_thd_gpu, rope_thd_metal); +test_device!(softmax, softmax_cpu, softmax_gpu, softmax_metal); +test_device!(rms_norm, rms_norm_cpu, rms_norm_gpu, rms_norm_metal); +test_device!(rms_norml, rms_norml_cpu, rms_norml_gpu, rms_norml_metal); +test_device!(layer_norm, ln_cpu, ln_gpu, ln_metal); +test_device!(layer_norml, lnl_cpu, lnl_gpu, lnl_metal); +test_device!(sigmoid, sigmoid_cpu, sigmoid_gpu, sigmoid_metal); diff --git a/mcandle/mcandle-nn/tests/optim.rs b/mcandle/mcandle-nn/tests/optim.rs new file mode 100644 index 000000000..23f883cd7 --- /dev/null +++ b/mcandle/mcandle-nn/tests/optim.rs @@ -0,0 +1,160 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use mcandle_core::test_utils::{to_vec0_round, to_vec2_round}; + +use anyhow::Result; +use mcandle_core::{DType, Device, Tensor, Var}; +use mcandle_nn::{AdamW, Linear, Module, Optimizer, ParamsAdamW, SGD}; + +#[test] +fn sgd_optim() -> Result<()> { + let x = Var::new(0f32, &Device::Cpu)?; + let mut sgd = SGD::new(vec![x.clone()], 0.1)?; + let xt = x.as_tensor(); + for _step in 0..100 { + let loss = ((xt - 4.2)? * (xt - 4.2)?)?; + sgd.backward_step(&loss)? + } + assert_eq!(x.to_scalar::()?, 4.199999); + Ok(()) +} + +/* The results of this test have been checked against the following PyTorch code. + import torch + from torch import optim + + w_gen = torch.tensor([[3., 1.]]) + b_gen = torch.tensor([-2.]) + + sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]]) + sample_ys = sample_xs.matmul(w_gen.t()) + b_gen + + m = torch.nn.Linear(2, 1) + with torch.no_grad(): + m.weight.zero_() + m.bias.zero_() + optimizer = optim.SGD(m.parameters(), lr=0.004, momentum=0.) + for _step in range(1000): + optimizer.zero_grad() + ys = m(sample_xs) + loss = ((ys - sample_ys)**2).sum() + loss.backward() + optimizer.step() + print(m.weight) + print(m.bias) +*/ +#[test] +fn sgd_linear_regression() -> Result<()> { + // Generate some linear data, y = 3.x1 + x2 - 2. + let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; + let b_gen = Tensor::new(-2f32, &Device::Cpu)?; + let gen = Linear::new(w_gen, Some(b_gen)); + let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; + let sample_ys = gen.forward(&sample_xs)?; + + // Now use backprop to run a linear regression between samples and get the coefficients back. + let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; + let b = Var::new(0f32, &Device::Cpu)?; + let mut sgd = SGD::new(vec![w.clone(), b.clone()], 0.004)?; + let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); + for _step in 0..1000 { + let ys = lin.forward(&sample_xs)?; + let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; + sgd.backward_step(&loss)?; + } + assert_eq!(w.to_vec2::()?, &[[2.9983196, 0.99790204]]); + assert_eq!(b.to_scalar::()?, -1.9796902); + Ok(()) +} + +/* The following test returns the same values as the PyTorch code below. +import torch +from torch import optim + +w_gen = torch.tensor([[3., 1.]]) +b_gen = torch.tensor([-2.]) + +sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]]) +sample_ys = sample_xs.matmul(w_gen.t()) + b_gen + +m = torch.nn.Linear(2, 1) +with torch.no_grad(): + m.weight.zero_() + m.bias.zero_() +optimizer = optim.AdamW(m.parameters(), lr=0.1) +for _step in range(100): + optimizer.zero_grad() + ys = m(sample_xs) + loss = ((ys - sample_ys)**2).sum() + loss.backward() + optimizer.step() +print(m.weight) +print(m.bias) +*/ +#[test] +fn adamw_linear_regression() -> Result<()> { + let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; + let b_gen = Tensor::new(-2f32, &Device::Cpu)?; + let gen = Linear::new(w_gen, Some(b_gen)); + let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; + let sample_ys = gen.forward(&sample_xs)?; + + // Now use backprop to run a linear regression between samples and get the coefficients back. + let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; + let b = Var::new(0f32, &Device::Cpu)?; + let params = ParamsAdamW { + lr: 0.1, + ..Default::default() + }; + let mut opt = AdamW::new(vec![w.clone(), b.clone()], params)?; + let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); + for _step in 0..100 { + let ys = lin.forward(&sample_xs)?; + let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; + opt.backward_step(&loss)?; + } + assert_eq!(to_vec2_round(w.as_tensor(), 4)?, &[[2.7257, 0.7097]]); + assert_eq!(to_vec0_round(b.as_tensor(), 4)?, 0.7873); + Ok(()) +} + +#[test] +fn adamw_linear_regression_varmap() -> Result<()> { + use mcandle_nn::Init::Const; + + // Similar as the previous test but using a VarMap. + let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; + let b_gen = Tensor::new(-2f32, &Device::Cpu)?; + let gen = Linear::new(w_gen, Some(b_gen)); + let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; + let sample_ys = gen.forward(&sample_xs)?; + + let mut var_map = mcandle_nn::VarMap::new(); + + let w = var_map.get((1, 2), "w", Const(0.), DType::F32, &Device::Cpu)?; + let b = var_map.get((), "b", Const(0.), DType::F32, &Device::Cpu)?; + let params = ParamsAdamW { + lr: 0.1, + ..Default::default() + }; + let mut opt = AdamW::new(var_map.all_vars(), params)?; + let lin = Linear::new(w, Some(b)); + for _step in 0..100 { + let ys = lin.forward(&sample_xs)?; + let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; + opt.backward_step(&loss)?; + } + assert_eq!(to_vec2_round(lin.weight(), 4)?, &[[2.7257, 0.7097]]); + assert_eq!(to_vec0_round(lin.bias().unwrap(), 4)?, 0.7873); + + var_map.set([("w", Tensor::zeros((1, 2), DType::F32, &Device::Cpu)?)].into_iter())?; + var_map.set([("b", Tensor::ones((), DType::F32, &Device::Cpu)?)].into_iter())?; + + assert_eq!(to_vec2_round(lin.weight(), 4)?, &[[0., 0.]]); + assert_eq!(to_vec0_round(lin.bias().unwrap(), 4)?, 1.); + Ok(()) +} diff --git a/mcandle/mcandle-nn/tests/rnn.rs b/mcandle/mcandle-nn/tests/rnn.rs new file mode 100644 index 000000000..6be04789b --- /dev/null +++ b/mcandle/mcandle-nn/tests/rnn.rs @@ -0,0 +1,101 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use mcandle_core::{test_utils::to_vec2_round, DType, Device, Result, Tensor}; +use mcandle_nn::RNN; + +/* The following test can be verified against PyTorch using the following snippet. +import torch +from torch import nn +lstm = nn.LSTM(2, 3, 1) +lstm.weight_ih_l0 = torch.nn.Parameter(torch.arange(0., 24.).reshape(12, 2).cos()) +lstm.weight_hh_l0 = torch.nn.Parameter(torch.arange(0., 36.).reshape(12, 3).sin()) +lstm.bias_ih_l0 = torch.nn.Parameter(torch.tensor([-1., 1., -0.5, 2, -1, 1, -0.5, 2, -1, 1, -0.5, 2])) +lstm.bias_hh_l0 = torch.nn.Parameter(torch.tensor([-1., 1., -0.5, 2, -1, 1, -0.5, 2, -1, 1, -0.5, 2]).cos()) +state = torch.zeros((1, 3)), torch.zeros((1, 3)) +for inp in [3., 1., 4., 1., 5., 9., 2.]: + inp = torch.tensor([[inp, inp * 0.5]]) + _out, state = lstm(inp, state) +print(state) +# (tensor([[ 0.9919, 0.1738, -0.1451]], grad_fn=...), tensor([[ 5.7250, 0.4458, -0.2908]], grad_fn=...)) +*/ +#[test] +fn lstm() -> Result<()> { + let cpu = &Device::Cpu; + let w_ih = Tensor::arange(0f32, 24f32, cpu)?.reshape((12, 2))?; + let w_ih = w_ih.cos()?; + let w_hh = Tensor::arange(0f32, 36f32, cpu)?.reshape((12, 3))?; + let w_hh = w_hh.sin()?; + let b_ih = Tensor::new( + &[-1f32, 1., -0.5, 2., -1., 1., -0.5, 2., -1., 1., -0.5, 2.], + cpu, + )?; + let b_hh = b_ih.cos()?; + let tensors: std::collections::HashMap<_, _> = [ + ("weight_ih_l0".to_string(), w_ih), + ("weight_hh_l0".to_string(), w_hh), + ("bias_ih_l0".to_string(), b_ih), + ("bias_hh_l0".to_string(), b_hh), + ] + .into_iter() + .collect(); + let vb = mcandle_nn::VarBuilder::from_tensors(tensors, DType::F32, cpu); + let lstm = mcandle_nn::lstm(2, 3, Default::default(), vb)?; + let mut state = lstm.zero_state(1)?; + for inp in [3f32, 1., 4., 1., 5., 9., 2.] { + let inp = Tensor::new(&[[inp, inp * 0.5]], cpu)?; + state = lstm.step(&inp, &state)? + } + let h = state.h(); + let c = state.c(); + assert_eq!(to_vec2_round(h, 4)?, &[[0.9919, 0.1738, -0.1451]]); + assert_eq!(to_vec2_round(c, 4)?, &[[5.725, 0.4458, -0.2908]]); + Ok(()) +} + +/* The following test can be verified against PyTorch using the following snippet. +import torch +from torch import nn +gru = nn.GRU(2, 3, 1) +gru.weight_ih_l0 = torch.nn.Parameter(torch.arange(0., 18.).reshape(9, 2).cos()) +gru.weight_hh_l0 = torch.nn.Parameter(torch.arange(0., 27.).reshape(9, 3).sin()) +gru.bias_ih_l0 = torch.nn.Parameter(torch.tensor([-1., 1., -0.5, 2, -1, 1, -0.5, 2, -1])) +gru.bias_hh_l0 = torch.nn.Parameter(torch.tensor([-1., 1., -0.5, 2, -1, 1, -0.5, 2, -1]).cos()) +state = torch.zeros((1, 3)) +for inp in [3., 1., 4., 1., 5., 9., 2.]: + inp = torch.tensor([[inp, inp * 0.5]]) + _out, state = gru(inp, state) +print(state) +# tensor([[ 0.0579, 0.8836, -0.9991]], grad_fn=) +*/ +#[test] +fn gru() -> Result<()> { + let cpu = &Device::Cpu; + let w_ih = Tensor::arange(0f32, 18f32, cpu)?.reshape((9, 2))?; + let w_ih = w_ih.cos()?; + let w_hh = Tensor::arange(0f32, 27f32, cpu)?.reshape((9, 3))?; + let w_hh = w_hh.sin()?; + let b_ih = Tensor::new(&[-1f32, 1., -0.5, 2., -1., 1., -0.5, 2., -1.], cpu)?; + let b_hh = b_ih.cos()?; + let tensors: std::collections::HashMap<_, _> = [ + ("weight_ih_l0".to_string(), w_ih), + ("weight_hh_l0".to_string(), w_hh), + ("bias_ih_l0".to_string(), b_ih), + ("bias_hh_l0".to_string(), b_hh), + ] + .into_iter() + .collect(); + let vb = mcandle_nn::VarBuilder::from_tensors(tensors, DType::F32, cpu); + let gru = mcandle_nn::gru(2, 3, Default::default(), vb)?; + let mut state = gru.zero_state(1)?; + for inp in [3f32, 1., 4., 1., 5., 9., 2.] { + let inp = Tensor::new(&[[inp, inp * 0.5]], cpu)?; + state = gru.step(&inp, &state)? + } + let h = state.h(); + assert_eq!(to_vec2_round(h, 4)?, &[[0.0579, 0.8836, -0.9991]]); + Ok(()) +} diff --git a/mcandle/mcandle-nn/tests/sdpa.rs b/mcandle/mcandle-nn/tests/sdpa.rs new file mode 100644 index 000000000..c1be98967 --- /dev/null +++ b/mcandle/mcandle-nn/tests/sdpa.rs @@ -0,0 +1,310 @@ +#[cfg(feature = "metal")] +mod metal_sdpa_tests { + #[test] + fn sdpa_full() -> mcandle_core::Result<()> { + use mcandle_core::{DType, Device, Tensor}; + + // Force seqlen = 100 + const BS: usize = 4; + const R: usize = 4; + const L: usize = 4; + const DK: usize = 64; + const H: usize = 3; + let scale: f64 = f64::from(DK as u32).sqrt().recip(); + + let device = Device::new_metal(0)?; + + let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?; + let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + + let ground_truth = { + let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?; + let att = mcandle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)? + .to_dtype(q.dtype())?; + att.matmul(&v.clone())? + }; + + let sdpa_output = mcandle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?; + + assert_eq!(ground_truth.shape(), sdpa_output.shape()); + + let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)? + .sum_all()? + .to_scalar()?; + + assert!(error <= 0.0005, "{}", error); + + Ok(()) + } + + #[test] + fn sdpa_vector() -> mcandle_core::Result<()> { + use mcandle_core::{DType, Device, Tensor}; + + // Allow vectorized, seqlen = 1 + const BS: usize = 4; + const R: usize = 1; + const L: usize = 1; + const DK: usize = 64; + const H: usize = 3; + let scale: f64 = f64::from(DK as u32).sqrt().recip(); + + let device = Device::new_metal(0)?; + + let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?; + let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + + let ground_truth = { + let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?; + let att = mcandle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)? + .to_dtype(q.dtype())?; + att.matmul(&v.clone())? + }; + + let sdpa_output = mcandle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?; + + assert_eq!(ground_truth.shape(), sdpa_output.shape()); + + let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)? + .sum_all()? + .to_scalar()?; + + assert!(error <= 0.0001, "{}", error); + + Ok(()) + } + + #[test] + fn sdpa_full_softcapping() -> mcandle_core::Result<()> { + use mcandle_core::{DType, Device, Tensor}; + use std::ops::{Div, Mul}; + + // Allow vectorized, seqlen = 1 + const BS: usize = 4; + const R: usize = 4; + const L: usize = 4; + const DK: usize = 64; + const H: usize = 3; + const SOFTCAP: f64 = 50.; + let scale: f64 = f64::from(DK as u32).sqrt().recip(); + + let device = Device::new_metal(0)?; + + let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?; + let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + + let ground_truth = { + let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?; + let att = mcandle_nn::ops::softmax_last_dim( + &att.to_dtype(DType::F32)? + .div(SOFTCAP)? + .tanh()? + .mul(SOFTCAP)?, + )? + .to_dtype(q.dtype())?; + att.matmul(&v.clone())? + }; + + let sdpa_output = mcandle_nn::ops::sdpa(&q, &k, &v, scale as f32, SOFTCAP as f32)?; + + assert_eq!(ground_truth.shape(), sdpa_output.shape()); + + let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)? + .sum_all()? + .to_scalar()?; + + assert!(error <= 0.0004, "{}", error); + + Ok(()) + } + + #[test] + fn sdpa_vector_softcapping() -> mcandle_core::Result<()> { + use mcandle_core::{DType, Device, Tensor}; + use std::ops::{Div, Mul}; + + // Allow vectorized, seqlen = 1 + const BS: usize = 4; + const R: usize = 1; + const L: usize = 1; + const DK: usize = 64; + const H: usize = 3; + const SOFTCAP: f64 = 50.; + let scale: f64 = f64::from(DK as u32).sqrt().recip(); + + let device = Device::new_metal(0)?; + + let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?; + let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + + let ground_truth = { + let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?; + let att = mcandle_nn::ops::softmax_last_dim( + &att.to_dtype(DType::F32)? + .div(SOFTCAP)? + .tanh()? + .mul(SOFTCAP)?, + )? + .to_dtype(q.dtype())?; + att.matmul(&v.clone())? + }; + + let sdpa_output = mcandle_nn::ops::sdpa(&q, &k, &v, scale as f32, SOFTCAP as f32)?; + + assert_eq!(ground_truth.shape(), sdpa_output.shape()); + + let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)? + .sum_all()? + .to_scalar()?; + + assert!(error <= 0.0001, "{}", error); + + Ok(()) + } + + #[test] + fn sdpa_vector_cross() -> mcandle_core::Result<()> { + use mcandle_core::{DType, Device, Tensor}; + + // Allow vectorized, seqlen = 1. Simulat cross attention case where R != L, R = 1 + const BS: usize = 4; + const R: usize = 1; + const L: usize = 24; + const DK: usize = 64; + const H: usize = 3; + let scale: f64 = f64::from(DK as u32).sqrt().recip(); + + let device = Device::new_metal(0)?; + + let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?; + let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + + let ground_truth = { + let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?; + let att = mcandle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)? + .to_dtype(q.dtype())?; + att.matmul(&v.clone())? + }; + + let sdpa_output = mcandle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?; + + assert_eq!(ground_truth.shape(), sdpa_output.shape()); + + let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)? + .sum_all()? + .to_scalar()?; + + assert!(error <= 0.0013, "{}", error); + + Ok(()) + } + + #[test] + fn attn_softmax_mask() -> mcandle_core::Result<()> { + use mcandle_core::{Device, Tensor}; + + let device = Device::new_metal(0)?; + + let tensor = Tensor::randn(0f32, 1f32, (4, 32, 64, 64), &device)?; + let truemask = Tensor::full(f32::MIN, (64, 64), &device)?.contiguous()?; + + let ground_truth = mcandle_nn::ops::softmax_last_dim(&tensor.broadcast_add(&truemask)?)?; + + let softmax_out = mcandle_nn::ops::attn_softmax_last_dim(&tensor, &truemask, 1.)?; + + let error: f32 = ((&ground_truth - &softmax_out)?.abs()? / &ground_truth.abs()?)? + .sum_all()? + .to_scalar()?; + + assert!(error < 1e-5); + + Ok(()) + } + + #[test] + fn attn_softmax_mask_scale() -> mcandle_core::Result<()> { + use mcandle_core::{DType, Device, Tensor}; + + let device = Device::new_metal(0)?; + + let tensor = Tensor::randn(0f32, 1f32, (4, 32, 64, 64), &device)?.to_dtype(DType::BF16)?; + let truemask = Tensor::full(half::bf16::MIN, (64, 64), &device)? + .contiguous()? + .to_dtype(DType::BF16)?; + + let scale = 0.1f32; + + let ground_truth = + mcandle_nn::ops::softmax_last_dim(&(tensor.broadcast_add(&truemask)? * scale as f64)?)? + .to_dtype(DType::F32)?; + + let softmax_out = mcandle_nn::ops::attn_softmax_last_dim(&tensor, &truemask, scale)? + .to_dtype(DType::F32)?; + + let error: f32 = ((&ground_truth - &softmax_out)?.abs()? / &ground_truth.abs()?)? + .sum_all()? + .to_dtype(DType::F32)? + .to_scalar()?; + + assert!(error < 1e-5, "{error}"); + + Ok(()) + } + + #[test] + fn attn_softmax_mask_novec() -> mcandle_core::Result<()> { + use mcandle_core::{Device, Tensor}; + + let device = Device::new_metal(0)?; + + let tensor = Tensor::randn(0f32, 1f32, (4, 32, 64, 63), &device)?; + let truemask = Tensor::full(f32::MIN, (64, 63), &device)?.contiguous()?; + + let ground_truth = mcandle_nn::ops::softmax_last_dim(&tensor.broadcast_add(&truemask)?)?; + + let softmax_out = mcandle_nn::ops::attn_softmax_last_dim(&tensor, &truemask, 1.)?; + + let error: f32 = ((&ground_truth - &softmax_out)?.abs()? / &ground_truth.abs()?)? + .sum_all()? + .to_scalar()?; + + assert!(error < 1e-5); + + Ok(()) + } + + #[test] + fn attn_softmax_mask_scale_novec() -> mcandle_core::Result<()> { + use mcandle_core::{DType, Device, Tensor}; + + let device = Device::new_metal(0)?; + + let tensor = Tensor::randn(0f32, 1f32, (4, 32, 64, 63), &device)?.to_dtype(DType::BF16)?; + let truemask = Tensor::full(half::bf16::MIN, (64, 63), &device)? + .contiguous()? + .to_dtype(DType::BF16)?; + + let scale = 0.1f32; + + let ground_truth = + mcandle_nn::ops::softmax_last_dim(&(tensor.broadcast_add(&truemask)? * scale as f64)?)? + .to_dtype(DType::F32)?; + + let softmax_out = mcandle_nn::ops::attn_softmax_last_dim(&tensor, &truemask, scale)? + .to_dtype(DType::F32)?; + + let error: f32 = ((&ground_truth - &softmax_out)?.abs()? / &ground_truth.abs()?)? + .sum_all()? + .to_dtype(DType::F32)? + .to_scalar()?; + + assert!(error < 1e-5, "{error}"); + + Ok(()) + } +} diff --git a/mistralrs-bench/Cargo.toml b/mistralrs-bench/Cargo.toml index 8d294cb4d..4847f3712 100644 --- a/mistralrs-bench/Cargo.toml +++ b/mistralrs-bench/Cargo.toml @@ -14,7 +14,7 @@ license.workspace = true [dependencies] anyhow.workspace = true -candle-core.workspace = true +mcandle-core.workspace = true serde.workspace = true serde_json.workspace = true clap.workspace = true diff --git a/mistralrs-bench/src/main.rs b/mistralrs-bench/src/main.rs index 53895d4c4..e18467c40 100644 --- a/mistralrs-bench/src/main.rs +++ b/mistralrs-bench/src/main.rs @@ -1,6 +1,6 @@ -use candle_core::Device; use clap::Parser; use cli_table::{format::Justify, print_stdout, Cell, CellStruct, Style, Table}; +use mcandle_core::Device; use mistralrs_core::{ get_model_dtype, initialize_logging, paged_attn_supported, parse_isq_value, Constraint, DefaultSchedulerMethod, DeviceLayerMapMetadata, DeviceMapMetadata, DrySamplingParams, IsqType, @@ -368,10 +368,10 @@ fn main() -> anyhow::Result<()> { let token_source = TokenSource::CacheToken; info!( "avx: {}, neon: {}, simd128: {}, f16c: {}", - candle_core::utils::with_avx(), - candle_core::utils::with_neon(), - candle_core::utils::with_simd128(), - candle_core::utils::with_f16c() + mcandle_core::utils::with_avx(), + mcandle_core::utils::with_neon(), + mcandle_core::utils::with_simd128(), + mcandle_core::utils::with_f16c() ); info!("Sampling method: penalties -> temperature -> topk -> topp -> minp -> multinomial"); if use_flash_attn { diff --git a/mistralrs-core/Cargo.toml b/mistralrs-core/Cargo.toml index c6f184bcb..2ed439ab7 100644 --- a/mistralrs-core/Cargo.toml +++ b/mistralrs-core/Cargo.toml @@ -13,11 +13,11 @@ homepage.workspace = true [dependencies] anyhow.workspace = true -candle-core.workspace = true -candle-nn.workspace = true +mcandle-core.workspace = true +mcandle-nn.workspace = true serde.workspace = true serde_json.workspace = true -candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.8.0", rev = "cb8082b", optional = true } +mcandle-flash-attn = { workspace = true, optional = true } dirs = "5.0.1" hf-hub = "0.3.2" thiserror = "1.0.57" @@ -82,12 +82,12 @@ float8.workspace = true [features] pyo3_macros = ["pyo3"] -cuda = ["candle-core/cuda", "candle-nn/cuda", "dep:bindgen_cuda", "mistralrs-quant/cuda", "dep:mistralrs-paged-attn", "mistralrs-paged-attn/cuda", "float8/cuda"] -cudnn = ["candle-core/cudnn"] -metal = ["candle-core/metal", "candle-nn/metal"] -flash-attn = ["cuda", "dep:candle-flash-attn"] -accelerate = ["candle-core/accelerate", "candle-nn/accelerate"] -mkl = ["candle-core/mkl", "candle-nn/mkl"] +cuda = ["mcandle-core/cuda", "mcandle-nn/cuda", "dep:bindgen_cuda", "mistralrs-quant/cuda", "dep:mistralrs-paged-attn", "mistralrs-paged-attn/cuda", "float8/cuda"] +cudnn = ["mcandle-core/cudnn"] +metal = ["mcandle-core/metal", "mcandle-nn/metal"] +flash-attn = ["cuda", "dep:mcandle-flash-attn"] +accelerate = ["mcandle-core/accelerate", "mcandle-nn/accelerate"] +mkl = ["mcandle-core/mkl", "mcandle-nn/mkl"] [build-dependencies] bindgen_cuda = { version = "0.1.5", optional = true } diff --git a/mistralrs-core/src/amoe/macros.rs b/mistralrs-core/src/amoe/macros.rs index 57e763c24..71dbc192e 100644 --- a/mistralrs-core/src/amoe/macros.rs +++ b/mistralrs-core/src/amoe/macros.rs @@ -29,7 +29,7 @@ macro_rules! merge_delta { QMatMul::QTensor(w) => { let (w, dtype) = (w.dequantize(&w.device())?, w.dtype()); QMatMul::QTensor(std::sync::Arc::new( - candle_core::quantized::QTensor::quantize(&(w + $delta)?, dtype)?, + mcandle_core::quantized::QTensor::quantize(&(w + $delta)?, dtype)?, )) } } diff --git a/mistralrs-core/src/amoe/mod.rs b/mistralrs-core/src/amoe/mod.rs index 0c3efa797..dbe0a22a7 100644 --- a/mistralrs-core/src/amoe/mod.rs +++ b/mistralrs-core/src/amoe/mod.rs @@ -5,8 +5,8 @@ use std::{ sync::{Arc, RwLock}, }; -use candle_core::{safetensors, DType, Device, Result, Tensor, Var, D}; -use candle_nn::{linear, Linear, ModuleT, VarBuilder, VarMap}; +use mcandle_core::{safetensors, DType, Device, Result, Tensor, Var, D}; +use mcandle_nn::{linear, Linear, ModuleT, VarBuilder, VarMap}; use mistralrs_quant::QuantMethod; use serde::{Deserialize, Serialize}; @@ -78,7 +78,7 @@ pub trait AnyMoeBaseModelMixin { _expert_type: AnyMoeExpertType, _gate_vb: Option, ) -> Result<()> { - candle_core::bail!("Model does not support AnyMoE layers"); + mcandle_core::bail!("Model does not support AnyMoE layers"); } fn get_mlps(&self) -> Vec<&dyn MlpLayer> { panic!("Model does not support AnyMoE layers"); @@ -165,9 +165,9 @@ impl ModuleT for MoeGate { fn forward_t(&self, xs: &Tensor, train: bool) -> Result { let hidden_states = xs.apply(&self.lin)?; if train { - candle_nn::ops::softmax(&hidden_states, D::Minus1) + mcandle_nn::ops::softmax(&hidden_states, D::Minus1) } else { - candle_nn::ops::softmax_last_dim(&hidden_states) + mcandle_nn::ops::softmax_last_dim(&hidden_states) } } } @@ -207,7 +207,7 @@ impl MoeMlp { let vars = var_map.all_vars(); if vars.is_empty() && !inference { - candle_core::bail!("No vars to train in MoeMlp, perhaps there are no layers?"); + mcandle_core::bail!("No vars to train in MoeMlp, perhaps there are no layers?"); } Ok(Self { experts, diff --git a/mistralrs-core/src/attention.rs b/mistralrs-core/src/attention.rs index 63390aad1..c5bc34ca2 100644 --- a/mistralrs-core/src/attention.rs +++ b/mistralrs-core/src/attention.rs @@ -6,7 +6,7 @@ use crate::{ pipeline::text_models_inputs_processor::FlashParams, }; -use candle_core::{Device, Result, Tensor}; +use mcandle_core::{Device, Result, Tensor}; #[cfg(feature = "flash-attn")] fn flash_attn( @@ -37,7 +37,7 @@ fn flash_attn( let window_size_right = if causal { Some(0) } else { None }; //dbg!(&qshape); - candle_flash_attn::flash_attn_varlen_windowed_softcap( + mcandle_flash_attn::flash_attn_varlen_windowed_softcap( &q, &k, &v, @@ -52,7 +52,7 @@ fn flash_attn( )? .reshape(qshape) } else { - candle_flash_attn::flash_attn_softcap( + mcandle_flash_attn::flash_attn_softcap( q, k, v, @@ -100,7 +100,8 @@ fn naive_sdpa( att = (att * softcap as f64)?; } - let att = candle_nn::ops::attn_softmax_last_dim(&att, mask, 1. / (head_dim as f32).sqrt())?; + let att = + mcandle_nn::ops::attn_softmax_last_dim(&att, mask, 1. / (head_dim as f32).sqrt())?; MatMul.matmul(&att, v) } else { let mut att = MatMul.matmul_affine_div(q, &k.t()?, (head_dim as f64).sqrt())?; @@ -114,7 +115,7 @@ fn naive_sdpa( Some(m) => att.broadcast_add(m)?, None => att, }; - let att = candle_nn::ops::softmax_last_dim(&att)?; + let att = mcandle_nn::ops::softmax_last_dim(&att)?; MatMul.matmul(&att, v) } } @@ -161,7 +162,7 @@ impl Sdpa { } if q.device().is_metal() && seq_len == 1 { - return candle_nn::ops::sdpa( + return mcandle_nn::ops::sdpa( q, k, v, @@ -203,7 +204,7 @@ impl Sdpa { if let Some(softcap) = sdpa_params.softcap { attention_scores = (attention_scores.tanh()? * softcap as f64)?; } - let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?; + let attention_probs = mcandle_nn::ops::softmax_last_dim(&attention_scores)?; let context_layer = cublaslt.batch_matmul( &v.t()?.contiguous()?, @@ -221,7 +222,7 @@ impl Sdpa { } #[cfg(not(feature = "cuda"))] { - candle_core::bail!("`cuda` feature is not enabled") + mcandle_core::bail!("`cuda` feature is not enabled") } } else { // Use the f16 kernels here if quantized (ISQ or GGML), and a large enough prompt diff --git a/mistralrs-core/src/cublaslt/api.rs b/mistralrs-core/src/cublaslt/api.rs index 8bb11d028..41edb10bb 100644 --- a/mistralrs-core/src/cublaslt/api.rs +++ b/mistralrs-core/src/cublaslt/api.rs @@ -1,11 +1,11 @@ -use candle_core::cuda::cudarc::driver::DevicePtr; use float8::F8E4M3; +use mcandle_core::cuda::cudarc::driver::DevicePtr; use std::ffi::c_int; -use candle_core::backend::BackendStorage; -use candle_core::cuda_backend::WrapErr; -use candle_core::{CpuStorage, DType, Device, Layout, Result, Shape, Storage, Tensor}; use half::{bf16, f16}; +use mcandle_core::backend::BackendStorage; +use mcandle_core::cuda_backend::WrapErr; +use mcandle_core::{CpuStorage, DType, Device, Layout, Result, Shape, Storage, Tensor}; use std::sync::Arc; use super::matmul::{Activation, CudaBlasLT, Matmul, MatmulConfig}; @@ -17,7 +17,7 @@ impl CublasLt { pub fn new(device: &Device) -> Result { let dev = match device { Device::Cuda(d) => d, - _ => candle_core::bail!("`device` must be a `cuda` device"), + _ => mcandle_core::bail!("`device` must be a `cuda` device"), }; let inner = CudaBlasLT::new(dev.cuda_device()).unwrap(); @@ -37,13 +37,13 @@ pub struct CublasLTMatmul { impl CublasLTMatmul { pub fn fwd_f16( &self, - a: &candle_core::CudaStorage, + a: &mcandle_core::CudaStorage, a_l: &Layout, - b: &candle_core::CudaStorage, + b: &mcandle_core::CudaStorage, b_l: &Layout, - bias: Option<&candle_core::CudaStorage>, + bias: Option<&mcandle_core::CudaStorage>, bias_l: Option<&Layout>, - ) -> Result<(candle_core::CudaStorage, Shape)> { + ) -> Result<(mcandle_core::CudaStorage, Shape)> { let dev = a.device(); // Assume TN @@ -52,7 +52,7 @@ impl CublasLTMatmul { let (n, b_1) = b_l.shape().dims2()?; if b_1 != k { - candle_core::bail!("This layer only supports TN layout"); + mcandle_core::bail!("This layer only supports TN layout"); } let lda = k; @@ -66,7 +66,7 @@ impl CublasLTMatmul { let bias = if let (Some(bias), Some(bias_l)) = (bias, bias_l) { if bias_l.shape().dims1()? != m { - candle_core::bail!("Bias does not have the correct shape"); + mcandle_core::bail!("Bias does not have the correct shape"); } Some(bias.as_cuda_slice::()?.slice(bias_l.start_offset()..)) @@ -78,21 +78,21 @@ impl CublasLTMatmul { let (c, c_l) = c.storage_and_layout(); let c = match &*c { Storage::Cuda(storage) => storage.as_cuda_slice::()?, - _ => candle_core::bail!("`c` must be a cuda tensor"), + _ => mcandle_core::bail!("`c` must be a cuda tensor"), }; match c_l.contiguous_offsets() { Some((o1, o2)) => { if o1 != 0 { - candle_core::bail!("`c` start offset must be 0"); + mcandle_core::bail!("`c` start offset must be 0"); } if o2 != out_shape.elem_count() { - candle_core::bail!("`c` end offset must be {}", out_shape.elem_count()) + mcandle_core::bail!("`c` end offset must be {}", out_shape.elem_count()) } } - None => candle_core::bail!("`c` has to be contiguous"), + None => mcandle_core::bail!("`c` has to be contiguous"), }; if c_l.shape().dims2()? != (n, m) { - candle_core::bail!("`c` does not have the correct shape"); + mcandle_core::bail!("`c` does not have the correct shape"); } c.clone() @@ -122,23 +122,23 @@ impl CublasLTMatmul { unsafe { self.cublaslt .matmul(config, &a, &b, &mut out, bias.as_ref(), self.act.as_ref()) - .map_err(|e| candle_core::Error::Cuda(Box::new(e)))?; + .map_err(|e| mcandle_core::Error::Cuda(Box::new(e)))?; } - let out = candle_core::CudaStorage::wrap_cuda_slice(out, dev.clone()); + let out = mcandle_core::CudaStorage::wrap_cuda_slice(out, dev.clone()); Ok((out, out_shape)) } pub fn fwd_bf16( &self, - a: &candle_core::CudaStorage, + a: &mcandle_core::CudaStorage, a_l: &Layout, - b: &candle_core::CudaStorage, + b: &mcandle_core::CudaStorage, b_l: &Layout, - bias: Option<&candle_core::CudaStorage>, + bias: Option<&mcandle_core::CudaStorage>, bias_l: Option<&Layout>, - ) -> Result<(candle_core::CudaStorage, Shape)> { + ) -> Result<(mcandle_core::CudaStorage, Shape)> { let dev = a.device(); // Assume TN @@ -147,7 +147,7 @@ impl CublasLTMatmul { let (n, b_1) = b_l.shape().dims2()?; if b_1 != k { - candle_core::bail!("This layer only supports TN layout"); + mcandle_core::bail!("This layer only supports TN layout"); } let lda = k; @@ -161,7 +161,7 @@ impl CublasLTMatmul { let bias = if let (Some(bias), Some(bias_l)) = (bias, bias_l) { if bias_l.shape().dims1()? != m { - candle_core::bail!("Bias does not have the correct shape"); + mcandle_core::bail!("Bias does not have the correct shape"); } Some(bias.as_cuda_slice::()?.slice(bias_l.start_offset()..)) @@ -173,21 +173,21 @@ impl CublasLTMatmul { let (c, c_l) = c.storage_and_layout(); let c = match &*c { Storage::Cuda(storage) => storage.as_cuda_slice::()?, - _ => candle_core::bail!("`c` must be a cuda tensor"), + _ => mcandle_core::bail!("`c` must be a cuda tensor"), }; match c_l.contiguous_offsets() { Some((o1, o2)) => { if o1 != 0 { - candle_core::bail!("`c` start offset must be 0"); + mcandle_core::bail!("`c` start offset must be 0"); } if o2 != out_shape.elem_count() { - candle_core::bail!("`c` end offset must be {}", out_shape.elem_count()) + mcandle_core::bail!("`c` end offset must be {}", out_shape.elem_count()) } } - None => candle_core::bail!("`c` has to be contiguous"), + None => mcandle_core::bail!("`c` has to be contiguous"), }; if c_l.shape().dims2()? != (n, m) { - candle_core::bail!("`c` does not have the correct shape"); + mcandle_core::bail!("`c` does not have the correct shape"); } c.clone() @@ -217,23 +217,23 @@ impl CublasLTMatmul { unsafe { self.cublaslt .matmul(config, &a, &b, &mut out, bias.as_ref(), self.act.as_ref()) - .map_err(|e| candle_core::Error::Cuda(Box::new(e)))?; + .map_err(|e| mcandle_core::Error::Cuda(Box::new(e)))?; } - let out = candle_core::CudaStorage::wrap_cuda_slice(out, dev.clone()); + let out = mcandle_core::CudaStorage::wrap_cuda_slice(out, dev.clone()); Ok((out, out_shape)) } pub fn fwd_f32( &self, - a: &candle_core::CudaStorage, + a: &mcandle_core::CudaStorage, a_l: &Layout, - b: &candle_core::CudaStorage, + b: &mcandle_core::CudaStorage, b_l: &Layout, - bias: Option<&candle_core::CudaStorage>, + bias: Option<&mcandle_core::CudaStorage>, bias_l: Option<&Layout>, - ) -> Result<(candle_core::CudaStorage, Shape)> { + ) -> Result<(mcandle_core::CudaStorage, Shape)> { let dev = a.device(); // Assume TN @@ -242,7 +242,7 @@ impl CublasLTMatmul { let (n, b_1) = b_l.shape().dims2()?; if b_1 != k { - candle_core::bail!("This layer only supports TN layout"); + mcandle_core::bail!("This layer only supports TN layout"); } let lda = k; @@ -256,7 +256,7 @@ impl CublasLTMatmul { let bias = if let (Some(bias), Some(bias_l)) = (bias, bias_l) { if bias_l.shape().dims1()? != m { - candle_core::bail!("Bias does not have the correct shape"); + mcandle_core::bail!("Bias does not have the correct shape"); } Some(bias.as_cuda_slice::()?.slice(bias_l.start_offset()..)) @@ -268,21 +268,21 @@ impl CublasLTMatmul { let (c, c_l) = c.storage_and_layout(); let c = match &*c { Storage::Cuda(storage) => storage.as_cuda_slice::()?, - _ => candle_core::bail!("`c` must be a cuda tensor"), + _ => mcandle_core::bail!("`c` must be a cuda tensor"), }; match c_l.contiguous_offsets() { Some((o1, o2)) => { if o1 != 0 { - candle_core::bail!("`c` start offset must be 0"); + mcandle_core::bail!("`c` start offset must be 0"); } if o2 != out_shape.elem_count() { - candle_core::bail!("`c` end offset must be {}", out_shape.elem_count()) + mcandle_core::bail!("`c` end offset must be {}", out_shape.elem_count()) } } - None => candle_core::bail!("`c` has to be contiguous"), + None => mcandle_core::bail!("`c` has to be contiguous"), }; if c_l.shape().dims2()? != (n, m) { - candle_core::bail!("`c` does not have the correct shape"); + mcandle_core::bail!("`c` does not have the correct shape"); } c.clone() @@ -312,16 +312,16 @@ impl CublasLTMatmul { unsafe { self.cublaslt .matmul(config, &a, &b, &mut out, bias.as_ref(), self.act.as_ref()) - .map_err(|e| candle_core::Error::Cuda(Box::new(e)))?; + .map_err(|e| mcandle_core::Error::Cuda(Box::new(e)))?; } - let out = candle_core::CudaStorage::wrap_cuda_slice(out, dev.clone()); + let out = mcandle_core::CudaStorage::wrap_cuda_slice(out, dev.clone()); Ok((out, out_shape)) } } -impl candle_core::CustomOp2 for CublasLTMatmul { +impl mcandle_core::CustomOp2 for CublasLTMatmul { fn name(&self) -> &'static str { "cublaslt-matmul" } @@ -333,26 +333,28 @@ impl candle_core::CustomOp2 for CublasLTMatmul { _: &CpuStorage, _: &Layout, ) -> Result<(CpuStorage, Shape)> { - candle_core::bail!("no cpu support for cublaslt-matmul") + mcandle_core::bail!("no cpu support for cublaslt-matmul") } fn cuda_fwd( &self, - a: &candle_core::CudaStorage, + a: &mcandle_core::CudaStorage, a_l: &Layout, - b: &candle_core::CudaStorage, + b: &mcandle_core::CudaStorage, b_l: &Layout, - ) -> Result<(candle_core::CudaStorage, Shape)> { + ) -> Result<(mcandle_core::CudaStorage, Shape)> { match a.dtype() { - candle_core::DType::F16 => self.fwd_f16(a, a_l, b, b_l, None, None), - candle_core::DType::BF16 => self.fwd_bf16(a, a_l, b, b_l, None, None), - candle_core::DType::F32 => self.fwd_f32(a, a_l, b, b_l, None, None), - dt => candle_core::bail!("cublaslt-matmul is only supported for f16/bf16/f32 ({dt:?})"), + mcandle_core::DType::F16 => self.fwd_f16(a, a_l, b, b_l, None, None), + mcandle_core::DType::BF16 => self.fwd_bf16(a, a_l, b, b_l, None, None), + mcandle_core::DType::F32 => self.fwd_f32(a, a_l, b, b_l, None, None), + dt => { + mcandle_core::bail!("cublaslt-matmul is only supported for f16/bf16/f32 ({dt:?})") + } } } } -impl candle_core::CustomOp3 for CublasLTMatmul { +impl mcandle_core::CustomOp3 for CublasLTMatmul { fn name(&self) -> &'static str { "cublaslt-matmul-add" } @@ -366,23 +368,25 @@ impl candle_core::CustomOp3 for CublasLTMatmul { _: &CpuStorage, _: &Layout, ) -> Result<(CpuStorage, Shape)> { - candle_core::bail!("no cpu support for cublaslt-matmul") + mcandle_core::bail!("no cpu support for cublaslt-matmul") } fn cuda_fwd( &self, - a: &candle_core::CudaStorage, + a: &mcandle_core::CudaStorage, a_l: &Layout, - b: &candle_core::CudaStorage, + b: &mcandle_core::CudaStorage, b_l: &Layout, - bias: &candle_core::CudaStorage, + bias: &mcandle_core::CudaStorage, bias_l: &Layout, - ) -> Result<(candle_core::CudaStorage, Shape)> { + ) -> Result<(mcandle_core::CudaStorage, Shape)> { match a.dtype() { - candle_core::DType::F16 => self.fwd_f16(a, a_l, b, b_l, Some(bias), Some(bias_l)), - candle_core::DType::BF16 => self.fwd_bf16(a, a_l, b, b_l, Some(bias), Some(bias_l)), - candle_core::DType::F32 => self.fwd_f32(a, a_l, b, b_l, Some(bias), Some(bias_l)), - dt => candle_core::bail!("cublaslt-matmul is only supported for f16/bf16/f32 ({dt:?})"), + mcandle_core::DType::F16 => self.fwd_f16(a, a_l, b, b_l, Some(bias), Some(bias_l)), + mcandle_core::DType::BF16 => self.fwd_bf16(a, a_l, b, b_l, Some(bias), Some(bias_l)), + mcandle_core::DType::F32 => self.fwd_f32(a, a_l, b, b_l, Some(bias), Some(bias_l)), + dt => { + mcandle_core::bail!("cublaslt-matmul is only supported for f16/bf16/f32 ({dt:?})") + } } } } @@ -439,13 +443,13 @@ pub struct CublasLTBatchMatmul { impl CublasLTBatchMatmul { pub fn fwd_f16( &self, - a: &candle_core::CudaStorage, + a: &mcandle_core::CudaStorage, a_l: &Layout, - b: &candle_core::CudaStorage, + b: &mcandle_core::CudaStorage, b_l: &Layout, - bias: Option<&candle_core::CudaStorage>, + bias: Option<&mcandle_core::CudaStorage>, bias_l: Option<&Layout>, - ) -> Result<(candle_core::CudaStorage, Shape)> { + ) -> Result<(mcandle_core::CudaStorage, Shape)> { let dev = a.device(); // Assume TN @@ -453,11 +457,11 @@ impl CublasLTBatchMatmul { let (b_0, n, b_2) = b_l.shape().dims3()?; if b_2 != k { - candle_core::bail!("This layer only supports TN layout"); + mcandle_core::bail!("This layer only supports TN layout"); } if b_0 != batch_size { - candle_core::bail!("`b` must have the same batch size as `a`") + mcandle_core::bail!("`b` must have the same batch size as `a`") } let lda = k; @@ -471,7 +475,7 @@ impl CublasLTBatchMatmul { let bias = if let (Some(bias), Some(bias_l)) = (bias, bias_l) { if bias_l.shape().dims1()? != m { - candle_core::bail!("Bias does not have the correct shape"); + mcandle_core::bail!("Bias does not have the correct shape"); } Some(bias.as_cuda_slice::()?.slice(bias_l.start_offset()..)) @@ -483,22 +487,22 @@ impl CublasLTBatchMatmul { let (c, c_l) = c.storage_and_layout(); let c = match &*c { Storage::Cuda(storage) => storage.as_cuda_slice::()?, - _ => candle_core::bail!("`c` must be a cuda tensor"), + _ => mcandle_core::bail!("`c` must be a cuda tensor"), }; match c_l.contiguous_offsets() { Some((o1, o2)) => { if o1 != 0 { - candle_core::bail!("`c` start offset must be 0"); + mcandle_core::bail!("`c` start offset must be 0"); } if o2 != out_shape.elem_count() { - candle_core::bail!("`c` end offset must be {}", out_shape.elem_count()) + mcandle_core::bail!("`c` end offset must be {}", out_shape.elem_count()) } } - None => candle_core::bail!("`c` has to be contiguous"), + None => mcandle_core::bail!("`c` has to be contiguous"), }; if c_l.shape().dims3()? != (batch_size, n, m) { - candle_core::bail!("`c` does not have the correct shape"); + mcandle_core::bail!("`c` does not have the correct shape"); } // Set beta to 0.0 if it is not set @@ -532,23 +536,23 @@ impl CublasLTBatchMatmul { unsafe { self.cublaslt .matmul(config, &a, &b, &mut out, bias.as_ref(), self.act.as_ref()) - .map_err(|e| candle_core::Error::Cuda(Box::new(e)))?; + .map_err(|e| mcandle_core::Error::Cuda(Box::new(e)))?; } - let out = candle_core::CudaStorage::wrap_cuda_slice(out, dev.clone()); + let out = mcandle_core::CudaStorage::wrap_cuda_slice(out, dev.clone()); Ok((out, out_shape)) } pub fn fwd_bf16( &self, - a: &candle_core::CudaStorage, + a: &mcandle_core::CudaStorage, a_l: &Layout, - b: &candle_core::CudaStorage, + b: &mcandle_core::CudaStorage, b_l: &Layout, - bias: Option<&candle_core::CudaStorage>, + bias: Option<&mcandle_core::CudaStorage>, bias_l: Option<&Layout>, - ) -> Result<(candle_core::CudaStorage, Shape)> { + ) -> Result<(mcandle_core::CudaStorage, Shape)> { let dev = a.device(); // Assume TN @@ -556,11 +560,11 @@ impl CublasLTBatchMatmul { let (b_0, n, b_2) = b_l.shape().dims3()?; if b_2 != k { - candle_core::bail!("This layer only supports TN layout"); + mcandle_core::bail!("This layer only supports TN layout"); } if b_0 != batch_size { - candle_core::bail!("`b` must have the same batch size as `a`") + mcandle_core::bail!("`b` must have the same batch size as `a`") } let lda = k; @@ -574,7 +578,7 @@ impl CublasLTBatchMatmul { let bias = if let (Some(bias), Some(bias_l)) = (bias, bias_l) { if bias_l.shape().dims1()? != m { - candle_core::bail!("Bias does not have the correct shape"); + mcandle_core::bail!("Bias does not have the correct shape"); } Some(bias.as_cuda_slice::()?.slice(bias_l.start_offset()..)) @@ -586,22 +590,22 @@ impl CublasLTBatchMatmul { let (c, c_l) = c.storage_and_layout(); let c = match &*c { Storage::Cuda(storage) => storage.as_cuda_slice::()?, - _ => candle_core::bail!("`c` must be a cuda tensor"), + _ => mcandle_core::bail!("`c` must be a cuda tensor"), }; match c_l.contiguous_offsets() { Some((o1, o2)) => { if o1 != 0 { - candle_core::bail!("`c` start offset must be 0"); + mcandle_core::bail!("`c` start offset must be 0"); } if o2 != out_shape.elem_count() { - candle_core::bail!("`c` end offset must be {}", out_shape.elem_count()) + mcandle_core::bail!("`c` end offset must be {}", out_shape.elem_count()) } } - None => candle_core::bail!("`c` has to be contiguous"), + None => mcandle_core::bail!("`c` has to be contiguous"), }; if c_l.shape().dims3()? != (batch_size, n, m) { - candle_core::bail!("`c` does not have the correct shape"); + mcandle_core::bail!("`c` does not have the correct shape"); } // Set beta to 0.0 if it is not set @@ -635,23 +639,23 @@ impl CublasLTBatchMatmul { unsafe { self.cublaslt .matmul(config, &a, &b, &mut out, bias.as_ref(), self.act.as_ref()) - .map_err(|e| candle_core::Error::Cuda(Box::new(e)))?; + .map_err(|e| mcandle_core::Error::Cuda(Box::new(e)))?; } - let out = candle_core::CudaStorage::wrap_cuda_slice(out, dev.clone()); + let out = mcandle_core::CudaStorage::wrap_cuda_slice(out, dev.clone()); Ok((out, out_shape)) } pub fn fwd_f32( &self, - a: &candle_core::CudaStorage, + a: &mcandle_core::CudaStorage, a_l: &Layout, - b: &candle_core::CudaStorage, + b: &mcandle_core::CudaStorage, b_l: &Layout, - bias: Option<&candle_core::CudaStorage>, + bias: Option<&mcandle_core::CudaStorage>, bias_l: Option<&Layout>, - ) -> Result<(candle_core::CudaStorage, Shape)> { + ) -> Result<(mcandle_core::CudaStorage, Shape)> { let dev = a.device(); // Assume TN @@ -659,11 +663,11 @@ impl CublasLTBatchMatmul { let (b_0, n, b_2) = b_l.shape().dims3()?; if b_2 != k { - candle_core::bail!("This layer only supports TN layout"); + mcandle_core::bail!("This layer only supports TN layout"); } if b_0 != batch_size { - candle_core::bail!("`b` must have the same batch size as `a`") + mcandle_core::bail!("`b` must have the same batch size as `a`") } let lda = k; @@ -677,7 +681,7 @@ impl CublasLTBatchMatmul { let bias = if let (Some(bias), Some(bias_l)) = (bias, bias_l) { if bias_l.shape().dims1()? != m { - candle_core::bail!("Bias does not have the correct shape"); + mcandle_core::bail!("Bias does not have the correct shape"); } Some(bias.as_cuda_slice::()?.slice(bias_l.start_offset()..)) @@ -689,22 +693,22 @@ impl CublasLTBatchMatmul { let (c, c_l) = c.storage_and_layout(); let c = match &*c { Storage::Cuda(storage) => storage.as_cuda_slice::()?, - _ => candle_core::bail!("`c` must be a cuda tensor"), + _ => mcandle_core::bail!("`c` must be a cuda tensor"), }; match c_l.contiguous_offsets() { Some((o1, o2)) => { if o1 != 0 { - candle_core::bail!("`c` start offset must be 0"); + mcandle_core::bail!("`c` start offset must be 0"); } if o2 != out_shape.elem_count() { - candle_core::bail!("`c` end offset must be {}", out_shape.elem_count()) + mcandle_core::bail!("`c` end offset must be {}", out_shape.elem_count()) } } - None => candle_core::bail!("`c` has to be contiguous"), + None => mcandle_core::bail!("`c` has to be contiguous"), }; if c_l.shape().dims3()? != (batch_size, n, m) { - candle_core::bail!("`c` does not have the correct shape"); + mcandle_core::bail!("`c` does not have the correct shape"); } // Set beta to 0.0 if it is not set @@ -738,16 +742,16 @@ impl CublasLTBatchMatmul { unsafe { self.cublaslt .matmul(config, &a, &b, &mut out, bias.as_ref(), self.act.as_ref()) - .map_err(|e| candle_core::Error::Cuda(Box::new(e)))?; + .map_err(|e| mcandle_core::Error::Cuda(Box::new(e)))?; } - let out = candle_core::CudaStorage::wrap_cuda_slice(out, dev.clone()); + let out = mcandle_core::CudaStorage::wrap_cuda_slice(out, dev.clone()); Ok((out, out_shape)) } } -impl candle_core::CustomOp2 for CublasLTBatchMatmul { +impl mcandle_core::CustomOp2 for CublasLTBatchMatmul { fn name(&self) -> &'static str { "cublaslt-batch-matmul" } @@ -759,22 +763,22 @@ impl candle_core::CustomOp2 for CublasLTBatchMatmul { _: &CpuStorage, _: &Layout, ) -> Result<(CpuStorage, Shape)> { - candle_core::bail!("no cpu support for cublaslt-batch-matmul") + mcandle_core::bail!("no cpu support for cublaslt-batch-matmul") } fn cuda_fwd( &self, - a: &candle_core::CudaStorage, + a: &mcandle_core::CudaStorage, a_l: &Layout, - b: &candle_core::CudaStorage, + b: &mcandle_core::CudaStorage, b_l: &Layout, - ) -> Result<(candle_core::CudaStorage, Shape)> { + ) -> Result<(mcandle_core::CudaStorage, Shape)> { match a.dtype() { - candle_core::DType::F16 => self.fwd_f16(a, a_l, b, b_l, None, None), - candle_core::DType::BF16 => self.fwd_bf16(a, a_l, b, b_l, None, None), - candle_core::DType::F32 => self.fwd_f32(a, a_l, b, b_l, None, None), + mcandle_core::DType::F16 => self.fwd_f16(a, a_l, b, b_l, None, None), + mcandle_core::DType::BF16 => self.fwd_bf16(a, a_l, b, b_l, None, None), + mcandle_core::DType::F32 => self.fwd_f32(a, a_l, b, b_l, None, None), dt => { - candle_core::bail!( + mcandle_core::bail!( "cublaslt-batch-matmul is only supported for f16/bf16/f32 ({dt:?})" ) } @@ -782,7 +786,7 @@ impl candle_core::CustomOp2 for CublasLTBatchMatmul { } } -impl candle_core::CustomOp3 for CublasLTBatchMatmul { +impl mcandle_core::CustomOp3 for CublasLTBatchMatmul { fn name(&self) -> &'static str { "cublaslt-batch-matmul-add" } @@ -796,23 +800,23 @@ impl candle_core::CustomOp3 for CublasLTBatchMatmul { _: &CpuStorage, _: &Layout, ) -> Result<(CpuStorage, Shape)> { - candle_core::bail!("no cpu support for cublaslt-batch-matmul-add") + mcandle_core::bail!("no cpu support for cublaslt-batch-matmul-add") } fn cuda_fwd( &self, - a: &candle_core::CudaStorage, + a: &mcandle_core::CudaStorage, a_l: &Layout, - b: &candle_core::CudaStorage, + b: &mcandle_core::CudaStorage, b_l: &Layout, - bias: &candle_core::CudaStorage, + bias: &mcandle_core::CudaStorage, bias_l: &Layout, - ) -> Result<(candle_core::CudaStorage, Shape)> { + ) -> Result<(mcandle_core::CudaStorage, Shape)> { match a.dtype() { - candle_core::DType::F16 => self.fwd_f16(a, a_l, b, b_l, Some(bias), Some(bias_l)), - candle_core::DType::BF16 => self.fwd_bf16(a, a_l, b, b_l, Some(bias), Some(bias_l)), - candle_core::DType::F32 => self.fwd_f32(a, a_l, b, b_l, Some(bias), Some(bias_l)), - dt => candle_core::bail!( + mcandle_core::DType::F16 => self.fwd_f16(a, a_l, b, b_l, Some(bias), Some(bias_l)), + mcandle_core::DType::BF16 => self.fwd_bf16(a, a_l, b, b_l, Some(bias), Some(bias_l)), + mcandle_core::DType::F32 => self.fwd_f32(a, a_l, b, b_l, Some(bias), Some(bias_l)), + dt => mcandle_core::bail!( "cublaslt-batch-matmul-add is only supported for f16/bf16/f32 ({dt:?})" ), } @@ -864,7 +868,7 @@ mod tests { use std::f32::consts::PI; use super::*; - use candle_core::{DType, Device, IndexOp}; + use mcandle_core::{DType, Device, IndexOp}; fn to_vec2_round(t: Tensor, digits: i32) -> Result>> { let b = 10f32.powi(digits); diff --git a/mistralrs-core/src/cublaslt/matmul.rs b/mistralrs-core/src/cublaslt/matmul.rs index 898a30522..f791c9ca6 100644 --- a/mistralrs-core/src/cublaslt/matmul.rs +++ b/mistralrs-core/src/cublaslt/matmul.rs @@ -1,13 +1,13 @@ -use candle_core::cuda::cudarc::cublaslt::result::set_matrix_layout_attribute; -use candle_core::cuda::cudarc::cublaslt::{result, result::CublasError, sys}; -use candle_core::cuda::cudarc::driver::sys::{CUdevice_attribute, CUdeviceptr, CUstream}; -use candle_core::cuda::cudarc::driver::{ - CudaDevice, CudaSlice, DevicePtr, DevicePtrMut, DriverError, -}; use core::ffi::c_int; use core::mem; use float8::F8E4M3; use half::bf16; +use mcandle_core::cuda::cudarc::cublaslt::result::set_matrix_layout_attribute; +use mcandle_core::cuda::cudarc::cublaslt::{result, result::CublasError, sys}; +use mcandle_core::cuda::cudarc::driver::sys::{CUdevice_attribute, CUdeviceptr, CUstream}; +use mcandle_core::cuda::cudarc::driver::{ + CudaDevice, CudaSlice, DevicePtr, DevicePtrMut, DriverError, +}; use std::sync::Arc; /// Wrapper around [sys::cublasLtHandle_t] diff --git a/mistralrs-core/src/cublaslt/mod.rs b/mistralrs-core/src/cublaslt/mod.rs index 9d6046b38..2eb711235 100644 --- a/mistralrs-core/src/cublaslt/mod.rs +++ b/mistralrs-core/src/cublaslt/mod.rs @@ -1,9 +1,9 @@ -// https://github.com/huggingface/text-embeddings-inference/blob/cc1c510e8d8af8447c01e6b14c417473cf2dfda9/backends/candle/src/layers/cublaslt.rs +// https://github.com/huggingface/text-embeddings-inference/blob/cc1c510e8d8af8447c01e6b14c417473cf2dfda9/backends/mcandle/src/layers/cublaslt.rs #![allow(unused_variables, unused_imports, dead_code)] -use candle_core::{Device, Result, Tensor}; -use candle_nn::Activation as CandleActivation; +use mcandle_core::{Device, Result, Tensor}; +use mcandle_nn::Activation as CandleActivation; use once_cell::sync::Lazy; use std::sync::{Mutex, Once}; @@ -33,7 +33,7 @@ pub fn setup_cublas_lt_wrapper() { // Check if we can call the driver // Then check if we can create a device // Then check that the device is CUDA - use candle_core::cuda_backend::cudarc::driver; + use mcandle_core::cuda_backend::cudarc::driver; CUBLASLT = driver::result::init() .ok() .and_then(|_| Device::cuda_if_available(0).ok()) @@ -88,13 +88,13 @@ impl CublasLtWrapper { )?; if Some(CandleActivation::Swiglu) == act { - result = candle_nn::ops::swiglu(&result)?; + result = mcandle_nn::ops::swiglu(&result)?; } Ok(result) } #[cfg(not(feature = "cuda"))] { - candle_core::bail!("`cuda` feature is not enabled") + mcandle_core::bail!("`cuda` feature is not enabled") } } } diff --git a/mistralrs-core/src/device_map.rs b/mistralrs-core/src/device_map.rs index 885ed6ec8..0aa1a1926 100644 --- a/mistralrs-core/src/device_map.rs +++ b/mistralrs-core/src/device_map.rs @@ -1,8 +1,8 @@ use std::fmt::Debug; use crate::{utils::debug::DeviceRepr, Topology, TryIntoDType}; -use candle_core::{DType, Device, Result, Tensor}; -use candle_nn::VarBuilder; +use mcandle_core::{DType, Device, Result, Tensor}; +use mcandle_nn::VarBuilder; use serde::Deserialize; use tracing::info; @@ -90,7 +90,7 @@ impl DeviceMapMetadata { .host_layers .unwrap_or(model_layers.saturating_sub(n_device_layers)); if n_device_layers + n_host_layers != model_layers { - candle_core::bail!("Expected the total number of GPU ({n_device_layers}) and host layers ({n_host_layers}) to sum to the number of model hidden layers ({model_layers})"); + mcandle_core::bail!("Expected the total number of GPU ({n_device_layers}) and host layers ({n_host_layers}) to sum to the number of model hidden layers ({model_layers})"); } info!("Model has {model_layers} repeating layers."); @@ -203,7 +203,7 @@ impl DeviceMapper for LayerDeviceMapper { fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result { dtype .try_into_dtype(&self.mappings.iter().collect::>()) - .map_err(candle_core::Error::msg) + .map_err(mcandle_core::Error::msg) } } @@ -251,6 +251,6 @@ impl DeviceMapper for DummyDeviceMapper { fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result { dtype .try_into_dtype(&[&self.nm_device]) - .map_err(candle_core::Error::msg) + .map_err(mcandle_core::Error::msg) } } diff --git a/mistralrs-core/src/diffusion_models/clip/text.rs b/mistralrs-core/src/diffusion_models/clip/text.rs index 2d8a58e32..6739055ff 100644 --- a/mistralrs-core/src/diffusion_models/clip/text.rs +++ b/mistralrs-core/src/diffusion_models/clip/text.rs @@ -1,8 +1,8 @@ #![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] -use candle_core::{DType, Device, IndexOp, Result, Tensor, D}; -use candle_nn as nn; -use candle_nn::Module; +use mcandle_core::{DType, Device, IndexOp, Result, Tensor, D}; +use mcandle_nn as nn; +use mcandle_nn::Module; use serde::Deserialize; #[derive(Debug, Clone, Copy, Deserialize)] @@ -39,16 +39,16 @@ pub struct ClipConfig { // TODO rewrite to be more similar to https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L142 #[derive(Clone, Debug)] struct ClipTextEmbeddings { - token_embedding: candle_nn::Embedding, - position_embedding: candle_nn::Embedding, + token_embedding: mcandle_nn::Embedding, + position_embedding: mcandle_nn::Embedding, position_ids: Tensor, } impl ClipTextEmbeddings { - fn new(vs: candle_nn::VarBuilder, c: &ClipTextConfig) -> Result { + fn new(vs: mcandle_nn::VarBuilder, c: &ClipTextConfig) -> Result { let token_embedding = - candle_nn::embedding(c.vocab_size, c.projection_dim, vs.pp("token_embedding"))?; - let position_embedding: nn::Embedding = candle_nn::embedding( + mcandle_nn::embedding(c.vocab_size, c.projection_dim, vs.pp("token_embedding"))?; + let position_embedding: nn::Embedding = mcandle_nn::embedding( c.max_position_embeddings, c.projection_dim, vs.pp("position_embedding"), @@ -75,23 +75,23 @@ impl Module for ClipTextEmbeddings { #[derive(Clone, Debug)] struct ClipAttention { - k_proj: candle_nn::Linear, - v_proj: candle_nn::Linear, - q_proj: candle_nn::Linear, - out_proj: candle_nn::Linear, + k_proj: mcandle_nn::Linear, + v_proj: mcandle_nn::Linear, + q_proj: mcandle_nn::Linear, + out_proj: mcandle_nn::Linear, head_dim: usize, scale: f64, num_attention_heads: usize, } impl ClipAttention { - fn new(vs: candle_nn::VarBuilder, c: &ClipTextConfig) -> Result { + fn new(vs: mcandle_nn::VarBuilder, c: &ClipTextConfig) -> Result { let projection_dim = c.projection_dim; let num_attention_heads = c.num_attention_heads; - let k_proj = candle_nn::linear(projection_dim, projection_dim, vs.pp("k_proj"))?; - let v_proj = candle_nn::linear(projection_dim, projection_dim, vs.pp("v_proj"))?; - let q_proj = candle_nn::linear(projection_dim, projection_dim, vs.pp("q_proj"))?; - let out_proj = candle_nn::linear(projection_dim, projection_dim, vs.pp("out_proj"))?; + let k_proj = mcandle_nn::linear(projection_dim, projection_dim, vs.pp("k_proj"))?; + let v_proj = mcandle_nn::linear(projection_dim, projection_dim, vs.pp("v_proj"))?; + let q_proj = mcandle_nn::linear(projection_dim, projection_dim, vs.pp("q_proj"))?; + let out_proj = mcandle_nn::linear(projection_dim, projection_dim, vs.pp("out_proj"))?; let head_dim = projection_dim / num_attention_heads; let scale = (head_dim as f64).powf(-0.5); @@ -143,7 +143,7 @@ impl ClipAttention { attn_weights }; - let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?; + let attn_weights = mcandle_nn::ops::softmax(&attn_weights, D::Minus1)?; let attn_output = attn_weights.matmul(&value_states)?.to_dtype(in_dtype)?; let attn_output = attn_output @@ -156,15 +156,15 @@ impl ClipAttention { #[derive(Clone, Debug)] struct ClipMlp { - fc1: candle_nn::Linear, - fc2: candle_nn::Linear, + fc1: mcandle_nn::Linear, + fc2: mcandle_nn::Linear, activation: Activation, } impl ClipMlp { - fn new(vs: candle_nn::VarBuilder, c: &ClipTextConfig) -> Result { - let fc1 = candle_nn::linear(c.projection_dim, c.intermediate_size, vs.pp("fc1"))?; - let fc2 = candle_nn::linear(c.intermediate_size, c.projection_dim, vs.pp("fc2"))?; + fn new(vs: mcandle_nn::VarBuilder, c: &ClipTextConfig) -> Result { + let fc1 = mcandle_nn::linear(c.projection_dim, c.intermediate_size, vs.pp("fc1"))?; + let fc2 = mcandle_nn::linear(c.intermediate_size, c.projection_dim, vs.pp("fc2"))?; Ok(ClipMlp { fc1, @@ -184,17 +184,17 @@ impl ClipMlp { #[derive(Clone, Debug)] struct ClipEncoderLayer { self_attn: ClipAttention, - layer_norm1: candle_nn::LayerNorm, + layer_norm1: mcandle_nn::LayerNorm, mlp: ClipMlp, - layer_norm2: candle_nn::LayerNorm, + layer_norm2: mcandle_nn::LayerNorm, } impl ClipEncoderLayer { - fn new(vs: candle_nn::VarBuilder, c: &ClipTextConfig) -> Result { + fn new(vs: mcandle_nn::VarBuilder, c: &ClipTextConfig) -> Result { let self_attn = ClipAttention::new(vs.pp("self_attn"), c)?; - let layer_norm1 = candle_nn::layer_norm(c.projection_dim, 1e-5, vs.pp("layer_norm1"))?; + let layer_norm1 = mcandle_nn::layer_norm(c.projection_dim, 1e-5, vs.pp("layer_norm1"))?; let mlp = ClipMlp::new(vs.pp("mlp"), c)?; - let layer_norm2 = candle_nn::layer_norm(c.projection_dim, 1e-5, vs.pp("layer_norm2"))?; + let layer_norm2 = mcandle_nn::layer_norm(c.projection_dim, 1e-5, vs.pp("layer_norm2"))?; Ok(ClipEncoderLayer { self_attn, @@ -223,7 +223,7 @@ pub struct ClipEncoder { } impl ClipEncoder { - pub fn new(vs: candle_nn::VarBuilder, c: &ClipTextConfig) -> Result { + pub fn new(vs: mcandle_nn::VarBuilder, c: &ClipTextConfig) -> Result { let vs = vs.pp("layers"); let mut layers: Vec = Vec::new(); for index in 0..c.num_hidden_layers { @@ -247,15 +247,15 @@ impl ClipEncoder { pub struct ClipTextTransformer { embeddings: ClipTextEmbeddings, encoder: ClipEncoder, - final_layer_norm: candle_nn::LayerNorm, + final_layer_norm: mcandle_nn::LayerNorm, } impl ClipTextTransformer { - pub fn new(vs: candle_nn::VarBuilder, c: &ClipTextConfig) -> Result { + pub fn new(vs: mcandle_nn::VarBuilder, c: &ClipTextConfig) -> Result { let embeddings = ClipTextEmbeddings::new(vs.pp("embeddings"), c)?; let encoder = ClipEncoder::new(vs.pp("encoder"), c)?; let final_layer_norm = - candle_nn::layer_norm(c.projection_dim, 1e-5, vs.pp("final_layer_norm"))?; + mcandle_nn::layer_norm(c.projection_dim, 1e-5, vs.pp("final_layer_norm"))?; Ok(ClipTextTransformer { embeddings, encoder, diff --git a/mistralrs-core/src/diffusion_models/flux/autoencoder.rs b/mistralrs-core/src/diffusion_models/flux/autoencoder.rs index cec73be4b..11a97920e 100644 --- a/mistralrs-core/src/diffusion_models/flux/autoencoder.rs +++ b/mistralrs-core/src/diffusion_models/flux/autoencoder.rs @@ -1,7 +1,7 @@ #![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] -use candle_core::{Result, Tensor, D}; -use candle_nn::{conv2d, group_norm, Conv2d, GroupNorm, VarBuilder}; +use mcandle_core::{Result, Tensor, D}; +use mcandle_nn::{conv2d, group_norm, Conv2d, GroupNorm, VarBuilder}; use serde::Deserialize; #[derive(Debug, Clone, Deserialize)] @@ -20,7 +20,7 @@ fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result Result { let init_xs = xs; let xs = xs.apply(&self.norm)?; @@ -77,7 +77,7 @@ struct ResnetBlock { impl ResnetBlock { fn new(in_c: usize, out_c: usize, vb: VarBuilder, cfg: &Config) -> Result { - let conv_cfg = candle_nn::Conv2dConfig { + let conv_cfg = mcandle_nn::Conv2dConfig { padding: 1, ..Default::default() }; @@ -106,14 +106,14 @@ impl ResnetBlock { } } -impl candle_core::Module for ResnetBlock { +impl mcandle_core::Module for ResnetBlock { fn forward(&self, xs: &Tensor) -> Result { let h = xs .apply(&self.norm1)? - .apply(&candle_nn::Activation::Swish)? + .apply(&mcandle_nn::Activation::Swish)? .apply(&self.conv1)? .apply(&self.norm2)? - .apply(&candle_nn::Activation::Swish)? + .apply(&mcandle_nn::Activation::Swish)? .apply(&self.conv2)?; match self.nin_shortcut.as_ref() { None => xs + h, @@ -129,7 +129,7 @@ struct Downsample { impl Downsample { fn new(in_c: usize, vb: VarBuilder) -> Result { - let conv_cfg = candle_nn::Conv2dConfig { + let conv_cfg = mcandle_nn::Conv2dConfig { stride: 2, ..Default::default() }; @@ -138,7 +138,7 @@ impl Downsample { } } -impl candle_core::Module for Downsample { +impl mcandle_core::Module for Downsample { fn forward(&self, xs: &Tensor) -> Result { let xs = xs.pad_with_zeros(D::Minus1, 0, 1)?; let xs = xs.pad_with_zeros(D::Minus2, 0, 1)?; @@ -153,7 +153,7 @@ struct Upsample { impl Upsample { fn new(in_c: usize, vb: VarBuilder) -> Result { - let conv_cfg = candle_nn::Conv2dConfig { + let conv_cfg = mcandle_nn::Conv2dConfig { padding: 1, ..Default::default() }; @@ -162,7 +162,7 @@ impl Upsample { } } -impl candle_core::Module for Upsample { +impl mcandle_core::Module for Upsample { fn forward(&self, xs: &Tensor) -> Result { let (_, _, h, w) = xs.dims4()?; xs.upsample_nearest2d(h * 2, w * 2)?.apply(&self.conv) @@ -188,7 +188,7 @@ pub struct Encoder { impl Encoder { pub fn new(cfg: &Config, vb: VarBuilder) -> Result { - let conv_cfg = candle_nn::Conv2dConfig { + let conv_cfg = mcandle_nn::Conv2dConfig { padding: 1, ..Default::default() }; @@ -245,7 +245,7 @@ impl Encoder { } } -impl candle_nn::Module for Encoder { +impl mcandle_nn::Module for Encoder { fn forward(&self, xs: &Tensor) -> Result { let mut h = xs.apply(&self.conv_in)?; for block in self.down.iter() { @@ -260,7 +260,7 @@ impl candle_nn::Module for Encoder { .apply(&self.mid_attn_1)? .apply(&self.mid_block_2)? .apply(&self.norm_out)? - .apply(&candle_nn::Activation::Swish)? + .apply(&mcandle_nn::Activation::Swish)? .apply(&self.conv_out) } } @@ -284,7 +284,7 @@ pub struct Decoder { impl Decoder { pub fn new(cfg: &Config, vb: VarBuilder) -> Result { - let conv_cfg = candle_nn::Conv2dConfig { + let conv_cfg = mcandle_nn::Conv2dConfig { padding: 1, ..Default::default() }; @@ -331,7 +331,7 @@ impl Decoder { } } -impl candle_nn::Module for Decoder { +impl mcandle_nn::Module for Decoder { fn forward(&self, xs: &Tensor) -> Result { let h = xs.apply(&self.conv_in)?; let mut h = h @@ -347,7 +347,7 @@ impl candle_nn::Module for Decoder { } } h.apply(&self.norm_out)? - .apply(&candle_nn::Activation::Swish)? + .apply(&mcandle_nn::Activation::Swish)? .apply(&self.conv_out) } } @@ -364,7 +364,7 @@ impl DiagonalGaussian { } } -impl candle_nn::Module for DiagonalGaussian { +impl mcandle_nn::Module for DiagonalGaussian { fn forward(&self, xs: &Tensor) -> Result { let chunks = xs.chunk(2, self.chunk_dim)?; if self.sample { @@ -409,7 +409,7 @@ impl AutoEncoder { } } -impl candle_core::Module for AutoEncoder { +impl mcandle_core::Module for AutoEncoder { fn forward(&self, xs: &Tensor) -> Result { self.decode(&self.encode(xs)?) } diff --git a/mistralrs-core/src/diffusion_models/flux/model.rs b/mistralrs-core/src/diffusion_models/flux/model.rs index f9e8694bd..8c4fc1adf 100644 --- a/mistralrs-core/src/diffusion_models/flux/model.rs +++ b/mistralrs-core/src/diffusion_models/flux/model.rs @@ -1,7 +1,7 @@ #![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] -use candle_core::{DType, Device, IndexOp, Result, Tensor, D}; -use candle_nn::{layer_norm::RmsNormNonQuantized, LayerNorm, Linear, RmsNorm, VarBuilder}; +use mcandle_core::{DType, Device, IndexOp, Result, Tensor, D}; +use mcandle_nn::{layer_norm::RmsNormNonQuantized, LayerNorm, Linear, RmsNorm, VarBuilder}; use serde::Deserialize; const MLP_RATIO: f64 = 4.; @@ -35,7 +35,7 @@ fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result Result Result { if dim % 2 == 1 { - candle_core::bail!("dim {dim} is odd") + mcandle_core::bail!("dim {dim} is odd") } let dev = pos.device(); let theta = theta as f64; @@ -84,16 +84,16 @@ fn timestep_embedding(t: &Tensor, dim: usize, dtype: DType) -> Result { const TIME_FACTOR: f64 = 1000.; const MAX_PERIOD: f64 = 10000.; if dim % 2 == 1 { - candle_core::bail!("{dim} is odd") + mcandle_core::bail!("{dim} is odd") } let dev = t.device(); let half = dim / 2; let t = (t * TIME_FACTOR)?; - let arange = Tensor::arange(0, half as u32, dev)?.to_dtype(candle_core::DType::F32)?; + let arange = Tensor::arange(0, half as u32, dev)?.to_dtype(mcandle_core::DType::F32)?; let freqs = (arange * (-MAX_PERIOD.ln() / half as f64))?.exp()?; let args = t .unsqueeze(1)? - .to_dtype(candle_core::DType::F32)? + .to_dtype(mcandle_core::DType::F32)? .broadcast_mul(&freqs.unsqueeze(0)?)?; let emb = Tensor::cat(&[args.cos()?, args.sin()?], D::Minus1)?.to_dtype(dtype)?; Ok(emb) @@ -117,7 +117,7 @@ impl EmbedNd { } } -impl candle_core::Module for EmbedNd { +impl mcandle_core::Module for EmbedNd { fn forward(&self, ids: &Tensor) -> Result { let n_axes = ids.dim(D::Minus1)?; let mut emb = Vec::with_capacity(n_axes); @@ -142,8 +142,8 @@ pub struct MlpEmbedder { impl MlpEmbedder { fn new(in_sz: usize, h_sz: usize, vb: VarBuilder) -> Result { - let in_layer = candle_nn::linear(in_sz, h_sz, vb.pp("in_layer"))?; - let out_layer = candle_nn::linear(h_sz, h_sz, vb.pp("out_layer"))?; + let in_layer = mcandle_nn::linear(in_sz, h_sz, vb.pp("in_layer"))?; + let out_layer = mcandle_nn::linear(h_sz, h_sz, vb.pp("out_layer"))?; Ok(Self { in_layer, out_layer, @@ -151,7 +151,7 @@ impl MlpEmbedder { } } -impl candle_core::Module for MlpEmbedder { +impl mcandle_core::Module for MlpEmbedder { fn forward(&self, xs: &Tensor) -> Result { xs.apply(&self.in_layer)?.silu()?.apply(&self.out_layer) } @@ -200,7 +200,7 @@ struct Modulation1 { impl Modulation1 { fn new(dim: usize, vb: VarBuilder) -> Result { - let lin = candle_nn::linear(dim, 3 * dim, vb.pp("lin"))?; + let lin = mcandle_nn::linear(dim, 3 * dim, vb.pp("lin"))?; Ok(Self { lin }) } @@ -211,7 +211,7 @@ impl Modulation1 { .unsqueeze(1)? .chunk(3, D::Minus1)?; if ys.len() != 3 { - candle_core::bail!("unexpected len from chunk {ys:?}") + mcandle_core::bail!("unexpected len from chunk {ys:?}") } Ok(ModulationOut { shift: ys[0].clone(), @@ -228,7 +228,7 @@ struct Modulation2 { impl Modulation2 { fn new(dim: usize, vb: VarBuilder) -> Result { - let lin = candle_nn::linear(dim, 6 * dim, vb.pp("lin"))?; + let lin = mcandle_nn::linear(dim, 6 * dim, vb.pp("lin"))?; Ok(Self { lin }) } @@ -239,7 +239,7 @@ impl Modulation2 { .unsqueeze(1)? .chunk(6, D::Minus1)?; if ys.len() != 6 { - candle_core::bail!("unexpected len from chunk {ys:?}") + mcandle_core::bail!("unexpected len from chunk {ys:?}") } let mod1 = ModulationOut { shift: ys[0].clone(), @@ -266,9 +266,9 @@ pub struct SelfAttention { impl SelfAttention { fn new(dim: usize, num_attention_heads: usize, qkv_bias: bool, vb: VarBuilder) -> Result { let head_dim = dim / num_attention_heads; - let qkv = candle_nn::linear_b(dim, dim * 3, qkv_bias, vb.pp("qkv"))?; + let qkv = mcandle_nn::linear_b(dim, dim * 3, qkv_bias, vb.pp("qkv"))?; let norm = QkNorm::new(head_dim, vb.pp("norm"))?; - let proj = candle_nn::linear(dim, dim, vb.pp("proj"))?; + let proj = mcandle_nn::linear(dim, dim, vb.pp("proj"))?; Ok(Self { qkv, norm, @@ -326,8 +326,8 @@ struct Mlp { impl Mlp { fn new(in_sz: usize, mlp_sz: usize, vb: VarBuilder) -> Result { - let lin1 = candle_nn::linear(in_sz, mlp_sz, vb.pp("0"))?; - let lin2 = candle_nn::linear(mlp_sz, in_sz, vb.pp("2"))?; + let lin1 = mcandle_nn::linear(in_sz, mlp_sz, vb.pp("0"))?; + let lin2 = mcandle_nn::linear(mlp_sz, in_sz, vb.pp("2"))?; Ok(Self { lin1, lin2 }) } @@ -344,7 +344,7 @@ impl Mlp { } } -impl candle_core::Module for Mlp { +impl mcandle_core::Module for Mlp { fn forward(&self, xs: &Tensor) -> Result { xs.apply(&self.lin1)?.gelu()?.apply(&self.lin2) } @@ -498,8 +498,8 @@ impl SingleStreamBlock { let h_sz = HIDDEN_SIZE; let mlp_sz = (h_sz as f64 * MLP_RATIO) as usize; let head_dim = h_sz / cfg.num_attention_heads; - let linear1 = candle_nn::linear(h_sz, h_sz * 3 + mlp_sz, vb.pp("linear1"))?; - let linear2 = candle_nn::linear(h_sz + mlp_sz, h_sz, vb.pp("linear2"))?; + let linear1 = mcandle_nn::linear(h_sz, h_sz * 3 + mlp_sz, vb.pp("linear1"))?; + let linear2 = mcandle_nn::linear(h_sz + mlp_sz, h_sz, vb.pp("linear2"))?; let norm = QkNorm::new(head_dim, vb.pp("norm"))?; let pre_norm = layer_norm(h_sz, vb.pp("pre_norm"))?; let modulation = Modulation1::new(h_sz, vb.pp("modulation"))?; @@ -578,8 +578,8 @@ pub struct LastLayer { impl LastLayer { fn new(h_sz: usize, p_sz: usize, out_c: usize, vb: VarBuilder) -> Result { let norm_final = layer_norm(h_sz, vb.pp("norm_final"))?; - let linear = candle_nn::linear(h_sz, p_sz * p_sz * out_c, vb.pp("linear"))?; - let ada_ln_modulation = candle_nn::linear(h_sz, 2 * h_sz, vb.pp("adaLN_modulation.1"))?; + let linear = mcandle_nn::linear(h_sz, p_sz * p_sz * out_c, vb.pp("linear"))?; + let ada_ln_modulation = mcandle_nn::linear(h_sz, 2 * h_sz, vb.pp("adaLN_modulation.1"))?; Ok(Self { norm_final, linear, @@ -615,12 +615,12 @@ pub struct Flux { impl Flux { pub fn new(cfg: &Config, vb: VarBuilder, device: Device, offloaded: bool) -> Result { - let img_in = candle_nn::linear( + let img_in = mcandle_nn::linear( cfg.in_channels, HIDDEN_SIZE, vb.pp("img_in").set_device(device.clone()), )?; - let txt_in = candle_nn::linear( + let txt_in = mcandle_nn::linear( cfg.joint_attention_dim, HIDDEN_SIZE, vb.pp("txt_in").set_device(device.clone()), @@ -692,10 +692,10 @@ impl Flux { guidance: Option<&Tensor>, ) -> Result { if txt.rank() != 3 { - candle_core::bail!("unexpected shape for txt {:?}", txt.shape()) + mcandle_core::bail!("unexpected shape for txt {:?}", txt.shape()) } if img.rank() != 3 { - candle_core::bail!("unexpected shape for img {:?}", img.shape()) + mcandle_core::bail!("unexpected shape for img {:?}", img.shape()) } let dtype = img.dtype(); let pe = { diff --git a/mistralrs-core/src/diffusion_models/flux/sampling.rs b/mistralrs-core/src/diffusion_models/flux/sampling.rs index ca9a046b6..e883559da 100644 --- a/mistralrs-core/src/diffusion_models/flux/sampling.rs +++ b/mistralrs-core/src/diffusion_models/flux/sampling.rs @@ -1,6 +1,6 @@ #![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] -use candle_core::{Device, Result, Tensor}; +use mcandle_core::{Device, Result, Tensor}; pub fn get_noise( num_samples: usize, diff --git a/mistralrs-core/src/diffusion_models/flux/stepper.rs b/mistralrs-core/src/diffusion_models/flux/stepper.rs index 3a7501a9a..905d3ea1d 100644 --- a/mistralrs-core/src/diffusion_models/flux/stepper.rs +++ b/mistralrs-core/src/diffusion_models/flux/stepper.rs @@ -1,8 +1,8 @@ use std::{cmp::Ordering, fs::File}; -use candle_core::{DType, Device, Result, Tensor, D}; -use candle_nn::{Module, VarBuilder}; use hf_hub::api::sync::{Api, ApiError}; +use mcandle_core::{DType, Device, Result, Tensor, D}; +use mcandle_nn::{Module, VarBuilder}; use tokenizers::Tokenizer; use tracing::info; @@ -88,7 +88,7 @@ fn get_t5_model( device: &Device, silent: bool, offloaded: bool, -) -> candle_core::Result { +) -> mcandle_core::Result { let repo = api.repo(hf_hub::Repo::with_revision( "EricB/t5-v1_1-xxl-enc-only".to_string(), hf_hub::RepoType::Model, @@ -100,7 +100,7 @@ fn get_t5_model( .iter() .map(|f| repo.get(f)) .collect::, ApiError>>() - .map_err(candle_core::Error::msg)?, + .map_err(mcandle_core::Error::msg)?, vec![], Some(dtype), device, @@ -108,9 +108,9 @@ fn get_t5_model( None, |_| true, )?; - let config_filename = repo.get("config.json").map_err(candle_core::Error::msg)?; + let config_filename = repo.get("config.json").map_err(mcandle_core::Error::msg)?; let config = std::fs::read_to_string(config_filename)?; - let config: t5::Config = serde_json::from_str(&config).map_err(candle_core::Error::msg)?; + let config: t5::Config = serde_json::from_str(&config).map_err(mcandle_core::Error::msg)?; t5::T5EncoderModel::load(vb, &config, device, offloaded) } @@ -142,7 +142,7 @@ fn get_clip_model_and_tokenizer( fn get_tokenization(tok: &Tokenizer, prompts: Vec, device: &Device) -> Result { Tensor::new( tok.encode_batch(prompts, true) - .map_err(|e| candle_core::Error::Msg(e.to_string()))? + .map_err(|e| mcandle_core::Error::Msg(e.to_string()))? .into_iter() .map(|e| e.get_ids().to_vec()) .collect::>(), @@ -194,7 +194,7 @@ impl DiffusionModel for FluxStepper { if !self.is_guidance { match t5_input_ids.dim(1)?.cmp(&256) { Ordering::Greater => { - candle_core::bail!("T5 embedding length greater than 256, please shrink the prompt or use the -dev (with guidance distillation) version.") + mcandle_core::bail!("T5 embedding length greater than 256, please shrink the prompt or use the -dev (with guidance distillation) version.") } Ordering::Less | Ordering::Equal => { t5_input_ids = diff --git a/mistralrs-core/src/diffusion_models/processor.rs b/mistralrs-core/src/diffusion_models/processor.rs index 17670fa95..cec799f06 100644 --- a/mistralrs-core/src/diffusion_models/processor.rs +++ b/mistralrs-core/src/diffusion_models/processor.rs @@ -1,8 +1,8 @@ use std::{any::Any, num::NonZeroUsize, sync::Arc}; use anyhow::{Context, Result}; -use candle_core::Device; use indexmap::IndexMap; +use mcandle_core::Device; use tokenizers::Tokenizer; use crate::{ diff --git a/mistralrs-core/src/diffusion_models/response.rs b/mistralrs-core/src/diffusion_models/response.rs index cc6e92e67..44714fbfb 100644 --- a/mistralrs-core/src/diffusion_models/response.rs +++ b/mistralrs-core/src/diffusion_models/response.rs @@ -12,9 +12,9 @@ use crate::{ pub async fn send_responses( input_seqs: &mut [&mut Sequence], images: Vec, -) -> candle_core::Result<()> { +) -> mcandle_core::Result<()> { if input_seqs.len() != images.len() { - candle_core::bail!( + mcandle_core::bail!( "Input seqs len ({}) does not match images generated len ({})", input_seqs.len(), images.len() @@ -30,7 +30,7 @@ pub async fn send_responses( let saved_path = format!("image-generation-{}.png", Uuid::new_v4()); image .save_with_format(&saved_path, image::ImageFormat::Png) - .map_err(|e| candle_core::Error::Msg(e.to_string()))?; + .map_err(|e| mcandle_core::Error::Msg(e.to_string()))?; ImageChoice { url: Some(saved_path), b64_json: None, @@ -61,7 +61,7 @@ pub async fn send_responses( seq.responder(), ) .await - .map_err(candle_core::Error::msg)?; + .map_err(mcandle_core::Error::msg)?; seq.set_state(SequenceState::Done(StopReason::GeneratedImage)); } diff --git a/mistralrs-core/src/diffusion_models/t5/mod.rs b/mistralrs-core/src/diffusion_models/t5/mod.rs index e9980ae6b..833e0c9be 100644 --- a/mistralrs-core/src/diffusion_models/t5/mod.rs +++ b/mistralrs-core/src/diffusion_models/t5/mod.rs @@ -3,9 +3,9 @@ // T5 Text Model // https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py -use candle_core::{DType, Device, Module, Result, Tensor, D}; -use candle_nn::{embedding, linear_no_bias, Activation, Embedding, Linear, VarBuilder}; use float8::F8E4M3; +use mcandle_core::{DType, Device, Module, Result, Tensor, D}; +use mcandle_nn::{embedding, linear_no_bias, Activation, Embedding, Linear, VarBuilder}; use serde::Deserialize; use std::sync::Arc; @@ -42,7 +42,7 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result #[derive(Debug, Deserialize, Default, Clone, PartialEq)] pub struct ActivationWithOptionalGating { pub gated: bool, - pub activation: candle_nn::Activation, + pub activation: mcandle_nn::Activation, } pub fn deserialize_feed_forward_proj_activation<'de, D>( @@ -54,11 +54,11 @@ where match String::deserialize(deserializer)?.as_str() { "gated-gelu" => Ok(ActivationWithOptionalGating { gated: true, - activation: candle_nn::Activation::NewGelu, + activation: mcandle_nn::Activation::NewGelu, }), "gated-silu" => Ok(ActivationWithOptionalGating { gated: true, - activation: candle_nn::Activation::Silu, + activation: mcandle_nn::Activation::Silu, }), buf => { let activation = serde_plain::from_str(buf).map_err(serde::de::Error::custom)?; @@ -448,7 +448,7 @@ impl T5Attention { }, }; - let attn_weights = { candle_nn::ops::softmax_last_dim(&scores)? }; + let attn_weights = { mcandle_nn::ops::softmax_last_dim(&scores)? }; let attn_output = attn_weights.matmul(&v)?; let attn_output = attn_output .transpose(1, 2)? diff --git a/mistralrs-core/src/dummy_paged_attention/cache_engine.rs b/mistralrs-core/src/dummy_paged_attention/cache_engine.rs index 8dab1636b..f36e4f455 100644 --- a/mistralrs-core/src/dummy_paged_attention/cache_engine.rs +++ b/mistralrs-core/src/dummy_paged_attention/cache_engine.rs @@ -3,7 +3,7 @@ use std::{ sync::{Arc, Mutex, MutexGuard}, }; -use candle_core::{DType, Device, Result, Tensor}; +use mcandle_core::{DType, Device, Result, Tensor}; use super::config::ModelConfigLike; diff --git a/mistralrs-core/src/dummy_paged_attention/layers/paged_attention.rs b/mistralrs-core/src/dummy_paged_attention/layers/paged_attention.rs index 7cc29b102..46592cab2 100644 --- a/mistralrs-core/src/dummy_paged_attention/layers/paged_attention.rs +++ b/mistralrs-core/src/dummy_paged_attention/layers/paged_attention.rs @@ -1,4 +1,4 @@ -use candle_core::{Device, Result, Tensor}; +use mcandle_core::{Device, Result, Tensor}; use crate::pipeline::text_models_inputs_processor::PagedAttentionInputMetadata; diff --git a/mistralrs-core/src/dummy_paged_attention/mod.rs b/mistralrs-core/src/dummy_paged_attention/mod.rs index 0de51a0f5..fedc65cef 100644 --- a/mistralrs-core/src/dummy_paged_attention/mod.rs +++ b/mistralrs-core/src/dummy_paged_attention/mod.rs @@ -14,9 +14,9 @@ pub const _PAD_SLOT_ID: i64 = -1; pub use block_engine::{BlockEngine, BlockTables, LogicalTokenBlock}; pub use block_engine_sequence::BlockEngineSequence; pub use cache_engine::{CacheConfig, CacheEngine}; -use candle_core::{DType, Device}; pub use config::{ModelConfigLike, ModelConfigMetadata}; pub use layers::PagedAttention; +use mcandle_core::{DType, Device}; pub use scheduler::{ PagedAttentionScheduler, PagedAttentionSchedulerConfig, PagedAttentionSchedulerOutput, }; diff --git a/mistralrs-core/src/engine/mod.rs b/mistralrs-core/src/engine/mod.rs index 553ab99bd..8fd95986b 100644 --- a/mistralrs-core/src/engine/mod.rs +++ b/mistralrs-core/src/engine/mod.rs @@ -1,4 +1,4 @@ -use candle_core::Tensor; +use mcandle_core::Tensor; use once_cell::sync::Lazy; use std::{ collections::HashMap, diff --git a/mistralrs-core/src/gguf/content.rs b/mistralrs-core/src/gguf/content.rs index 02bf8286a..9f98a64d5 100644 --- a/mistralrs-core/src/gguf/content.rs +++ b/mistralrs-core/src/gguf/content.rs @@ -1,14 +1,14 @@ use std::{collections::HashMap, fs}; use anyhow::Context; -use candle_core::{ +use indexmap::IndexMap; +use mcandle_core::{ quantized::{ gguf_file::{self, Value}, QTensor, }, Device, Result, }; -use indexmap::IndexMap; use tracing::info; use crate::DEBUG; @@ -68,11 +68,11 @@ impl<'a, R: std::io::Seek + std::io::Read> Content<'a, R> { accum }); if n_splits.len() > 1 { - candle_core::bail!("GGUF files have differing `split.count` values: {n_splits:?}. Perhaps the GGUF files do not match?"); + mcandle_core::bail!("GGUF files have differing `split.count` values: {n_splits:?}. Perhaps the GGUF files do not match?"); } #[allow(clippy::cast_possible_truncation)] if !n_splits.is_empty() && n_readers != n_splits[0] as usize { - candle_core::bail!( + mcandle_core::bail!( "Number of GGUF files does not match the number of splits, expected {} files.", n_splits[0] ); @@ -120,7 +120,7 @@ impl<'a, R: std::io::Seek + std::io::Read> Content<'a, R> { return tensor_info.read(reader, ct.tensor_data_offset, device); } } - candle_core::bail!("Cannot find tensor info for {name}") + mcandle_core::bail!("Cannot find tensor info for {name}") } /// Check for a tensor, searching through each content. diff --git a/mistralrs-core/src/layers.rs b/mistralrs-core/src/layers.rs index 5b0cd9811..ce7d365ad 100644 --- a/mistralrs-core/src/layers.rs +++ b/mistralrs-core/src/layers.rs @@ -10,11 +10,11 @@ use std::{ }, }; -use candle_core::{ +use mcandle_core::{ quantized::{QMatMul, QTensor}, Context, DType, Device, IndexOp, Result, Tensor, D, }; -use candle_nn::{Conv2d, Conv2dConfig, Linear, Module, VarBuilder}; +use mcandle_nn::{Conv2d, Conv2dConfig, Linear, Module, VarBuilder}; use mistralrs_quant::QuantMethod; use serde::{Deserialize, Serialize}; @@ -37,14 +37,14 @@ pub struct RmsNorm { impl RmsNorm { pub fn new(size: usize, eps: f64, vb: VarBuilder) -> Result { - let inner = candle_nn::rms_norm_non_quant(size, eps, vb)?; + let inner = mcandle_nn::rms_norm_non_quant(size, eps, vb)?; let w = inner.inner().weight().clone(); Ok(Self { eps, weight: w }) } /// Gemma uses weight + 1.0 pub fn new_gemma(size: usize, eps: f64, vb: VarBuilder) -> Result { - let inner = candle_nn::rms_norm_non_quant(size, eps, vb)?; + let inner = mcandle_nn::rms_norm_non_quant(size, eps, vb)?; let w = (inner.inner().weight().clone() + 1.0)?; Ok(Self { eps, weight: w }) } @@ -68,7 +68,7 @@ impl RmsNorm { impl Module for RmsNorm { fn forward(&self, x: &Tensor) -> Result { - candle_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32) + mcandle_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32) } } @@ -117,7 +117,7 @@ impl QRmsNorm { } pub fn forward(&self, x: &Tensor) -> Result { - candle_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32) + mcandle_nn::ops::rms_norm(&x.contiguous()?, &self.weight, self.eps as f32) } } @@ -142,12 +142,12 @@ pub enum ScaledRopeType { } impl FromStr for ScaledRopeType { - type Err = candle_core::Error; + type Err = mcandle_core::Error; fn from_str(s: &str) -> std::result::Result { match s { "su" | "longrope" => Ok(Self::Su), "yarn" => Ok(Self::Yarn), - _ => Err(candle_core::Error::Msg( + _ => Err(mcandle_core::Error::Msg( "Expected either `su` or `yarn` scaled RoPE type.".to_string(), )), } @@ -290,18 +290,18 @@ impl PhiRotaryEmbedding { let dim = cfg.head_dim; if !matches!(scaling_type, ScaledRopeType::Su) { - candle_core::bail!("Scaled Phi3 RoPE (non-classic scaled, with mscales) must have type `su`/`longrope`."); + mcandle_core::bail!("Scaled Phi3 RoPE (non-classic scaled, with mscales) must have type `su`/`longrope`."); } if short_factor.len() != dim / 2 { - candle_core::bail!( + mcandle_core::bail!( "Misaligned length {}, expected {} for `su`/`longrope` short rescale factors", short_factor.len(), dim / 2 ); } if long_factor.len() != dim / 2 { - candle_core::bail!( + mcandle_core::bail!( "Misaligned length {}, expected {} for `su`/`longrope` long rescale factors", long_factor.len(), dim / 2 @@ -414,9 +414,9 @@ impl PhiRotaryEmbedding { let cos = cos.narrow(0, *offset, seq_len)?; let sin = sin.narrow(0, *offset, seq_len)?; let q_embed = - candle_nn::rotary_emb::rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?; + mcandle_nn::rotary_emb::rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?; let k_embed = - candle_nn::rotary_emb::rope(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?; + mcandle_nn::rotary_emb::rope(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?; q_embeds.push(q_embed); k_embeds.push(k_embed); } @@ -600,7 +600,7 @@ impl Llama3RotaryEmbedding { Some(MLlamaRopeScaling { rope_type: other, .. }) => { - candle_core::bail!( + mcandle_core::bail!( "MLlama doesn't support any other RoPE type than `llama3`, got {other:?}" ) } @@ -633,9 +633,9 @@ impl Llama3RotaryEmbedding { let cos = cos.narrow(0, *offset, seq_len)?; let sin = sin.narrow(0, *offset, seq_len)?; let rope = if *is_gptx { - candle_nn::rotary_emb::rope + mcandle_nn::rotary_emb::rope } else { - candle_nn::rotary_emb::rope_i + mcandle_nn::rotary_emb::rope_i }; let q_embed = rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?; let k_embed = rope(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?; @@ -723,8 +723,8 @@ impl Qwen2VLRotaryEmbedding { q: &mut Tensor, k: &mut Tensor, ) -> Result<()> { - *q = candle_nn::rotary_emb::rope(&q.contiguous()?, cos, sin)?; - *k = candle_nn::rotary_emb::rope(&k.contiguous()?, cos, sin)?; + *q = mcandle_nn::rotary_emb::rope(&q.contiguous()?, cos, sin)?; + *k = mcandle_nn::rotary_emb::rope(&k.contiguous()?, cos, sin)?; Ok(()) } } @@ -886,7 +886,7 @@ impl Module for QLinear { } #[derive(Debug, Clone)] -pub struct RotaryEmbedding(candle_nn::RotaryEmbedding); +pub struct RotaryEmbedding(mcandle_nn::RotaryEmbedding); impl RotaryEmbedding { pub fn new( @@ -897,7 +897,7 @@ impl RotaryEmbedding { is_gpt_neox: bool, dtype: DType, ) -> Result { - Ok(Self(candle_nn::RotaryEmbedding::new( + Ok(Self(mcandle_nn::RotaryEmbedding::new( base, head_dim, max_position_embeddings, @@ -916,7 +916,7 @@ impl RotaryEmbedding { is_gpt_neox: bool, dtype: DType, ) -> Result { - Ok(Self(candle_nn::RotaryEmbedding::new_partial( + Ok(Self(mcandle_nn::RotaryEmbedding::new_partial( base, head_dim, rot_dim, @@ -973,15 +973,15 @@ impl Module for Activation { Self::Relu2 => xs.relu()?.sqr(), Self::Relu6 => xs.clamp(0f32, 6f32), Self::Silu => xs.silu(), - Self::Sigmoid => candle_nn::ops::sigmoid(xs), - Self::HardSigmoid => candle_nn::ops::hard_sigmoid(xs), - Self::Swiglu => candle_nn::ops::swiglu(xs), - Self::Swish => xs * candle_nn::ops::sigmoid(xs)?, - Self::HardSwish => xs * candle_nn::ops::hard_sigmoid(xs)?, + Self::Sigmoid => mcandle_nn::ops::sigmoid(xs), + Self::HardSigmoid => mcandle_nn::ops::hard_sigmoid(xs), + Self::Swiglu => mcandle_nn::ops::swiglu(xs), + Self::Swish => xs * mcandle_nn::ops::sigmoid(xs)?, + Self::HardSwish => xs * mcandle_nn::ops::hard_sigmoid(xs)?, &Self::Elu(alpha) => xs.elu(alpha), - &Self::LeakyRelu(negative_slope) => candle_nn::ops::leaky_relu(xs, negative_slope), + &Self::LeakyRelu(negative_slope) => mcandle_nn::ops::leaky_relu(xs, negative_slope), Self::GeluPytorchTanh => xs.gelu(), - Self::QuickGelu => xs * candle_nn::ops::sigmoid(&(xs * 1.702f64)?), + Self::QuickGelu => xs * mcandle_nn::ops::sigmoid(&(xs * 1.702f64)?), } } } diff --git a/mistralrs-core/src/layers_masker.rs b/mistralrs-core/src/layers_masker.rs index 12f9a182a..f06cbd792 100644 --- a/mistralrs-core/src/layers_masker.rs +++ b/mistralrs-core/src/layers_masker.rs @@ -2,14 +2,14 @@ use std::ops::Add; -use candle_core::{DType, Device, Result, Tensor, WithDType}; +use mcandle_core::{DType, Device, Result, Tensor, WithDType}; use crate::pipeline::KvCache; // https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_attn_mask_utils.py pub struct CausalMasker; -// https://github.com/mokeyish/candle-ext/blob/main/src/triangular.rs +// https://github.com/mokeyish/mcandle-ext/blob/main/src/triangular.rs fn apply_tril(xs: &Tensor, diagonal: isize) -> Result { let device = xs.device(); let (l, s) = xs.dims2()?; @@ -23,7 +23,7 @@ fn apply_tril(xs: &Tensor, diagonal: isize) -> Result { xs * Tensor::from_vec(xs_tri, (l, s), device)?.to_dtype(xs.dtype())? } -// https://github.com/mokeyish/candle-ext/blob/main/src/masked_fill.rs +// https://github.com/mokeyish/mcandle-ext/blob/main/src/masked_fill.rs /// xs are on false (0), value is on true (1) pub fn masked_fill(xs: &Tensor, mask: &Tensor, value: D) -> Result { let on_true = Tensor::full(value, xs.shape(), xs.device())?.to_dtype(xs.dtype())?; @@ -82,7 +82,7 @@ impl<'a> PastKvLenCache for Option<&'a [(Tensor, Tensor)]> { match self { None => Ok(0), Some([(k_cache_1, _), ..]) => Ok(k_cache_1.dims()[2]), - _ => candle_core::bail!("Unreachable"), + _ => mcandle_core::bail!("Unreachable"), } } } @@ -122,7 +122,7 @@ impl CausalMasker { pub fn calculate_past_kv_len( &self, cache: &[Option<(Tensor, Tensor)>], - ) -> candle_core::Result { + ) -> mcandle_core::Result { let kv_cache_1 = &cache[0]; if kv_cache_1.is_none() { return Ok(0); diff --git a/mistralrs-core/src/layers_utils.rs b/mistralrs-core/src/layers_utils.rs index a521aae6b..753d9f239 100644 --- a/mistralrs-core/src/layers_utils.rs +++ b/mistralrs-core/src/layers_utils.rs @@ -1,4 +1,4 @@ -use candle_core::{Result, Tensor}; +use mcandle_core::{Result, Tensor}; pub fn repeat_kv(x: Tensor, n_rep: usize) -> Result { if n_rep == 1 { diff --git a/mistralrs-core/src/lib.rs b/mistralrs-core/src/lib.rs index a9f237ba4..5a5e8b87c 100644 --- a/mistralrs-core/src/lib.rs +++ b/mistralrs-core/src/lib.rs @@ -1,10 +1,10 @@ #![deny(clippy::cast_possible_truncation, clippy::cast_precision_loss)] -use candle_core::Device; use cublaslt::setup_cublas_lt_wrapper; use engine::Engine; pub use engine::{EngineInstruction, ENGINE_INSTRUCTIONS, TERMINATE_ALL_NEXT_STEP}; pub use lora::Ordering; +use mcandle_core::Device; pub use pipeline::ModelCategory; pub use pipeline::Pipeline; #[cfg(feature = "pyo3_macros")] @@ -242,30 +242,30 @@ pub(crate) static INHIBIT_GEMM_F16: AtomicBool = AtomicBool::new(false); #[cfg(feature = "cuda")] fn set_gemm_reduced_precision_f16() { - use candle_core::{DType, Device, Tensor}; + use mcandle_core::{DType, Device, Tensor}; // NOTE(EricLBuehler): When we support multi-GPU inference, we should check for each gpu here let a = Tensor::zeros((2, 2), DType::BF16, &Device::new_cuda(0).unwrap()).unwrap(); - candle_core::cuda::set_gemm_reduced_precision_bf16(true); + mcandle_core::cuda::set_gemm_reduced_precision_bf16(true); match a.matmul(&a) { Ok(_) => tracing::info!("Enabling GEMM reduced precision in BF16."), Err(e) => { if format!("{e:?}").contains("CUBLAS_STATUS_NOT_SUPPORTED") { tracing::info!("GEMM reduced precision in BF16 not supported."); - candle_core::cuda::set_gemm_reduced_precision_bf16(false); + mcandle_core::cuda::set_gemm_reduced_precision_bf16(false); INHIBIT_GEMM_F16.store(true, std::sync::atomic::Ordering::Relaxed); } } } let a = Tensor::zeros((2, 2), DType::F16, &Device::new_cuda(0).unwrap()).unwrap(); - candle_core::cuda::set_gemm_reduced_precision_f16(true); + mcandle_core::cuda::set_gemm_reduced_precision_f16(true); match a.matmul(&a) { Ok(_) => tracing::info!("Enabling GEMM reduced precision in F16."), Err(e) => { if format!("{e:?}").contains("CUBLAS_STATUS_NOT_SUPPORTED") { tracing::info!("GEMM reduced precision in F16 not supported."); - candle_core::cuda::set_gemm_reduced_precision_f16(false); + mcandle_core::cuda::set_gemm_reduced_precision_f16(false); INHIBIT_GEMM_F16.store(true, std::sync::atomic::Ordering::Relaxed); } } diff --git a/mistralrs-core/src/lora/loralinear.rs b/mistralrs-core/src/lora/loralinear.rs index 587fc54af..2c1a8bede 100644 --- a/mistralrs-core/src/lora/loralinear.rs +++ b/mistralrs-core/src/lora/loralinear.rs @@ -1,8 +1,8 @@ use std::{collections::HashMap, iter::zip, ops::Mul, sync::Arc}; -use candle_core::{bail, DType, Module, Result, Tensor}; -use candle_nn::{Linear, VarBuilder}; use either::Either; +use mcandle_core::{bail, DType, Module, Result, Tensor}; +use mcandle_nn::{Linear, VarBuilder}; use mistralrs_quant::{QuantMethod, QuantMethodConfig, UnquantLinear}; use super::{ diff --git a/mistralrs-core/src/lora/mod.rs b/mistralrs-core/src/lora/mod.rs index f006affae..407c8fcce 100644 --- a/mistralrs-core/src/lora/mod.rs +++ b/mistralrs-core/src/lora/mod.rs @@ -2,9 +2,9 @@ use std::{collections::HashSet, fmt::Debug, sync::Arc}; -use candle_core::{quantized::QTensor, DType, IndexOp, Result, Tensor, D}; -use candle_nn::{init, Linear, Module, VarBuilder}; use loralinear::LoraLinear; +use mcandle_core::{quantized::QTensor, DType, IndexOp, Result, Tensor, D}; +use mcandle_nn::{init, Linear, Module, VarBuilder}; use mistralrs_quant::QuantMethod; pub use qloralinear::QLoraLinear; use serde::Deserialize; @@ -190,7 +190,7 @@ pub fn linear( let module = prefix.split('.').last().unwrap(); let linear_config = LoraLinearConfig::new(d1, d2); - let inner = candle_nn::linear(d1, d2, base_vb.clone())?; + let inner = mcandle_nn::linear(d1, d2, base_vb.clone())?; let target_modules = &lora_config.first().map(|c| &c.1.target_modules); for (_, cfg) in lora_config { @@ -198,7 +198,7 @@ pub fn linear( .as_ref() .is_some_and(|target_modules| &cfg.target_modules != *target_modules) { - candle_core::bail!("Expected all target modules to be the same."); + mcandle_core::bail!("Expected all target modules to be the same."); } } @@ -242,7 +242,7 @@ pub fn linear_no_bias( let module = prefix.split('.').last().unwrap(); let linear_config = LoraLinearConfig::new(d1, d2); - let inner = candle_nn::linear_no_bias(d1, d2, base_vb.clone())?; + let inner = mcandle_nn::linear_no_bias(d1, d2, base_vb.clone())?; let target_modules = &lora_config.first().map(|c| &c.1.target_modules); for (_, cfg) in lora_config { @@ -250,7 +250,7 @@ pub fn linear_no_bias( .as_ref() .is_some_and(|target_modules| &cfg.target_modules != *target_modules) { - candle_core::bail!("Expected all target modules to be the same."); + mcandle_core::bail!("Expected all target modules to be the same."); } } diff --git a/mistralrs-core/src/lora/qloralinear.rs b/mistralrs-core/src/lora/qloralinear.rs index fe8dbb6bb..eda4e2ffa 100644 --- a/mistralrs-core/src/lora/qloralinear.rs +++ b/mistralrs-core/src/lora/qloralinear.rs @@ -1,8 +1,8 @@ use std::{collections::HashMap, iter::zip, ops::Mul, sync::Arc}; -use candle_core::{bail, quantized::QMatMul, DType, Module, Result, Tensor}; -use candle_nn::{Linear, VarBuilder}; use either::Either; +use mcandle_core::{bail, quantized::QMatMul, DType, Module, Result, Tensor}; +use mcandle_nn::{Linear, VarBuilder}; use mistralrs_quant::{GgufMatMul, QuantMethod, QuantMethodConfig, UnquantLinear}; use super::{ @@ -41,7 +41,7 @@ impl QLoraLinear { .as_ref() .is_some_and(|target_modules| &cfg.target_modules != *target_modules) { - candle_core::bail!("Expected all target modules to be the same."); + mcandle_core::bail!("Expected all target modules to be the same."); } } diff --git a/mistralrs-core/src/models/gemma.rs b/mistralrs-core/src/models/gemma.rs index bbad7693c..9f97c38b8 100644 --- a/mistralrs-core/src/models/gemma.rs +++ b/mistralrs-core/src/models/gemma.rs @@ -2,8 +2,8 @@ use std::{collections::HashMap, sync::Arc}; -use candle_core::{DType, Device, Module, Result, Tensor}; -use candle_nn::{Linear, RotaryEmbedding, VarBuilder}; +use mcandle_core::{DType, Device, Module, Result, Tensor}; +use mcandle_nn::{Linear, RotaryEmbedding, VarBuilder}; use mistralrs_quant::{QuantMethod, QuantMethodConfig, QuantizedConfig, UnquantLinear}; use crate::{ @@ -62,9 +62,9 @@ impl Config { match (self.hidden_act, self.hidden_activation) { (None, Some(act)) | (Some(act), None) => Ok(act), (Some(_), Some(_)) => { - candle_core::bail!("both hidden_act and hidden_activation are set") + mcandle_core::bail!("both hidden_act and hidden_activation are set") } - (None, None) => candle_core::bail!("none of hidden_act and hidden_activation are set"), + (None, None) => mcandle_core::bail!("none of hidden_act and hidden_activation are set"), } } } @@ -424,7 +424,7 @@ impl DecoderLayer { } pub struct Model { - embed_tokens: candle_nn::Embedding, + embed_tokens: mcandle_nn::Embedding, layers: Vec, norm: RmsNorm, lm_head: Arc, @@ -454,7 +454,7 @@ impl Model { let mapper = normal_loading_metadata.mapper; let vb_m = vb.pp("model"); - let embed_tokens = candle_nn::embedding( + let embed_tokens = mcandle_nn::embedding( cfg.vocab_size, cfg.hidden_size, mapper.set_nm_device(vb_m.pp("embed_tokens"), false), diff --git a/mistralrs-core/src/models/gemma2.rs b/mistralrs-core/src/models/gemma2.rs index 25fa6d16e..8f23a77cc 100644 --- a/mistralrs-core/src/models/gemma2.rs +++ b/mistralrs-core/src/models/gemma2.rs @@ -2,8 +2,8 @@ use std::{collections::HashMap, sync::Arc}; -use candle_core::{DType, Device, Module, Result, Tensor}; -use candle_nn::{Linear, RotaryEmbedding, VarBuilder}; +use mcandle_core::{DType, Device, Module, Result, Tensor}; +use mcandle_nn::{Linear, RotaryEmbedding, VarBuilder}; use mistralrs_quant::{QuantMethod, QuantMethodConfig, QuantizedConfig, UnquantLinear}; use crate::{ @@ -58,7 +58,7 @@ impl Config { // If both are set just use hidden_act Ok(act) } - (None, None) => candle_core::bail!("none of hidden_act and hidden_activation are set"), + (None, None) => mcandle_core::bail!("none of hidden_act and hidden_activation are set"), } } } @@ -461,7 +461,7 @@ impl DecoderLayer { } pub struct Model { - embed_tokens: candle_nn::Embedding, + embed_tokens: mcandle_nn::Embedding, layers: Vec, norm: RmsNorm, lm_head: Arc, @@ -493,7 +493,7 @@ impl Model { let mapper = normal_loading_metadata.mapper; let vb_m = vb.pp("model"); - let embed_tokens = candle_nn::embedding( + let embed_tokens = mcandle_nn::embedding( cfg.vocab_size, cfg.hidden_size, mapper.set_nm_device(vb_m.pp("embed_tokens"), false), diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 5126669f2..66677d1ad 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -1,7 +1,7 @@ #![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] -use candle_core::{DType, Device, Result, Tensor}; -use candle_nn::{embedding, Embedding, Module, VarBuilder}; +use mcandle_core::{DType, Device, Result, Tensor}; +use mcandle_nn::{embedding, Embedding, Module, VarBuilder}; use mistralrs_quant::{QuantMethod, QuantMethodConfig, QuantizedConfig, UnquantLinear}; use serde::{Deserialize, Serialize}; use std::{collections::HashMap, sync::Arc}; @@ -266,7 +266,7 @@ impl MlpLayer for Mlp { if let Some(t) = self.c_fc1.quantized_act_type() { x = x.to_dtype(t)?; } - let x = (candle_nn::ops::silu(&MatMul.qmethod_matmul(&x, &*self.c_fc1)?)? + let x = (mcandle_nn::ops::silu(&MatMul.qmethod_matmul(&x, &*self.c_fc1)?)? * MatMul.qmethod_matmul(&x, &*self.c_fc2)?)?; let mut res = MatMul.qmethod_matmul(&x, &*self.c_proj)?; if self.c_fc1.quantized_act_type().is_some() { @@ -468,7 +468,7 @@ impl Llama { )? } else { Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized( - candle_nn::Linear::new( + mcandle_nn::Linear::new( mapper.cast_nm_device(wte.embeddings(), normal_loading_metadata.loading_isq)?, None, ), diff --git a/mistralrs-core/src/models/mistral.rs b/mistralrs-core/src/models/mistral.rs index 4c838c6ce..d751d8f61 100644 --- a/mistralrs-core/src/models/mistral.rs +++ b/mistralrs-core/src/models/mistral.rs @@ -1,8 +1,8 @@ #![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] /// Mistral LLM, https://github.com/mistralai/mistral-src -use candle_core::{DType, Device, Module, Result, Tensor}; -use candle_nn::VarBuilder; +use mcandle_core::{DType, Device, Module, Result, Tensor}; +use mcandle_nn::VarBuilder; use mistralrs_quant::{QuantMethod, QuantMethodConfig, QuantizedConfig, UnquantLinear}; use serde::Serialize; use std::{collections::HashMap, sync::Arc}; @@ -402,7 +402,7 @@ impl DecoderLayer { } pub struct Model { - embed_tokens: candle_nn::Embedding, + embed_tokens: mcandle_nn::Embedding, layers: Vec, norm: RmsNorm, lm_head: Arc, @@ -451,7 +451,7 @@ impl Model { } let mapper = normal_loading_metadata.mapper; - let embed_tokens = candle_nn::embedding( + let embed_tokens = mcandle_nn::embedding( cfg.vocab_size, cfg.hidden_size, mapper.set_nm_device(vb_m.pp("embed_tokens"), false), @@ -525,7 +525,7 @@ impl Model { )? } else { Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized( - candle_nn::Linear::new( + mcandle_nn::Linear::new( mapper.cast_nm_device( embed_tokens.embeddings(), normal_loading_metadata.loading_isq, diff --git a/mistralrs-core/src/models/mixtral.rs b/mistralrs-core/src/models/mixtral.rs index 1e894fbc3..efe81203c 100644 --- a/mistralrs-core/src/models/mixtral.rs +++ b/mistralrs-core/src/models/mixtral.rs @@ -3,8 +3,8 @@ /// Mixtral Model /// https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py /// https://mistral.ai/news/mixtral-of-experts/ -use candle_core::{DType, Device, Module, Result, Tensor}; -use candle_nn::{RotaryEmbedding, VarBuilder}; +use mcandle_core::{DType, Device, Module, Result, Tensor}; +use mcandle_nn::{RotaryEmbedding, VarBuilder}; use mistralrs_quant::{QuantMethod, QuantMethodConfig, QuantizedConfig, UnquantLinear}; use serde::{Deserialize, Serialize}; use std::{collections::HashMap, sync::Arc}; @@ -323,7 +323,7 @@ impl Module for SparseMoeBlock { router_logits = router_logits.to_dtype(original_dtype)?; } - let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?; + let routing_weights = mcandle_nn::ops::softmax_last_dim(&router_logits)?; // In order to extract topk, we extract the data from the tensor and manipulate it // directly. Maybe we will want to use some custom ops instead at some point. @@ -455,7 +455,7 @@ impl DecoderLayer { } pub struct Model { - embed_tokens: candle_nn::Embedding, + embed_tokens: mcandle_nn::Embedding, layers: Vec, norm: RmsNorm, lm_head: Arc, @@ -485,7 +485,7 @@ impl Model { let mapper = normal_loading_metadata.mapper; let vb_m = vb.pp("model"); - let embed_tokens = candle_nn::embedding( + let embed_tokens = mcandle_nn::embedding( cfg.vocab_size, cfg.hidden_size, mapper.set_nm_device(vb_m.pp("embed_tokens"), false), @@ -557,7 +557,7 @@ impl Model { )? } else { Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized( - candle_nn::Linear::new( + mcandle_nn::Linear::new( mapper.cast_nm_device( embed_tokens.embeddings(), normal_loading_metadata.loading_isq, diff --git a/mistralrs-core/src/models/phi2.rs b/mistralrs-core/src/models/phi2.rs index 25e824bca..2cc186128 100644 --- a/mistralrs-core/src/models/phi2.rs +++ b/mistralrs-core/src/models/phi2.rs @@ -6,8 +6,8 @@ use std::{collections::HashMap, sync::Arc}; /// https://huggingface.co/microsoft/phi-2 /// This corresponds to the model update made with the following commit: /// https://huggingface.co/microsoft/phi-2/commit/cb2f4533604d8b67de604e7df03bfe6f3ca22869 -use candle_core::{DType, Device, Result, Tensor}; -use candle_nn::{embedding, layer_norm, Embedding, LayerNorm, RotaryEmbedding, VarBuilder}; +use mcandle_core::{DType, Device, Result, Tensor}; +use mcandle_nn::{embedding, layer_norm, Embedding, LayerNorm, RotaryEmbedding, VarBuilder}; use mistralrs_quant::{QuantMethod, QuantizedConfig}; use serde::{Deserialize, Serialize}; diff --git a/mistralrs-core/src/models/phi3.rs b/mistralrs-core/src/models/phi3.rs index 9384df0fd..0217376c0 100644 --- a/mistralrs-core/src/models/phi3.rs +++ b/mistralrs-core/src/models/phi3.rs @@ -2,8 +2,8 @@ // This implementation is based on: // https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py -use candle_core::{DType, Device, Module, Result, Tensor, D}; -use candle_nn::VarBuilder; +use mcandle_core::{DType, Device, Module, Result, Tensor, D}; +use mcandle_nn::VarBuilder; use mistralrs_quant::{QuantMethod, QuantMethodConfig, QuantizedConfig, UnquantLinear}; use std::{collections::HashMap, sync::Arc}; @@ -396,7 +396,7 @@ impl DecoderLayer { } pub struct Model { - embed_tokens: candle_nn::Embedding, + embed_tokens: mcandle_nn::Embedding, layers: Vec, norm: RmsNorm, lm_head: Arc, @@ -426,7 +426,7 @@ impl Model { let mapper = normal_loading_metadata.mapper; let vb_m = vb.pp("model"); - let embed_tokens = candle_nn::embedding( + let embed_tokens = mcandle_nn::embedding( cfg.vocab_size, cfg.hidden_size, mapper.set_nm_device(vb_m.pp("embed_tokens"), false), @@ -490,7 +490,7 @@ impl Model { )? } else { Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized( - candle_nn::Linear::new( + mcandle_nn::Linear::new( mapper.cast_nm_device( embed_tokens.embeddings(), normal_loading_metadata.loading_isq, diff --git a/mistralrs-core/src/models/phi3_5_moe.rs b/mistralrs-core/src/models/phi3_5_moe.rs index 672065ec4..2f6d6448b 100644 --- a/mistralrs-core/src/models/phi3_5_moe.rs +++ b/mistralrs-core/src/models/phi3_5_moe.rs @@ -2,8 +2,8 @@ // This implementation is based on: // https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py -use candle_core::{Device, IndexOp, Module, Result, Tensor, D}; -use candle_nn::{layer_norm, LayerNorm, VarBuilder}; +use mcandle_core::{Device, IndexOp, Module, Result, Tensor, D}; +use mcandle_nn::{layer_norm, LayerNorm, VarBuilder}; use mistralrs_quant::{QuantMethod, QuantizedConfig}; use std::{collections::HashMap, sync::Arc}; @@ -301,7 +301,7 @@ impl Mlp { } struct MoeMlp { - gate: candle_nn::Linear, + gate: mcandle_nn::Linear, experts: Vec, router_jitter_noise: f64, num_experts: usize, @@ -310,7 +310,7 @@ struct MoeMlp { impl MoeMlp { fn new(cfg: &Config, vb: VarBuilder, layer_device: Device) -> Result { let num_experts = cfg.num_local_experts; - let gate = candle_nn::linear_no_bias( + let gate = mcandle_nn::linear_no_bias( cfg.hidden_size, num_experts, vb.pp("gate").set_device(layer_device), @@ -344,7 +344,7 @@ impl MoeMlp { let masked_gates = masked_fill(scores, &mask_logits_threshold, f64::NEG_INFINITY)?; // Compute scores - let masked_gates = candle_nn::ops::softmax_last_dim(&masked_gates)?; + let masked_gates = mcandle_nn::ops::softmax_last_dim(&masked_gates)?; let multiplier = masked_gates.gather(&selected_experts, D::Minus1)?; // Mask out first expert @@ -368,7 +368,7 @@ impl MoeMlp { // Apply mask let masked_gates_top2 = masked_fill(&masked_scores, &mask_logits_threshold, f64::NEG_INFINITY)?; - let masked_gates_top2 = candle_nn::ops::softmax_last_dim(&masked_gates_top2)?; + let masked_gates_top2 = mcandle_nn::ops::softmax_last_dim(&masked_gates_top2)?; let multiplier_top2 = masked_gates_top2.gather(&selected_experts_top2, D::Minus1)?; let multiplier = Tensor::cat(&[multiplier, multiplier_top2], D::Minus1)?; @@ -399,7 +399,7 @@ impl MoeMlp { // One hot encode the selected experts to create an expert mask // this will be used to easily index which expert to activate let experts_mask = - candle_nn::encoding::one_hot(selected_experts, self.num_experts, 1u8, 0u8)? + mcandle_nn::encoding::one_hot(selected_experts, self.num_experts, 1u8, 0u8)? .permute((2, 1, 0))?; // Loop over all avail experts in the model and perform the computation on each expert @@ -524,7 +524,7 @@ impl DecoderLayer { } pub struct Model { - embed_tokens: candle_nn::Embedding, + embed_tokens: mcandle_nn::Embedding, layers: Vec, norm: LayerNorm, lm_head: Arc, @@ -554,7 +554,7 @@ impl Model { let mapper = normal_loading_metadata.mapper; let vb_m = vb.pp("model"); - let embed_tokens = candle_nn::embedding( + let embed_tokens = mcandle_nn::embedding( cfg.vocab_size, cfg.hidden_size, mapper.set_nm_device(vb_m.pp("embed_tokens"), false), diff --git a/mistralrs-core/src/models/quantized_llama.rs b/mistralrs-core/src/models/quantized_llama.rs index de1374ebe..7acf64aa8 100644 --- a/mistralrs-core/src/models/quantized_llama.rs +++ b/mistralrs-core/src/models/quantized_llama.rs @@ -3,10 +3,10 @@ use std::collections::HashMap; use std::sync::Arc; -use candle_core::quantized::ggml_file; -use candle_core::quantized::QTensor; -use candle_core::{DType, Device, Result, Tensor}; -use candle_nn::{Embedding, Module, RotaryEmbedding}; +use mcandle_core::quantized::ggml_file; +use mcandle_core::quantized::QTensor; +use mcandle_core::{DType, Device, Result, Tensor}; +use mcandle_nn::{Embedding, Module, RotaryEmbedding}; use mistralrs_quant::{GgufMatMul, QuantMethod, QuantMethodConfig}; use crate::attention::SdpaParams; @@ -37,7 +37,7 @@ impl Mlp { fn forward(&self, xs: &Tensor) -> Result { let w1 = MatMul.qmethod_matmul(xs, &*self.feed_forward_w1)?; let w3 = MatMul.qmethod_matmul(xs, &*self.feed_forward_w3)?; - let y = &(candle_nn::ops::silu(&w1)? * w3)?; + let y = &(mcandle_nn::ops::silu(&w1)? * w3)?; MatMul.qmethod_matmul(y, &*self.feed_forward_w2) } } @@ -62,7 +62,7 @@ impl MlpOrMoe { let (b_size, seq_len, hidden_dim) = xs.dims3()?; let xs = xs.reshape(((), hidden_dim))?; let router_logits = MatMul.qmethod_matmul(&xs, &**feed_forward_gate_inp)?; - let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?; + let routing_weights = mcandle_nn::ops::softmax_last_dim(&router_logits)?; // In order to extract topk, we extract the data from the tensor and manipulate it // directly. Maybe we will want to use some custom ops instead at some point. @@ -424,7 +424,7 @@ impl ModelConfig::FromGGUF for ModelWeights { rope_freq_base, key_length, value_length, - } = PropsGGUF::try_from(metadata).or_else(|err| candle_core::bail!("{err}"))?; + } = PropsGGUF::try_from(metadata).or_else(|err| mcandle_core::bail!("{err}"))?; let qtok_embeddings = ct.tensor("token_embd.weight", device)?; let tok_embeddings = qtok_embeddings.dequantize(device)?; @@ -440,7 +440,7 @@ impl ModelConfig::FromGGUF for ModelWeights { let head_dim = key_length; if key_length != value_length { - candle_core::bail!( + mcandle_core::bail!( "Expected key_length == value_length, got {key_length} != {value_length}" ); } diff --git a/mistralrs-core/src/models/quantized_phi2.rs b/mistralrs-core/src/models/quantized_phi2.rs index 3c6719838..8374b19a3 100644 --- a/mistralrs-core/src/models/quantized_phi2.rs +++ b/mistralrs-core/src/models/quantized_phi2.rs @@ -2,10 +2,10 @@ use std::sync::Arc; -use candle_core::quantized::QMatMul; -use candle_core::quantized::QTensor; -use candle_core::{DType, Device, IndexOp, Module, Result, Tensor, D}; -use candle_nn::{Embedding, LayerNorm}; +use mcandle_core::quantized::QMatMul; +use mcandle_core::quantized::QTensor; +use mcandle_core::{DType, Device, IndexOp, Module, Result, Tensor, D}; +use mcandle_nn::{Embedding, LayerNorm}; use mistralrs_quant::GgufMatMul; use mistralrs_quant::QuantMethod; use mistralrs_quant::QuantMethodConfig; @@ -67,8 +67,11 @@ impl LayerWeights { for (b, offset) in (0..xs.dim(0)?).zip(start_offsets) { let cos = self.cos.narrow(0, *offset, seq_len)?; let sin = self.sin.narrow(0, *offset, seq_len)?; - let xs_rot = - candle_nn::rotary_emb::rope(&xs_rot.i(b)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?; + let xs_rot = mcandle_nn::rotary_emb::rope( + &xs_rot.i(b)?.unsqueeze(0)?.contiguous()?, + &cos, + &sin, + )?; chunks.push(Tensor::cat(&[&xs_rot, &xs_pass], D::Minus1)?); } Tensor::cat(&chunks, 0)?.contiguous() @@ -238,7 +241,7 @@ impl ModelConfig::FromGGUF for ModelWeights { rope_dim, ln_eps, max_seq_len, - } = PropsGGUF::try_from(metadata).or_else(|err| candle_core::bail!("{err}"))?; + } = PropsGGUF::try_from(metadata).or_else(|err| mcandle_core::bail!("{err}"))?; let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device, max_seq_len)?; diff --git a/mistralrs-core/src/models/quantized_phi3.rs b/mistralrs-core/src/models/quantized_phi3.rs index db4e910ce..f5eb289c6 100644 --- a/mistralrs-core/src/models/quantized_phi3.rs +++ b/mistralrs-core/src/models/quantized_phi3.rs @@ -14,10 +14,10 @@ use crate::utils::gguf_metadata::ContentMetadata; use crate::utils::model_config as ModelConfig; use crate::utils::progress::NiceProgressBar; use crate::{DeviceMapMetadata, Topology}; -use candle_core::quantized::QMatMul; -use candle_core::quantized::QTensor; -use candle_core::{DType, Device, IndexOp, Module, Result, Tensor, D}; -use candle_nn::Embedding; +use mcandle_core::quantized::QMatMul; +use mcandle_core::quantized::QTensor; +use mcandle_core::{DType, Device, IndexOp, Module, Result, Tensor, D}; +use mcandle_nn::Embedding; use mistralrs_quant::{GgufMatMul, QuantMethod, QuantMethodConfig}; #[derive(Clone)] @@ -66,7 +66,7 @@ impl LayerWeights { for (i, offset) in seqlen_offsets.iter().enumerate() { let cos = self.cos.narrow(0, *offset, seq_len)?; let sin = self.sin.narrow(0, *offset, seq_len)?; - outputs.push(candle_nn::rotary_emb::rope( + outputs.push(mcandle_nn::rotary_emb::rope( &xs.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin, @@ -249,7 +249,7 @@ impl ModelConfig::FromGGUF for ModelWeights { rope_dim, rms_eps, context_window, - } = PropsGGUF::try_from(metadata).or_else(|err| candle_core::bail!("{err}"))?; + } = PropsGGUF::try_from(metadata).or_else(|err| mcandle_core::bail!("{err}"))?; let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device, context_window)?; diff --git a/mistralrs-core/src/models/quantized_qwen2.rs b/mistralrs-core/src/models/quantized_qwen2.rs index 681288dad..da942f859 100644 --- a/mistralrs-core/src/models/quantized_qwen2.rs +++ b/mistralrs-core/src/models/quantized_qwen2.rs @@ -3,8 +3,8 @@ use std::collections::HashMap; use std::sync::Arc; -use candle_core::{DType, Device, Result, Tensor}; -use candle_nn::{Embedding, Module, RotaryEmbedding}; +use mcandle_core::{DType, Device, Result, Tensor}; +use mcandle_nn::{Embedding, Module, RotaryEmbedding}; use mistralrs_quant::{GgufMatMul, QuantMethod, QuantMethodConfig}; use crate::attention::SdpaParams; @@ -32,7 +32,7 @@ impl Mlp { fn forward(&self, xs: &Tensor) -> Result { let w1 = MatMul.qmethod_matmul(xs, &*self.feed_forward_w1)?; let w3 = MatMul.qmethod_matmul(xs, &*self.feed_forward_w3)?; - let y = &(candle_nn::ops::silu(&w1)? * w3)?; + let y = &(mcandle_nn::ops::silu(&w1)? * w3)?; MatMul.qmethod_matmul(y, &*self.feed_forward_w2) } } @@ -230,7 +230,7 @@ impl ModelConfig::FromGGUF for ModelWeights { rope_freq_base, key_length, value_length, - } = PropsGGUF::try_from(metadata).or_else(|err| candle_core::bail!("{err}"))?; + } = PropsGGUF::try_from(metadata).or_else(|err| mcandle_core::bail!("{err}"))?; let qtok_embeddings = ct.tensor("token_embd.weight", device)?; let tok_embeddings = qtok_embeddings.dequantize(device)?; @@ -246,7 +246,7 @@ impl ModelConfig::FromGGUF for ModelWeights { let head_dim = key_length; if key_length != value_length { - candle_core::bail!( + mcandle_core::bail!( "Expected key_length == value_length, got {key_length} != {value_length}" ); } diff --git a/mistralrs-core/src/models/quantized_starcoder2.rs b/mistralrs-core/src/models/quantized_starcoder2.rs index 33265aa5a..be9f12752 100644 --- a/mistralrs-core/src/models/quantized_starcoder2.rs +++ b/mistralrs-core/src/models/quantized_starcoder2.rs @@ -15,10 +15,10 @@ use crate::utils::gguf_metadata::ContentMetadata; use crate::utils::model_config as ModelConfig; use crate::utils::progress::NiceProgressBar; use crate::{DeviceMapMetadata, Topology}; -use candle_core::quantized::QMatMul; -use candle_core::quantized::QTensor; -use candle_core::{DType, Device, IndexOp, Module, Result, Tensor}; -use candle_nn::{Embedding, LayerNorm}; +use mcandle_core::quantized::QMatMul; +use mcandle_core::quantized::QTensor; +use mcandle_core::{DType, Device, IndexOp, Module, Result, Tensor}; +use mcandle_nn::{Embedding, LayerNorm}; use mistralrs_quant::{GgufMatMul, QuantMethod, QuantMethodConfig}; #[derive(Clone)] @@ -32,7 +32,7 @@ impl Module for Mlp { MatMul.qmethod_matmul( &MatMul .qmethod_matmul(xs, &*self.ffn_up)? - .apply(&candle_nn::Activation::GeluPytorchTanh)?, + .apply(&mcandle_nn::Activation::GeluPytorchTanh)?, &*self.ffn_down, ) } @@ -216,7 +216,7 @@ impl ModelConfig::FromGGUF for ModelWeights { layer_norm_epsilon, context_window, rope_freq_base, - } = PropsGGUF::try_from(metadata).or_else(|err| candle_core::bail!("{err}"))?; + } = PropsGGUF::try_from(metadata).or_else(|err| mcandle_core::bail!("{err}"))?; let tok_embeddings = ct.tensor("token_embd.weight", device)?; let tok_embeddings = tok_embeddings.dequantize(device)?; diff --git a/mistralrs-core/src/models/qwen2.rs b/mistralrs-core/src/models/qwen2.rs index 101dbdb03..4bef68b7e 100644 --- a/mistralrs-core/src/models/qwen2.rs +++ b/mistralrs-core/src/models/qwen2.rs @@ -1,7 +1,7 @@ #![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] -use candle_core::{DType, Device, Module, Result, Tensor}; -use candle_nn::{RotaryEmbedding, VarBuilder}; +use mcandle_core::{DType, Device, Module, Result, Tensor}; +use mcandle_nn::{RotaryEmbedding, VarBuilder}; use mistralrs_quant::{QuantMethod, QuantMethodConfig, QuantizedConfig, UnquantLinear}; use std::{collections::HashMap, sync::Arc}; @@ -393,7 +393,7 @@ impl DecoderLayer { } pub struct Model { - embed_tokens: candle_nn::Embedding, + embed_tokens: mcandle_nn::Embedding, layers: Vec, norm: RmsNorm, lm_head: Arc, @@ -423,7 +423,7 @@ impl Model { let mapper = normal_loading_metadata.mapper; let vb_m = vb.pp("model"); - let embed_tokens = candle_nn::embedding( + let embed_tokens = mcandle_nn::embedding( cfg.vocab_size, cfg.hidden_size, mapper.set_nm_device(vb_m.pp("embed_tokens"), false), @@ -497,7 +497,7 @@ impl Model { )? } else { Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized( - candle_nn::Linear::new( + mcandle_nn::Linear::new( mapper.cast_nm_device( embed_tokens.embeddings(), normal_loading_metadata.loading_isq, diff --git a/mistralrs-core/src/models/starcoder2.rs b/mistralrs-core/src/models/starcoder2.rs index e292437bf..068370cc5 100644 --- a/mistralrs-core/src/models/starcoder2.rs +++ b/mistralrs-core/src/models/starcoder2.rs @@ -1,7 +1,7 @@ #![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] -use candle_core::{DType, Device, Module, Result, Tensor}; -use candle_nn::{layer_norm, LayerNorm, Linear, VarBuilder}; +use mcandle_core::{DType, Device, Module, Result, Tensor}; +use mcandle_nn::{layer_norm, LayerNorm, Linear, VarBuilder}; use mistralrs_quant::{QuantMethod, QuantMethodConfig, QuantizedConfig, UnquantLinear}; use std::{collections::HashMap, sync::Arc}; @@ -387,7 +387,7 @@ impl DecoderLayer { } pub struct Model { - embed_tokens: candle_nn::Embedding, + embed_tokens: mcandle_nn::Embedding, layers: Vec, norm: LayerNorm, lm_head: Arc, @@ -417,7 +417,7 @@ impl Model { let mapper = normal_loading_metadata.mapper; let vb_m = vb.pp("model"); - let embed_tokens = candle_nn::embedding( + let embed_tokens = mcandle_nn::embedding( cfg.vocab_size, cfg.hidden_size, mapper.set_nm_device(vb_m.pp("embed_tokens"), false), diff --git a/mistralrs-core/src/ops.rs b/mistralrs-core/src/ops.rs index 295ac9407..2ca7a214a 100644 --- a/mistralrs-core/src/ops.rs +++ b/mistralrs-core/src/ops.rs @@ -1,4 +1,4 @@ -use candle_core::{ +use mcandle_core::{ backend::BackendStorage, shape::Dim, CpuStorage, CustomOp1, CustomOp2, DType, Error, Layout, Result, Shape, Tensor, WithDType, D, }; @@ -11,10 +11,10 @@ use std::{ #[cfg(feature = "cuda")] use crate::cuda::ffi; #[cfg(feature = "cuda")] -use candle_core::cuda::{cudarc::driver::DevicePtr, CudaStorage, WrapErr}; -#[cfg(feature = "cuda")] use half::{bf16, f16}; #[cfg(feature = "cuda")] +use mcandle_core::cuda::{cudarc::driver::DevicePtr, CudaStorage, WrapErr}; +#[cfg(feature = "cuda")] use std::ffi::c_void; pub enum BitWiseOpEnum { And, @@ -322,9 +322,9 @@ impl BitWiseOp for Tensor { #[cfg(feature = "metal")] fn bitwise_and(&self, rhs: &Tensor) -> Result { let original_device = rhs.device(); - self.to_device(&candle_core::Device::Cpu)? + self.to_device(&mcandle_core::Device::Cpu)? .apply_op2_no_bwd( - &rhs.to_device(&candle_core::Device::Cpu)?, + &rhs.to_device(&mcandle_core::Device::Cpu)?, &BitWise::new(BitWiseOpEnum::And), )? .to_device(original_device) @@ -337,9 +337,9 @@ impl BitWiseOp for Tensor { #[cfg(feature = "metal")] fn bitwise_or(&self, rhs: &Tensor) -> Result { let original_device = rhs.device(); - self.to_device(&candle_core::Device::Cpu)? + self.to_device(&mcandle_core::Device::Cpu)? .apply_op2_no_bwd( - &rhs.to_device(&candle_core::Device::Cpu)?, + &rhs.to_device(&mcandle_core::Device::Cpu)?, &BitWise::new(BitWiseOpEnum::Or), )? .to_device(original_device) @@ -352,9 +352,9 @@ impl BitWiseOp for Tensor { #[cfg(feature = "metal")] fn bitwise_xor(&self, rhs: &Tensor) -> Result { let original_device = rhs.device(); - self.to_device(&candle_core::Device::Cpu)? + self.to_device(&mcandle_core::Device::Cpu)? .apply_op2_no_bwd( - &rhs.to_device(&candle_core::Device::Cpu)?, + &rhs.to_device(&mcandle_core::Device::Cpu)?, &BitWise::new(BitWiseOpEnum::Xor), )? .to_device(original_device) @@ -389,26 +389,26 @@ impl NonZero { } #[cfg(feature = "cuda")] -fn count_nonzero_cuda(dtype: candle_core::DType, d_in: *const c_void, n: u32) -> u32 { +fn count_nonzero_cuda(dtype: mcandle_core::DType, d_in: *const c_void, n: u32) -> u32 { unsafe { match dtype { - candle_core::DType::U8 => ffi::count_nonzero_u8(d_in, n), - candle_core::DType::U32 => ffi::count_nonzero_u32(d_in, n), - candle_core::DType::I64 => ffi::count_nonzero_i64(d_in, n), - candle_core::DType::I16 => ffi::count_nonzero_i16(d_in, n), - candle_core::DType::I32 => ffi::count_nonzero_i32(d_in, n), - candle_core::DType::BF16 => ffi::count_nonzero_bf16(d_in, n), - candle_core::DType::F16 => ffi::count_nonzero_f16(d_in, n), - candle_core::DType::F32 => ffi::count_nonzero_f32(d_in, n), - candle_core::DType::F64 => ffi::count_nonzero_f64(d_in, n), - candle_core::DType::F8E4M3 => todo!(), + mcandle_core::DType::U8 => ffi::count_nonzero_u8(d_in, n), + mcandle_core::DType::U32 => ffi::count_nonzero_u32(d_in, n), + mcandle_core::DType::I64 => ffi::count_nonzero_i64(d_in, n), + mcandle_core::DType::I16 => ffi::count_nonzero_i16(d_in, n), + mcandle_core::DType::I32 => ffi::count_nonzero_i32(d_in, n), + mcandle_core::DType::BF16 => ffi::count_nonzero_bf16(d_in, n), + mcandle_core::DType::F16 => ffi::count_nonzero_f16(d_in, n), + mcandle_core::DType::F32 => ffi::count_nonzero_f32(d_in, n), + mcandle_core::DType::F64 => ffi::count_nonzero_f64(d_in, n), + mcandle_core::DType::F8E4M3 => todo!(), } } } #[cfg(feature = "cuda")] fn nonzero_cuda( - dtype: candle_core::DType, + dtype: mcandle_core::DType, d_in: *const c_void, n: u32, num_nonzero: u32, @@ -418,32 +418,32 @@ fn nonzero_cuda( ) { unsafe { match dtype { - candle_core::DType::U8 => ffi::nonzero_u8(d_in, n, num_nonzero, dims, num_dims, d_out), - candle_core::DType::U32 => { + mcandle_core::DType::U8 => ffi::nonzero_u8(d_in, n, num_nonzero, dims, num_dims, d_out), + mcandle_core::DType::U32 => { ffi::nonzero_u32(d_in, n, num_nonzero, dims, num_dims, d_out) } - candle_core::DType::I64 => { + mcandle_core::DType::I64 => { ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out) } - candle_core::DType::I32 => { + mcandle_core::DType::I32 => { ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out) } - candle_core::DType::I16 => { + mcandle_core::DType::I16 => { ffi::nonzero_i16(d_in, n, num_nonzero, dims, num_dims, d_out) } - candle_core::DType::BF16 => { + mcandle_core::DType::BF16 => { ffi::nonzero_bf16(d_in, n, num_nonzero, dims, num_dims, d_out) } - candle_core::DType::F16 => { + mcandle_core::DType::F16 => { ffi::nonzero_f16(d_in, n, num_nonzero, dims, num_dims, d_out) } - candle_core::DType::F32 => { + mcandle_core::DType::F32 => { ffi::nonzero_f32(d_in, n, num_nonzero, dims, num_dims, d_out) } - candle_core::DType::F64 => { + mcandle_core::DType::F64 => { ffi::nonzero_f64(d_in, n, num_nonzero, dims, num_dims, d_out) } - candle_core::DType::F8E4M3 => todo!(), + mcandle_core::DType::F8E4M3 => todo!(), } } } @@ -458,16 +458,16 @@ impl CustomOp1 for NonZero { return Err(Error::RequiresContiguous { op: "nonzero" }); } let result = match storage { - candle_core::CpuStorage::U8(vs) => self.nonzero(vs, layout), - candle_core::CpuStorage::U32(vs) => self.nonzero(vs, layout), - candle_core::CpuStorage::I16(vs) => self.nonzero(vs, layout), - candle_core::CpuStorage::I32(vs) => self.nonzero(vs, layout), - candle_core::CpuStorage::I64(vs) => self.nonzero(vs, layout), - candle_core::CpuStorage::BF16(vs) => self.nonzero(vs, layout), - candle_core::CpuStorage::F16(vs) => self.nonzero(vs, layout), - candle_core::CpuStorage::F32(vs) => self.nonzero(vs, layout), - candle_core::CpuStorage::F64(vs) => self.nonzero(vs, layout), - candle_core::CpuStorage::F8E4M3(_vs) => todo!(), + mcandle_core::CpuStorage::U8(vs) => self.nonzero(vs, layout), + mcandle_core::CpuStorage::U32(vs) => self.nonzero(vs, layout), + mcandle_core::CpuStorage::I16(vs) => self.nonzero(vs, layout), + mcandle_core::CpuStorage::I32(vs) => self.nonzero(vs, layout), + mcandle_core::CpuStorage::I64(vs) => self.nonzero(vs, layout), + mcandle_core::CpuStorage::BF16(vs) => self.nonzero(vs, layout), + mcandle_core::CpuStorage::F16(vs) => self.nonzero(vs, layout), + mcandle_core::CpuStorage::F32(vs) => self.nonzero(vs, layout), + mcandle_core::CpuStorage::F64(vs) => self.nonzero(vs, layout), + mcandle_core::CpuStorage::F8E4M3(_vs) => todo!(), }; let index_len = layout.dims().len(); let result_len = result.len() / index_len; @@ -478,24 +478,24 @@ impl CustomOp1 for NonZero { #[cfg(feature = "cuda")] fn cuda_fwd( &self, - storage: &candle_core::CudaStorage, + storage: &mcandle_core::CudaStorage, layout: &Layout, - ) -> Result<(candle_core::CudaStorage, Shape)> { + ) -> Result<(mcandle_core::CudaStorage, Shape)> { if !layout.is_contiguous() { - return Err(candle_core::Error::RequiresContiguous { op: "nonzero" }); + return Err(mcandle_core::Error::RequiresContiguous { op: "nonzero" }); } let dev = storage.device().clone(); let d_in = match storage.dtype() { - candle_core::DType::U8 => *storage.as_cuda_slice::()?.device_ptr(), - candle_core::DType::U32 => *storage.as_cuda_slice::()?.device_ptr(), - candle_core::DType::I32 => *storage.as_cuda_slice::()?.device_ptr(), - candle_core::DType::I16 => *storage.as_cuda_slice::()?.device_ptr(), - candle_core::DType::I64 => *storage.as_cuda_slice::()?.device_ptr(), - candle_core::DType::BF16 => *storage.as_cuda_slice::()?.device_ptr(), - candle_core::DType::F16 => *storage.as_cuda_slice::()?.device_ptr(), - candle_core::DType::F32 => *storage.as_cuda_slice::()?.device_ptr(), - candle_core::DType::F64 => *storage.as_cuda_slice::()?.device_ptr(), - candle_core::DType::F8E4M3 => todo!(), + mcandle_core::DType::U8 => *storage.as_cuda_slice::()?.device_ptr(), + mcandle_core::DType::U32 => *storage.as_cuda_slice::()?.device_ptr(), + mcandle_core::DType::I32 => *storage.as_cuda_slice::()?.device_ptr(), + mcandle_core::DType::I16 => *storage.as_cuda_slice::()?.device_ptr(), + mcandle_core::DType::I64 => *storage.as_cuda_slice::()?.device_ptr(), + mcandle_core::DType::BF16 => *storage.as_cuda_slice::()?.device_ptr(), + mcandle_core::DType::F16 => *storage.as_cuda_slice::()?.device_ptr(), + mcandle_core::DType::F32 => *storage.as_cuda_slice::()?.device_ptr(), + mcandle_core::DType::F64 => *storage.as_cuda_slice::()?.device_ptr(), + mcandle_core::DType::F8E4M3 => todo!(), } as *const c_void; let n = layout.shape().elem_count(); let num_nonzero = count_nonzero_cuda(storage.dtype(), d_in, u32::try_from(n)?); @@ -521,7 +521,7 @@ impl CustomOp1 for NonZero { d_out_ptr, ); let shape = Shape::from_dims(&[num_nonzero as usize, layout.dims().len()]); - let dst = candle_core::CudaStorage::wrap_cuda_slice(d_out, dev); + let dst = mcandle_core::CudaStorage::wrap_cuda_slice(d_out, dev); Ok((dst, shape)) } } @@ -534,17 +534,17 @@ impl NonZeroOp for Tensor { #[cfg(feature = "metal")] fn nonzero(&self) -> Result { if !self.is_contiguous() { - return Err(candle_core::Error::RequiresContiguous { op: "nonzero" }); + return Err(mcandle_core::Error::RequiresContiguous { op: "nonzero" }); } let original_device = self.device(); - self.to_device(&candle_core::Device::Cpu)? + self.to_device(&mcandle_core::Device::Cpu)? .apply_op1_no_bwd(&NonZero {})? .to_device(original_device) } #[cfg(not(feature = "metal"))] fn nonzero(&self) -> Result { if !self.is_contiguous() { - return Err(candle_core::Error::RequiresContiguous { op: "nonzero" }); + return Err(mcandle_core::Error::RequiresContiguous { op: "nonzero" }); } self.apply_op1_no_bwd(&NonZero {}) } @@ -596,7 +596,7 @@ impl RepeatInterleaveOp for Tensor { fn repeat_interleave_flat(&self, repeats: Vec) -> Result { let xs = self.flatten_all()?; if repeats.len() != xs.dim(0)? { - candle_core::bail!( + mcandle_core::bail!( "repeats ({}) must match flattened self length ({})", repeats.len(), xs.dim(0)? @@ -634,8 +634,8 @@ mod tests { #[test] fn test_topk() { use crate::ops::{TopKLastDimOp, TopKOutput}; - use candle_core::Tensor; - let device = candle_core::Device::Cpu; + use mcandle_core::Tensor; + let device = mcandle_core::Device::Cpu; // [[1, 3, 5], // [2, 4, 6]] let x = Tensor::arange(1f32, 7f32, &device) @@ -664,8 +664,8 @@ mod tests { #[test] fn test_nonzero_cpu() { use crate::ops::NonZeroOp; - use candle_core::Tensor; - let device = candle_core::Device::Cpu; + use mcandle_core::Tensor; + let device = mcandle_core::Device::Cpu; let a = Tensor::from_vec( vec![1f32, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0], &[2, 4], @@ -680,8 +680,8 @@ mod tests { #[test] fn test_nonzero_cuda() { use crate::ops::NonZeroOp; - use candle_core::Tensor; - let device = candle_core::Device::new_cuda(0).unwrap(); + use mcandle_core::Tensor; + let device = mcandle_core::Device::new_cuda(0).unwrap(); let a = Tensor::from_vec( vec![1f32, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0], &[2, 4], @@ -695,8 +695,8 @@ mod tests { #[test] fn test_bitwise_and_cpu() { use crate::ops::BitWiseOp; - use candle_core::Tensor; - let device = candle_core::Device::Cpu; + use mcandle_core::Tensor; + let device = mcandle_core::Device::Cpu; let a = Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap(); let b = @@ -709,8 +709,8 @@ mod tests { #[test] fn test_bitwise_and_cuda() { use crate::ops::BitWiseOp; - use candle_core::Tensor; - let device = candle_core::Device::new_cuda(0).unwrap(); + use mcandle_core::Tensor; + let device = mcandle_core::Device::new_cuda(0).unwrap(); let a = Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap(); let b = @@ -722,8 +722,8 @@ mod tests { #[test] fn test_bitwise_or_cpu() { use crate::ops::BitWiseOp; - use candle_core::Tensor; - let device = candle_core::Device::Cpu; + use mcandle_core::Tensor; + let device = mcandle_core::Device::Cpu; let a = Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap(); let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap(); @@ -735,8 +735,8 @@ mod tests { #[test] fn test_bitwise_or_cuda() { use crate::ops::BitWiseOp; - use candle_core::Tensor; - let device = candle_core::Device::new_cuda(0).unwrap(); + use mcandle_core::Tensor; + let device = mcandle_core::Device::new_cuda(0).unwrap(); let a = Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap(); let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap(); @@ -747,8 +747,8 @@ mod tests { #[test] fn test_bitwise_xor_cpu() { use crate::ops::BitWiseOp; - use candle_core::Tensor; - let device = candle_core::Device::Cpu; + use mcandle_core::Tensor; + let device = mcandle_core::Device::Cpu; let a = Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap(); let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap(); @@ -760,8 +760,8 @@ mod tests { #[test] fn test_bitwise_xor_cuda() { use crate::ops::BitWiseOp; - use candle_core::Tensor; - let device = candle_core::Device::new_cuda(0).unwrap(); + use mcandle_core::Tensor; + let device = mcandle_core::Device::new_cuda(0).unwrap(); let a = Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap(); let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap(); @@ -772,7 +772,7 @@ mod tests { #[test] fn test_nonzero_and() { use crate::ops::{BitWiseOp, NonZeroOp}; - use candle_core::{Device, Tensor}; + use mcandle_core::{Device, Tensor}; let input1 = Tensor::from_vec( vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], @@ -817,7 +817,7 @@ mod tests { #[test] fn nonzero_and_cuda() { use crate::ops::{BitWiseOp, NonZeroOp}; - use candle_core::{Device, Tensor}; + use mcandle_core::{Device, Tensor}; let device = Device::new_cuda(0).unwrap(); let input1 = @@ -852,9 +852,9 @@ mod tests { } #[test] - fn test_repeat_interleave() -> candle_core::Result<()> { + fn test_repeat_interleave() -> mcandle_core::Result<()> { use crate::ops::RepeatInterleaveOp; - use candle_core::{Device, Tensor}; + use mcandle_core::{Device, Tensor}; let input = Tensor::new( vec![vec![vec![1f32, 2., 3.], vec![4f32, 5., 6.]]], @@ -874,9 +874,9 @@ mod tests { } #[test] - fn test_repeat_interleave_flat() -> candle_core::Result<()> { + fn test_repeat_interleave_flat() -> mcandle_core::Result<()> { use crate::ops::RepeatInterleaveOp; - use candle_core::{Device, Tensor}; + use mcandle_core::{Device, Tensor}; let input = Tensor::new(vec![1., 2., 3., 4.], &Device::Cpu)?; diff --git a/mistralrs-core/src/paged_attention/cache_engine.rs b/mistralrs-core/src/paged_attention/cache_engine.rs index 3faeb2481..eef15ac29 100644 --- a/mistralrs-core/src/paged_attention/cache_engine.rs +++ b/mistralrs-core/src/paged_attention/cache_engine.rs @@ -3,7 +3,7 @@ use std::{ sync::{Arc, Mutex, MutexGuard}, }; -use candle_core::{DType, Device, Result, Tensor}; +use mcandle_core::{DType, Device, Result, Tensor}; use mistralrs_paged_attn::{copy_blocks, swap_blocks}; use super::config::ModelConfigLike; diff --git a/mistralrs-core/src/paged_attention/layers/paged_attention.rs b/mistralrs-core/src/paged_attention/layers/paged_attention.rs index 4852aec1a..4db102777 100644 --- a/mistralrs-core/src/paged_attention/layers/paged_attention.rs +++ b/mistralrs-core/src/paged_attention/layers/paged_attention.rs @@ -1,4 +1,4 @@ -use candle_core::{Device, Result, Tensor}; +use mcandle_core::{Device, Result, Tensor}; use mistralrs_paged_attn::{paged_attention, reshape_and_cache}; @@ -100,7 +100,7 @@ impl PagedAttention { }; let att = att.broadcast_add(&mask.unsqueeze(0)?.unsqueeze(0)?)?; - let att = candle_nn::ops::softmax_last_dim(&att)?; + let att = mcandle_nn::ops::softmax_last_dim(&att)?; if key_value_heads != attention_heads { let value_repeat = if key_value_heads == 1 { value.broadcast_as((batch_size, attention_heads, seq_len, head_size))? diff --git a/mistralrs-core/src/paged_attention/mod.rs b/mistralrs-core/src/paged_attention/mod.rs index ea0bd26ce..df00fbd0a 100644 --- a/mistralrs-core/src/paged_attention/mod.rs +++ b/mistralrs-core/src/paged_attention/mod.rs @@ -14,9 +14,9 @@ pub const _PAD_SLOT_ID: i64 = -1; pub use block_engine::{BlockEngine, BlockTables, LogicalTokenBlock}; pub use block_engine_sequence::BlockEngineSequence; pub use cache_engine::{CacheConfig, CacheEngine}; -use candle_core::{DType, Device}; pub use config::{ModelConfigLike, ModelConfigMetadata}; pub use layers::PagedAttention; +use mcandle_core::{DType, Device}; pub use scheduler::{ PagedAttentionScheduler, PagedAttentionSchedulerConfig, PagedAttentionSchedulerOutput, }; diff --git a/mistralrs-core/src/pipeline/amoe.rs b/mistralrs-core/src/pipeline/amoe.rs index 987c6804a..42de1b985 100644 --- a/mistralrs-core/src/pipeline/amoe.rs +++ b/mistralrs-core/src/pipeline/amoe.rs @@ -7,11 +7,11 @@ use std::{ }; use base64::{engine::general_purpose, Engine}; -use candle_core::{DType, Device, Tensor}; -use candle_nn::{AdamW, Optimizer, ParamsAdamW}; use either::Either; use image::DynamicImage; use indexmap::IndexMap; +use mcandle_core::{DType, Device, Tensor}; +use mcandle_nn::{AdamW, Optimizer, ParamsAdamW}; use mistralrs_quant::IsqType; use rand::{seq::SliceRandom, thread_rng}; use rand_isaac::Isaac64Rng; @@ -251,7 +251,7 @@ impl Pipeline for AnyMoePipeline { fn forward_inputs( &mut self, inputs: Box, - ) -> Result { + ) -> Result { get_mut_arcmutex!(self.target).forward_inputs(inputs) } @@ -262,7 +262,7 @@ impl Pipeline for AnyMoePipeline { prefix_cacher: &mut PrefixCacheManager, disable_eos_stop: bool, rng: Arc>, - ) -> Result<(), candle_core::Error> { + ) -> Result<(), mcandle_core::Error> { get_mut_arcmutex!(self.target) .sample_causal_gen(seqs, logits, prefix_cacher, disable_eos_stop, rng) .await @@ -284,10 +284,10 @@ impl AnyMoePipelineMixin for AnyMoePipeline { revision: Option, layers: Vec, silent: bool, - ) -> anyhow::Result, candle_core::Error> { + ) -> anyhow::Result, mcandle_core::Error> { let mut target = get_mut_arcmutex!(self.target); if !target.amoe_supported() { - candle_core::bail!("AnyMoE is not supported for this model."); + mcandle_core::bail!("AnyMoE is not supported for this model."); } let device = target.device(); @@ -358,7 +358,7 @@ impl AnyMoePipelineMixin for AnyMoePipeline { }, ) }) - .collect::>>()?; + .collect::>>()?; let mut rng = thread_rng(); let mut samples = inputs.into_inner(); @@ -377,7 +377,7 @@ impl AnyMoePipelineMixin for AnyMoePipeline { 0.0, vec![], ) - .map_err(candle_core::Error::msg)?; + .map_err(mcandle_core::Error::msg)?; let dummy_group = Arc::new(tokio::sync::Mutex::new(SequenceGroup::new( 1, false, false, 0, @@ -409,7 +409,7 @@ impl AnyMoePipelineMixin for AnyMoePipeline { true, Vec::new(), ) - .map_err(candle_core::Error::msg)?; + .map_err(mcandle_core::Error::msg)?; let images = image_urls.as_ref().map(|urls| { urls.iter() .map(|url| -> anyhow::Result { @@ -437,7 +437,7 @@ impl AnyMoePipelineMixin for AnyMoePipeline { let images = match images { Some(Ok(x)) => Some(x), Some(Err(e)) => { - return anyhow::Result::Err(candle_core::Error::Msg(e.to_string())) + return anyhow::Result::Err(mcandle_core::Error::Msg(e.to_string())) } None => None, }; @@ -497,7 +497,7 @@ impl AnyMoePipelineMixin for AnyMoePipeline { let cached = target.amoe_take_cached_gating_outputs(); for (layer, (optimizer, output)) in optimizers.iter_mut().zip(cached).enumerate() { - let loss = candle_nn::loss::cross_entropy( + let loss = mcandle_nn::loss::cross_entropy( &output, &labels.to_device(output.device())?, )?; @@ -518,26 +518,26 @@ impl AnyMoePipelineMixin for AnyMoePipeline { .extension() .is_some_and(|e| e.to_string_lossy() == *"csv") { - candle_core::bail!("`loss_csv_path` must have an extension `csv`."); + mcandle_core::bail!("`loss_csv_path` must have an extension `csv`."); } - let mut writer = csv::Writer::from_path(path).map_err(candle_core::Error::msg)?; + let mut writer = csv::Writer::from_path(path).map_err(mcandle_core::Error::msg)?; let mut header = vec![format!("Step")]; header.extend((0..all_losses[0].len()).map(|i| format!("Gating layer {i}"))); writer .write_record(&header) - .map_err(candle_core::Error::msg)?; + .map_err(mcandle_core::Error::msg)?; for (i, row) in all_losses.into_iter().enumerate() { let mut new_row = vec![format!("Step {i}")]; new_row.extend(row.iter().map(|x| format!("{x:.4}"))); writer .write_record(&new_row) - .map_err(candle_core::Error::msg)?; + .map_err(mcandle_core::Error::msg)?; } - writer.flush().map_err(candle_core::Error::msg)?; + writer.flush().map_err(mcandle_core::Error::msg)?; } Ok(Some(AnyMoeTrainingResult { diff --git a/mistralrs-core/src/pipeline/cache_manager.rs b/mistralrs-core/src/pipeline/cache_manager.rs index acd567369..22ad261d8 100644 --- a/mistralrs-core/src/pipeline/cache_manager.rs +++ b/mistralrs-core/src/pipeline/cache_manager.rs @@ -1,6 +1,6 @@ use std::sync::{Arc, Mutex, MutexGuard}; -use candle_core::{Result, Tensor, D}; +use mcandle_core::{Result, Tensor, D}; use crate::{get_mut_arcmutex, sequence::Sequence}; @@ -122,7 +122,7 @@ impl SingleCache { (diff + NormalCache::CACHE_GROW_SIZE - 1) / NormalCache::CACHE_GROW_SIZE; self.capacity_seq_len += n_blocks_needed * NormalCache::CACHE_GROW_SIZE; if self.capacity_seq_len > self.max_seq_len { - candle_core::bail!( + mcandle_core::bail!( "kv-cache: requested capacity ({}) above max seq len ({})", self.capacity_seq_len, self.max_seq_len @@ -501,8 +501,8 @@ impl Cache { None => (k, v), Some((k_cache, v_cache)) => { if !slow_cat { - let k = candle_nn::ops::kvconcat(k_cache, &k, 2)?.contiguous()?; - let v = candle_nn::ops::kvconcat(v_cache, &v, 2)?.contiguous()?; + let k = mcandle_nn::ops::kvconcat(k_cache, &k, 2)?.contiguous()?; + let v = mcandle_nn::ops::kvconcat(v_cache, &v, 2)?.contiguous()?; (k, v) } else { let k = Tensor::cat(&[k_cache, &k], 2)?.contiguous()?; @@ -556,8 +556,8 @@ impl Cache { } } let (k, v) = if !slow_cat { - let k = candle_nn::ops::kvconcat(&prev_k, &k, 2)?; - let v = candle_nn::ops::kvconcat(&prev_v, &v, 2)?; + let k = mcandle_nn::ops::kvconcat(&prev_k, &k, 2)?; + let v = mcandle_nn::ops::kvconcat(&prev_v, &v, 2)?; (k, v) } else { let k = Tensor::cat(&[prev_k, k], 2)?.contiguous()?; diff --git a/mistralrs-core/src/pipeline/diffusion.rs b/mistralrs-core/src/pipeline/diffusion.rs index 02a825e42..fd8576d25 100644 --- a/mistralrs-core/src/pipeline/diffusion.rs +++ b/mistralrs-core/src/pipeline/diffusion.rs @@ -14,9 +14,9 @@ use crate::utils::debug::DeviceRepr; use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors}; use crate::{DeviceMapMetadata, PagedAttentionConfig, Pipeline, TryIntoDType}; use anyhow::Result; -use candle_core::{DType, Device, Tensor}; use hf_hub::{api::sync::ApiBuilder, Repo, RepoType}; use image::{DynamicImage, RgbImage}; +use mcandle_core::{DType, Device, Tensor}; use mistralrs_quant::IsqType; use rand_isaac::Isaac64Rng; use std::any::Any; @@ -192,7 +192,7 @@ impl Loader for DiffusionLoader { |_| true, ) }) - .collect::>>()?; + .collect::>>()?; self.inner.load( configs, @@ -300,7 +300,10 @@ impl MetadataMixin for DiffusionPipeline { #[async_trait::async_trait] impl Pipeline for DiffusionPipeline { - fn forward_inputs(&mut self, inputs: Box) -> candle_core::Result { + fn forward_inputs( + &mut self, + inputs: Box, + ) -> mcandle_core::Result { let ModelInputs { prompts, params } = *inputs.downcast().expect("Downcast failed."); let img = self.model.forward(prompts, params)?.to_dtype(DType::U8)?; let (_b, c, h, w) = img.dims4()?; @@ -308,12 +311,12 @@ impl Pipeline for DiffusionPipeline { for b_img in img.chunk(img.dim(0)?, 0)? { let flattened = b_img.squeeze(0)?.permute((1, 2, 0))?.flatten_all()?; if c != 3 { - candle_core::bail!("Expected 3 channels in image output"); + mcandle_core::bail!("Expected 3 channels in image output"); } #[allow(clippy::cast_possible_truncation)] images.push(DynamicImage::ImageRgb8( RgbImage::from_raw(w as u32, h as u32, flattened.to_vec1::()?).ok_or( - candle_core::Error::Msg("RgbImage has invalid capacity.".to_string()), + mcandle_core::Error::Msg("RgbImage has invalid capacity.".to_string()), )?, )); } @@ -326,8 +329,8 @@ impl Pipeline for DiffusionPipeline { _prefix_cacher: &mut PrefixCacheManager, _disable_eos_stop: bool, _srng: Arc>, - ) -> Result<(), candle_core::Error> { - candle_core::bail!("`sample_causal_gen` is incompatible with `DiffusionPipeline`"); + ) -> Result<(), mcandle_core::Error> { + mcandle_core::bail!("`sample_causal_gen` is incompatible with `DiffusionPipeline`"); } fn category(&self) -> ModelCategory { ModelCategory::Diffusion diff --git a/mistralrs-core/src/pipeline/ggml.rs b/mistralrs-core/src/pipeline/ggml.rs index 128bfe956..976414371 100644 --- a/mistralrs-core/src/pipeline/ggml.rs +++ b/mistralrs-core/src/pipeline/ggml.rs @@ -30,9 +30,9 @@ use crate::{ xlora_models::XLoraQLlama, }; use anyhow::Result; -use candle_core::quantized::ggml_file; -use candle_core::{DType, Device, Tensor}; use hf_hub::{api::sync::ApiBuilder, Repo, RepoType}; +use mcandle_core::quantized::ggml_file; +use mcandle_core::{DType, Device, Tensor}; use mistralrs_quant::IsqType; use rand_isaac::Isaac64Rng; use std::any::Any; @@ -526,7 +526,7 @@ impl Pipeline for GGMLPipeline { fn forward_inputs( &mut self, inputs: Box, - ) -> Result { + ) -> Result { let ModelInputs { input_ids, input_ids_full, @@ -571,7 +571,7 @@ impl Pipeline for GGMLPipeline { prefix_cacher: &mut PrefixCacheManager, disable_eos_stop: bool, rng: Arc>, - ) -> Result<(), candle_core::Error> { + ) -> Result<(), mcandle_core::Error> { sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await } fn category(&self) -> ModelCategory { diff --git a/mistralrs-core/src/pipeline/gguf.rs b/mistralrs-core/src/pipeline/gguf.rs index 8edbedd3c..4f95a1a15 100644 --- a/mistralrs-core/src/pipeline/gguf.rs +++ b/mistralrs-core/src/pipeline/gguf.rs @@ -42,9 +42,9 @@ use crate::{ xlora_models::{XLoraQLlama, XLoraQPhi3}, }; use anyhow::{bail, Result}; -use candle_core::{DType, Device, Tensor}; use either::Either; use hf_hub::{api::sync::ApiBuilder, Repo, RepoType}; +use mcandle_core::{DType, Device, Tensor}; use mistralrs_quant::IsqType; use rand_isaac::Isaac64Rng; use std::any::Any; @@ -670,7 +670,7 @@ impl Pipeline for GGUFPipeline { fn forward_inputs( &mut self, inputs: Box, - ) -> Result { + ) -> Result { let ModelInputs { input_ids, input_ids_full, @@ -691,11 +691,11 @@ impl Pipeline for GGUFPipeline { (Some(engine), Some(meta)) => Some((engine.get_kv_cache().clone(), meta)), (Some(_), None) => { // This can happen if Rust-side user code is wrong - candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.") + mcandle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.") } (None, Some(_)) => { // This should never happen but we handle it anyway - candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.") + mcandle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.") } (None, None) => None, }; @@ -762,7 +762,7 @@ impl Pipeline for GGUFPipeline { prefix_cacher: &mut PrefixCacheManager, disable_eos_stop: bool, rng: Arc>, - ) -> Result<(), candle_core::Error> { + ) -> Result<(), mcandle_core::Error> { sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await } fn category(&self) -> ModelCategory { diff --git a/mistralrs-core/src/pipeline/inputs_processor.rs b/mistralrs-core/src/pipeline/inputs_processor.rs index 73662cb8d..dac7cf2a2 100644 --- a/mistralrs-core/src/pipeline/inputs_processor.rs +++ b/mistralrs-core/src/pipeline/inputs_processor.rs @@ -3,7 +3,7 @@ use std::{any::Any, num::NonZeroUsize, sync::Arc}; use anyhow::Result; -use candle_core::Device; +use mcandle_core::Device; use text_models_inputs_processor::PagedAttentionMeta; use tokenizers::Tokenizer; @@ -50,7 +50,7 @@ pub mod text_models_inputs_processor { use std::{any::Any, fmt::Debug, iter::repeat, num::NonZeroUsize, sync::Arc}; use anyhow::Result; - use candle_core::{DType, Device, Tensor, WithDType}; + use mcandle_core::{DType, Device, Tensor, WithDType}; use tokenizers::Tokenizer; use crate::{ diff --git a/mistralrs-core/src/pipeline/isq.rs b/mistralrs-core/src/pipeline/isq.rs index 3aa96dbc9..3915f860e 100644 --- a/mistralrs-core/src/pipeline/isq.rs +++ b/mistralrs-core/src/pipeline/isq.rs @@ -9,8 +9,8 @@ use std::{ }; use anyhow::Result; -use candle_core::{Context, Device, Tensor}; use indicatif::{ParallelProgressIterator, ProgressBar, ProgressStyle}; +use mcandle_core::{Context, Device, Tensor}; use mistralrs_quant::{ FP8Linear, GgufMatMul, HqqLayer, IsqType, QuantMethod, QuantizedSerde, QuantizedSerdeType, UnquantLinear, @@ -168,7 +168,7 @@ pub trait IsqModel { organization: IsqOrganization, write_artifacts: Option<&PathBuf>, full_ser: UqffFullSer<'_>, - ) -> candle_core::Result<()> { + ) -> mcandle_core::Result<()> { { let (mut tensors, mapper) = match organization { IsqOrganization::Default => self.get_layers(), @@ -269,7 +269,7 @@ pub trait IsqModel { let pool = rayon::ThreadPoolBuilder::new() .num_threads(minimum_max_threads) .build() - .map_err(candle_core::Error::msg)?; + .map_err(mcandle_core::Error::msg)?; pool.install(|| { use indicatif::ParallelProgressIterator; @@ -308,7 +308,7 @@ pub trait IsqModel { ); if !serialized.extension().is_some_and(|ext| ext == "uqff") { - candle_core::bail!("UQFF output path extension must be `.uqff`",); + mcandle_core::bail!("UQFF output path extension must be `.uqff`",); } let bar = ProgressBar::new(total_tensors as u64); @@ -322,7 +322,7 @@ pub trait IsqModel { let pool = rayon::ThreadPoolBuilder::new() .num_threads(2) .build() - .map_err(candle_core::Error::msg)?; + .map_err(mcandle_core::Error::msg)?; let quantized_values = pool.install(|| { if silent { @@ -336,7 +336,7 @@ pub trait IsqModel { Tensor::new(Cow::into_owned(layer.serialize()?), &Device::Cpu)?, )) }) - .collect::>>() + .collect::>>() } else { tensors .par_iter() @@ -349,7 +349,7 @@ pub trait IsqModel { Tensor::new(Cow::into_owned(layer.serialize()?), &Device::Cpu)?, )) }) - .collect::>>() + .collect::>>() } }); @@ -400,7 +400,7 @@ pub trait IsqModel { info!("Serializing tokenizer to `{}`.", tokenizer_out.display()); serde_json::to_writer_pretty(File::create(&tokenizer_out)?, tokenizer) - .map_err(candle_core::Error::msg)?; + .map_err(mcandle_core::Error::msg)?; if let Some(template_filename) = template_filename { info!( @@ -409,9 +409,9 @@ pub trait IsqModel { ); let template = - std::fs::read(template_filename).map_err(candle_core::Error::msg)?; + std::fs::read(template_filename).map_err(mcandle_core::Error::msg)?; std::fs::write(&tokenizer_cfg_out, template) - .map_err(candle_core::Error::msg)?; + .map_err(mcandle_core::Error::msg)?; } if let Some(generation_config) = generation_config { @@ -420,8 +420,8 @@ pub trait IsqModel { gen_cfg_out.display() ); - let cfg = std::fs::read(generation_config).map_err(candle_core::Error::msg)?; - std::fs::write(&gen_cfg_out, cfg).map_err(candle_core::Error::msg)?; + let cfg = std::fs::read(generation_config).map_err(mcandle_core::Error::msg)?; + std::fs::write(&gen_cfg_out, cfg).map_err(mcandle_core::Error::msg)?; } if let Some(processor_config) = processor_filename { @@ -430,8 +430,8 @@ pub trait IsqModel { processor_out.display() ); - let cfg = std::fs::read(processor_config).map_err(candle_core::Error::msg)?; - std::fs::write(&processor_out, cfg).map_err(candle_core::Error::msg)?; + let cfg = std::fs::read(processor_config).map_err(mcandle_core::Error::msg)?; + std::fs::write(&processor_out, cfg).map_err(mcandle_core::Error::msg)?; } if let Some(preprocessor_config) = preprocessor_filename { @@ -441,8 +441,8 @@ pub trait IsqModel { ); let cfg = - std::fs::read(preprocessor_config).map_err(candle_core::Error::msg)?; - std::fs::write(&preprocessor_out, cfg).map_err(candle_core::Error::msg)?; + std::fs::read(preprocessor_config).map_err(mcandle_core::Error::msg)?; + std::fs::write(&preprocessor_out, cfg).map_err(mcandle_core::Error::msg)?; } } let delta = Instant::now().duration_since(t_start).as_secs_f32(); @@ -457,7 +457,7 @@ pub trait IsqModel { topology: Option<&Topology>, silent: bool, artifacts: &PathBuf, - ) -> candle_core::Result<()> { + ) -> mcandle_core::Result<()> { let (tensors, mapper) = self.get_layers(); let total_tensors = tensors.len(); @@ -491,7 +491,7 @@ pub trait IsqModel { devices.push(device); } - let artifacts = unsafe { candle_core::safetensors::MmapedSafetensors::new(artifacts)? }; + let artifacts = unsafe { mcandle_core::safetensors::MmapedSafetensors::new(artifacts)? }; let artifact_isqs = artifacts .tensors() @@ -506,7 +506,7 @@ pub trait IsqModel { .collect::>(); if artifact_isqs.len() != total_tensors { - candle_core::bail!( + mcandle_core::bail!( "Number of artifacts ({}) does not match the number of ISQ layers ({total_tensors})", artifact_isqs.len(), ); @@ -549,7 +549,7 @@ pub trait IsqModel { } Ok(()) }) - .collect::>>()?; + .collect::>>()?; } else { (0..tensors.len()) .into_par_iter() @@ -578,7 +578,7 @@ pub trait IsqModel { } Ok(()) }) - .collect::>>()?; + .collect::>>()?; } let delta = Instant::now().duration_since(t_start).as_secs_f32(); diff --git a/mistralrs-core/src/pipeline/loaders/diffusion_loaders.rs b/mistralrs-core/src/pipeline/loaders/diffusion_loaders.rs index eaf4262ff..e0dc08ffd 100644 --- a/mistralrs-core/src/pipeline/loaders/diffusion_loaders.rs +++ b/mistralrs-core/src/pipeline/loaders/diffusion_loaders.rs @@ -6,8 +6,8 @@ use std::{ }; use anyhow::{Context, Result}; -use candle_core::{Device, Tensor}; -use candle_nn::VarBuilder; +use mcandle_core::{Device, Tensor}; +use mcandle_nn::VarBuilder; use hf_hub::api::sync::ApiRepo; #[cfg(feature = "pyo3_macros")] @@ -40,7 +40,7 @@ pub trait DiffusionModel { &mut self, prompts: Vec, params: DiffusionGenerationParams, - ) -> candle_core::Result; + ) -> mcandle_core::Result; fn device(&self) -> &Device; fn max_seq_len(&self) -> usize; } diff --git a/mistralrs-core/src/pipeline/loaders/mod.rs b/mistralrs-core/src/pipeline/loaders/mod.rs index 92c3864bd..53e21b68a 100644 --- a/mistralrs-core/src/pipeline/loaders/mod.rs +++ b/mistralrs-core/src/pipeline/loaders/mod.rs @@ -12,7 +12,7 @@ use std::{ use anyhow::Result; use as_any::AsAny; -use candle_core::Device; +use mcandle_core::Device; use mistralrs_quant::IsqType; use tokio::sync::Mutex; @@ -354,7 +354,7 @@ impl ModelKind { /// # Example /// ```no_run /// use mistralrs_core::{Loader, TokenSource, DeviceMapMetadata, ModelDType}; -/// use candle_core::Device; +/// use mcandle_core::Device; /// /// let loader: Box = todo!(); /// let pipeline = loader.load_model_from_hf( diff --git a/mistralrs-core/src/pipeline/loaders/normal_loaders.rs b/mistralrs-core/src/pipeline/loaders/normal_loaders.rs index 39e7d62f7..5401432b3 100644 --- a/mistralrs-core/src/pipeline/loaders/normal_loaders.rs +++ b/mistralrs-core/src/pipeline/loaders/normal_loaders.rs @@ -20,8 +20,8 @@ use crate::{ xlora_models::NonGranularState, }; use anyhow::Result; -use candle_core::{Device, Tensor}; -use candle_nn::VarBuilder; +use mcandle_core::{Device, Tensor}; +use mcandle_nn::VarBuilder; use mistralrs_quant::QuantizedConfig; #[cfg(feature = "pyo3_macros")] @@ -46,7 +46,7 @@ pub trait NormalModel: IsqModel + AnyMoeBaseModelMixin { position_ids: Vec, metadata: Option<(Vec<(Tensor, Tensor)>, &mut PagedAttentionInputMetadata)>, flash_params: &FlashParams, - ) -> candle_core::Result; + ) -> mcandle_core::Result; #[allow(clippy::too_many_arguments)] fn xlora_forward( &self, @@ -62,14 +62,14 @@ pub trait NormalModel: IsqModel + AnyMoeBaseModelMixin { position_ids: Vec, flash_params: &FlashParams, flash_params_full: &FlashParams, - ) -> candle_core::Result; + ) -> mcandle_core::Result; fn is_xlora(&self) -> bool; fn device(&self) -> &Device; fn cache(&self) -> &EitherCache; fn max_seq_len(&self) -> usize; - fn activate_adapters(&mut self, _: Vec) -> candle_core::Result { + fn activate_adapters(&mut self, _: Vec) -> mcandle_core::Result { // NOTE: While X-LoRA shares a similar name, it is not equivalent. Its adapter set must remain the same. - candle_core::bail!( + mcandle_core::bail!( "Activating adapters is only supported for models fine-tuned with LoRA." ); } diff --git a/mistralrs-core/src/pipeline/loaders/vision_loaders.rs b/mistralrs-core/src/pipeline/loaders/vision_loaders.rs index 607e7aeed..7ca915a73 100644 --- a/mistralrs-core/src/pipeline/loaders/vision_loaders.rs +++ b/mistralrs-core/src/pipeline/loaders/vision_loaders.rs @@ -3,8 +3,8 @@ use std::sync::Arc; use std::{fmt::Debug, str::FromStr}; use anyhow::Result; -use candle_core::{Device, Tensor}; -use candle_nn::VarBuilder; +use mcandle_core::{Device, Tensor}; +use mcandle_nn::VarBuilder; #[cfg(feature = "pyo3_macros")] use pyo3::pyclass; @@ -46,7 +46,7 @@ pub trait VisionModel: IsqModel + AnyMoeBaseModelMixin { model_specific_args: Box, // pixel attention mask, or image sizes, or anything else metadata: Option<(Vec<(Tensor, Tensor)>, &mut PagedAttentionInputMetadata)>, flash_params: &FlashParams, - ) -> candle_core::Result; + ) -> mcandle_core::Result; fn device(&self) -> &Device; fn cache(&self) -> &EitherCache; fn max_seq_len(&self) -> usize; diff --git a/mistralrs-core/src/pipeline/mod.rs b/mistralrs-core/src/pipeline/mod.rs index 81267de31..2567b0957 100644 --- a/mistralrs-core/src/pipeline/mod.rs +++ b/mistralrs-core/src/pipeline/mod.rs @@ -54,7 +54,7 @@ use tokenizers::Tokenizer; pub use vision::{VisionLoader, VisionLoaderBuilder, VisionSpecificConfig}; use anyhow::Result; -use candle_core::{DType, Device, IndexOp, Tensor, Var}; +use mcandle_core::{DType, Device, IndexOp, Tensor, Var}; use crate::sequence::Sequence; @@ -155,7 +155,7 @@ pub trait AnyMoePipelineMixin { fn amoe_layer_vars(&self) -> Vec> { unreachable!() } - fn amoe_finish_training(&mut self, _gate_model_id: Option) -> candle_core::Result<()> { + fn amoe_finish_training(&mut self, _gate_model_id: Option) -> mcandle_core::Result<()> { unreachable!() } fn amoe_base_model_trainable_params(&self) -> usize { @@ -184,7 +184,7 @@ pub trait AnyMoePipelineMixin { _expert_type: AnyMoeExpertType, _silent: bool, _gate_model_id: Option, - ) -> candle_core::Result<()> { + ) -> mcandle_core::Result<()> { unreachable!() } /// Pre-train the gating layers @@ -198,7 +198,7 @@ pub trait AnyMoePipelineMixin { _revision: Option, _layers: Vec, _silent: bool, - ) -> Result, candle_core::Error> { + ) -> Result, mcandle_core::Error> { unreachable!() } } @@ -253,7 +253,7 @@ pub enum ForwardInputsResult { } impl ForwardInputsResult { - fn index_bs(&self, bs_idx: usize) -> candle_core::Result { + fn index_bs(&self, bs_idx: usize) -> mcandle_core::Result { match self { Self::CausalGeneration { logits } => Ok(Self::CausalGeneration { logits: logits.i(bs_idx)?, @@ -264,7 +264,7 @@ impl ForwardInputsResult { } } - fn to_device(&self, device: &Device) -> candle_core::Result { + fn to_device(&self, device: &Device) -> mcandle_core::Result { match self { Self::CausalGeneration { logits } => Ok(Self::CausalGeneration { logits: logits.to_device(device)?, @@ -288,7 +288,7 @@ pub trait Pipeline: fn forward_inputs( &mut self, inputs: Box, - ) -> Result; + ) -> Result; /// Returns the total of model execution time. #[allow(clippy::too_many_arguments)] @@ -300,7 +300,7 @@ pub trait Pipeline: disable_eos_stop: bool, rng: Arc>, backend_metadata: CacheBackendMetadata<'_>, - ) -> Result { + ) -> Result { match backend_metadata { CacheBackendMetadata::DefaultInstructions { pre_op, post_op } => { let inputs_iter = self.get_processor().inputs_processor().process_inputs( @@ -323,14 +323,14 @@ pub trait Pipeline: let InputProcessorOutput { inputs, seq_indices, - } = inputs.map_err(candle_core::Error::msg)?; + } = inputs.map_err(mcandle_core::Error::msg)?; if i == 0 { match pre_op { CacheInstruction::In(ref adapter_inst) => { match adapter_inst { AdapterInstruction::Activate(adapters) => { self.activate_adapters(adapters.clone()).map_err(|e| { - candle_core::Error::msg(>::as_ref( &e @@ -345,7 +345,7 @@ pub trait Pipeline: match adapter_inst { AdapterInstruction::Activate(adapters) => { self.activate_adapters(adapters.clone()).map_err(|e| { - candle_core::Error::msg(>::as_ref( &e @@ -363,7 +363,7 @@ pub trait Pipeline: match adapter_inst { AdapterInstruction::Activate(adapters) => { self.activate_adapters(adapters.clone()).map_err(|e| { - candle_core::Error::msg(>::as_ref( &e @@ -399,7 +399,7 @@ pub trait Pipeline: l.expect("Did not get any inputs. This is shocking.") .to_device(&Device::Cpu) }) - .collect::>>()?; + .collect::>>()?; match post_op { CacheInstruction::Out => self.clone_out_cache(input_seqs, false), @@ -501,7 +501,7 @@ pub trait Pipeline: let InputProcessorOutput { inputs, seq_indices, - } = inputs.map_err(candle_core::Error::msg)?; + } = inputs.map_err(mcandle_core::Error::msg)?; let start = Instant::now(); let raw_logits = self.forward_inputs(inputs)?; @@ -519,7 +519,7 @@ pub trait Pipeline: l.expect("Did not get any inputs. This is shocking.") .to_device(&Device::Cpu) }) - .collect::>>()?; + .collect::>>()?; let start = Instant::now(); match &logits[0] { @@ -581,7 +581,7 @@ pub trait Pipeline: prefix_cacher: &mut PrefixCacheManager, disable_eos_stop: bool, rng: Arc>, - ) -> Result<(), candle_core::Error>; + ) -> Result<(), mcandle_core::Error>; fn category(&self) -> ModelCategory; } @@ -589,7 +589,7 @@ pub trait Pipeline: pub(crate) fn extract_logits( logits: &Tensor, context_lens: Vec<(usize, usize)>, -) -> candle_core::Result { +) -> mcandle_core::Result { let mut toks = Vec::new(); for (dim, (start, len)) in logits.chunk(logits.dims()[0], 0)?.iter().zip(context_lens) { toks.push(dim.narrow(1, start, len)?); diff --git a/mistralrs-core/src/pipeline/normal.rs b/mistralrs-core/src/pipeline/normal.rs index 551795fd2..4d1679b21 100644 --- a/mistralrs-core/src/pipeline/normal.rs +++ b/mistralrs-core/src/pipeline/normal.rs @@ -35,8 +35,8 @@ use crate::{ Topology, TryIntoDType, }; use anyhow::Result; -use candle_core::{Device, Tensor, Var}; use hf_hub::{api::sync::ApiBuilder, Repo, RepoType}; +use mcandle_core::{Device, Tensor, Var}; use mistralrs_quant::IsqType; use rand_isaac::Isaac64Rng; use regex_automata::meta::Regex; @@ -591,7 +591,7 @@ impl Pipeline for NormalPipeline { fn forward_inputs( &mut self, inputs: Box, - ) -> Result { + ) -> Result { let ModelInputs { input_ids, input_ids_full, @@ -612,11 +612,11 @@ impl Pipeline for NormalPipeline { (Some(engine), Some(meta)) => Some((engine.get_kv_cache().clone(), meta)), (Some(_), None) => { // This can happen if Rust-side user code is wrong - candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.") + mcandle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.") } (None, Some(_)) => { // This should never happen but we handle it anyway - candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.") + mcandle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.") } (None, None) => None, }; @@ -654,7 +654,7 @@ impl Pipeline for NormalPipeline { prefix_cacher: &mut PrefixCacheManager, disable_eos_stop: bool, rng: Arc>, - ) -> Result<(), candle_core::Error> { + ) -> Result<(), mcandle_core::Error> { sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await } fn category(&self) -> ModelCategory { @@ -663,7 +663,7 @@ impl Pipeline for NormalPipeline { } impl AnyMoePipelineMixin for NormalPipeline { - fn amoe_finish_training(&mut self, gate_model_id: Option) -> candle_core::Result<()> { + fn amoe_finish_training(&mut self, gate_model_id: Option) -> mcandle_core::Result<()> { self.model.finish_training(gate_model_id) } fn amoe_layer_vars(&self) -> Vec> { @@ -682,26 +682,26 @@ impl AnyMoePipelineMixin for NormalPipeline { revision: Option, match_regex: &str, config: crate::amoe::AnyMoeConfig, - dtype: candle_core::DType, + dtype: mcandle_core::DType, dev: &Device, (prefix, mlp): (String, String), layers: Vec, expert_type: AnyMoeExpertType, silent: bool, gate_model_id: Option, - ) -> candle_core::Result<()> { + ) -> mcandle_core::Result<()> { let mut vbs = Vec::new(); // Precompile regex here - let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?; + let regex = Regex::new(match_regex).map_err(mcandle_core::Error::msg)?; for model_id in model_ids { let model_id_str = &model_id; let model_id = Path::new(&model_id); let api = ApiBuilder::new() .with_progress(!silent) - .with_token(get_token(token).map_err(candle_core::Error::msg)?) + .with_token(get_token(token).map_err(mcandle_core::Error::msg)?) .build() - .map_err(candle_core::Error::msg)?; + .map_err(mcandle_core::Error::msg)?; let revision = revision.clone().unwrap_or("main".to_string()); let api = api.repo(Repo::with_revision( model_id_str.clone(), @@ -748,9 +748,9 @@ impl AnyMoePipelineMixin for NormalPipeline { let api = ApiBuilder::new() .with_progress(!silent) - .with_token(get_token(token).map_err(candle_core::Error::msg)?) + .with_token(get_token(token).map_err(mcandle_core::Error::msg)?) .build() - .map_err(candle_core::Error::msg)?; + .map_err(mcandle_core::Error::msg)?; let revision = revision.clone().unwrap_or("main".to_string()); let api = api.repo(Repo::with_revision( model_id_str.clone(), diff --git a/mistralrs-core/src/pipeline/sampling.rs b/mistralrs-core/src/pipeline/sampling.rs index cab5c6f95..5bd348483 100644 --- a/mistralrs-core/src/pipeline/sampling.rs +++ b/mistralrs-core/src/pipeline/sampling.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use candle_core::{DType, Device, Result, Tensor}; +use mcandle_core::{DType, Device, Result, Tensor}; use rand_isaac::Isaac64Rng; use crate::{ @@ -26,7 +26,7 @@ pub(crate) async fn finish_or_add_toks_to_seq( this.get_metadata() .tok_trie .as_ref() - .ok_or(candle_core::Error::Msg( + .ok_or(mcandle_core::Error::Msg( "`finish_or_add_toks_to_seq` requires the pipeline to have a token trie" .to_string(), ))? @@ -125,7 +125,7 @@ pub(crate) async fn finish_or_add_toks_to_seq( token: crate::handle_seq_error_ok!( tokenizer .as_ref() - .ok_or(candle_core::Error::Msg( + .ok_or(mcandle_core::Error::Msg( "`finish_or_add_toks_to_seq` requires the pipeline to have a tokenizer" .to_string(), ))?.decode(&[logprob.token], false), @@ -160,7 +160,7 @@ pub(crate) async fn finish_or_add_toks_to_seq( txt[..completion_bytes_pos].trim_start().to_string() } crate::sequence::StopReason::GeneratedImage => { - candle_core::bail!("Stop reason was `GeneratedImage`.") + mcandle_core::bail!("Stop reason was `GeneratedImage`.") } }; @@ -168,7 +168,7 @@ pub(crate) async fn finish_or_add_toks_to_seq( let mut tool_calls = Vec::new(); let mut text_new = Some(text.clone()); if let Some(ref matcher) = seq.tools { - let calls = matcher.get_call(&text).map_err(candle_core::Error::msg)?; + let calls = matcher.get_call(&text).map_err(mcandle_core::Error::msg)?; if !calls.is_empty() { text_new = None; } @@ -216,7 +216,7 @@ pub(crate) async fn finish_or_add_toks_to_seq( seq.responder(), ) .await - .map_err(candle_core::Error::msg)?; + .map_err(mcandle_core::Error::msg)?; } else { group .maybe_send_completion_done_response( @@ -232,7 +232,7 @@ pub(crate) async fn finish_or_add_toks_to_seq( seq.responder(), ) .await - .map_err(candle_core::Error::msg)?; + .map_err(mcandle_core::Error::msg)?; } } this.reset_non_granular_state(); @@ -339,7 +339,7 @@ pub async fn sample_sequence( -f32::INFINITY; seq.tok_trie .as_ref() - .ok_or(candle_core::Error::Msg( + .ok_or(mcandle_core::Error::Msg( "TokTrie must be present in pipeline if bias is calculated".to_string() ))? .vocab_size() @@ -381,14 +381,14 @@ pub async fn sample_sequence( .as_ref() .unwrap() .append_token(rx.as_mut(), second_logprobs_response.token) - .map_err(candle_core::Error::msg)?; + .map_err(mcandle_core::Error::msg)?; } SequenceRecognizer::Cfg(ref mut cfg) => { seq.tok_trie .as_ref() .unwrap() .append_token(cfg.as_mut(), second_logprobs_response.token) - .map_err(candle_core::Error::msg)?; + .map_err(mcandle_core::Error::msg)?; } SequenceRecognizer::None => {} } diff --git a/mistralrs-core/src/pipeline/speculative.rs b/mistralrs-core/src/pipeline/speculative.rs index 6a6d5516c..d827917be 100644 --- a/mistralrs-core/src/pipeline/speculative.rs +++ b/mistralrs-core/src/pipeline/speculative.rs @@ -6,7 +6,7 @@ use std::{ }; use anyhow::Result as anyhowResult; -use candle_core::{Device, IndexOp, Result, Tensor}; +use mcandle_core::{Device, IndexOp, Result, Tensor}; use mistralrs_quant::IsqType; use rand_isaac::Isaac64Rng; use tokenizers::Tokenizer; @@ -184,7 +184,7 @@ impl SpeculativePipeline { if get_mut_arcmutex!(target) .tokenizer() .as_ref() - .ok_or(candle_core::Error::Msg( + .ok_or(mcandle_core::Error::Msg( "`SpeculativePipeline::new` requires the target pipeline to have a token trie" .to_string(), ))? @@ -192,16 +192,16 @@ impl SpeculativePipeline { != get_mut_arcmutex!(draft) .tokenizer() .as_ref() - .ok_or(candle_core::Error::Msg( + .ok_or(mcandle_core::Error::Msg( "`SpeculativePipeline::new` requires the draft pipeline to have a token trie" .to_string(), ))? .get_vocab(true) { - candle_core::bail!("Target and draft models' tokenizer vocab do not match. This is required for speculative decoding."); + mcandle_core::bail!("Target and draft models' tokenizer vocab do not match. This is required for speculative decoding."); } if get_mut_arcmutex!(target).category() != get_mut_arcmutex!(draft).category() { - candle_core::bail!("Target and draft models' category do not match. This is required for speculative decoding."); + mcandle_core::bail!("Target and draft models' category do not match. This is required for speculative decoding."); } if get_mut_arcmutex!(target) .get_processor() @@ -212,7 +212,7 @@ impl SpeculativePipeline { .inputs_processor() .get_type() { - candle_core::bail!("Target and draft models' input processors do not match. This is required for speculative decoding."); + mcandle_core::bail!("Target and draft models' input processors do not match. This is required for speculative decoding."); } let metadata = get_mut_arcmutex!(target).get_metadata().clone(); let category = get_mut_arcmutex!(target).category(); @@ -346,7 +346,7 @@ impl Pipeline for SpeculativePipeline { match adapter_inst { AdapterInstruction::Activate(adapters) => { self.activate_adapters(adapters).map_err(|e| { - candle_core::Error::msg(>::as_ref( &e @@ -361,7 +361,7 @@ impl Pipeline for SpeculativePipeline { match adapter_inst { AdapterInstruction::Activate(adapters) => { self.activate_adapters(adapters).map_err(|e| { - candle_core::Error::msg(>::as_ref( &e @@ -379,7 +379,7 @@ impl Pipeline for SpeculativePipeline { match adapter_inst { AdapterInstruction::Activate(adapters) => { self.activate_adapters(adapters).map_err(|e| { - candle_core::Error::msg(>::as_ref( &e @@ -433,7 +433,7 @@ impl Pipeline for SpeculativePipeline { #[allow(irrefutable_let_patterns)] let ForwardInputsResult::CausalGeneration { logits } = logits else { - candle_core::bail!( + mcandle_core::bail!( "Speculative decoding requires `CausalGeneration` forward results" ); }; @@ -506,7 +506,7 @@ impl Pipeline for SpeculativePipeline { #[allow(irrefutable_let_patterns)] let ForwardInputsResult::CausalGeneration { logits } = logits else { - candle_core::bail!( + mcandle_core::bail!( "Speculative decoding requires `CausalGeneration` forward results" ); }; @@ -616,22 +616,22 @@ impl Pipeline for SpeculativePipeline { .get_metadata() .tok_trie .as_ref() - .ok_or(candle_core::Error::Msg( + .ok_or(mcandle_core::Error::Msg( "`SpeculativePipeline::step` requires a token trie".to_string(), ))? .append_token(rx.as_mut(), accepted.token) - .map_err(candle_core::Error::msg)?; + .map_err(mcandle_core::Error::msg)?; } SequenceRecognizer::Cfg(ref mut cfg) => { get_mut_arcmutex!(self.target) .get_metadata() .tok_trie .as_ref() - .ok_or(candle_core::Error::Msg( + .ok_or(mcandle_core::Error::Msg( "`SpeculativePipeline::step` requires a token trie".to_string(), ))? .append_token(cfg.as_mut(), accepted.token) - .map_err(candle_core::Error::msg)?; + .map_err(mcandle_core::Error::msg)?; } SequenceRecognizer::None => {} } diff --git a/mistralrs-core/src/pipeline/vision.rs b/mistralrs-core/src/pipeline/vision.rs index b9f8863e6..b6a93f6c7 100644 --- a/mistralrs-core/src/pipeline/vision.rs +++ b/mistralrs-core/src/pipeline/vision.rs @@ -28,8 +28,8 @@ use crate::{ TryIntoDType, }; use anyhow::Result; -use candle_core::{Device, Tensor, Var}; use hf_hub::{api::sync::ApiBuilder, Repo, RepoType}; +use mcandle_core::{Device, Tensor, Var}; use mistralrs_quant::IsqType; use rand_isaac::Isaac64Rng; use regex_automata::meta::Regex; @@ -482,7 +482,10 @@ impl MetadataMixin for VisionPipeline { #[async_trait::async_trait] impl Pipeline for VisionPipeline { - fn forward_inputs(&mut self, inputs: Box) -> candle_core::Result { + fn forward_inputs( + &mut self, + inputs: Box, + ) -> mcandle_core::Result { let ModelInputs { input_ids, seqlen_offsets, @@ -501,11 +504,11 @@ impl Pipeline for VisionPipeline { (Some(engine), Some(meta)) => Some((engine.get_kv_cache().clone(), meta)), (Some(_), None) => { // This can happen if Rust-side user code is wrong - candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.") + mcandle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.") } (None, Some(_)) => { // This should never happen but we handle it anyway - candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.") + mcandle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.") } (None, None) => None, }; @@ -529,7 +532,7 @@ impl Pipeline for VisionPipeline { prefix_cacher: &mut PrefixCacheManager, disable_eos_stop: bool, rng: Arc>, - ) -> Result<(), candle_core::Error> { + ) -> Result<(), mcandle_core::Error> { sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await } fn category(&self) -> ModelCategory { @@ -542,7 +545,7 @@ impl Pipeline for VisionPipeline { } impl AnyMoePipelineMixin for VisionPipeline { - fn amoe_finish_training(&mut self, gate_model_id: Option) -> candle_core::Result<()> { + fn amoe_finish_training(&mut self, gate_model_id: Option) -> mcandle_core::Result<()> { self.model.finish_training(gate_model_id) } fn amoe_layer_vars(&self) -> Vec> { @@ -561,26 +564,26 @@ impl AnyMoePipelineMixin for VisionPipeline { revision: Option, match_regex: &str, config: crate::amoe::AnyMoeConfig, - dtype: candle_core::DType, + dtype: mcandle_core::DType, dev: &Device, (prefix, mlp): (String, String), layers: Vec, expert_type: AnyMoeExpertType, silent: bool, gate_model_id: Option, - ) -> candle_core::Result<()> { + ) -> mcandle_core::Result<()> { let mut vbs = Vec::new(); // Precompile regex here - let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?; + let regex = Regex::new(match_regex).map_err(mcandle_core::Error::msg)?; for model_id in model_ids { let model_id_str = &model_id; let model_id = Path::new(&model_id); let api = ApiBuilder::new() .with_progress(!silent) - .with_token(get_token(token).map_err(candle_core::Error::msg)?) + .with_token(get_token(token).map_err(mcandle_core::Error::msg)?) .build() - .map_err(candle_core::Error::msg)?; + .map_err(mcandle_core::Error::msg)?; let revision = revision.clone().unwrap_or("main".to_string()); let api = api.repo(Repo::with_revision( model_id_str.clone(), @@ -627,9 +630,9 @@ impl AnyMoePipelineMixin for VisionPipeline { let api = ApiBuilder::new() .with_progress(!silent) - .with_token(get_token(token).map_err(candle_core::Error::msg)?) + .with_token(get_token(token).map_err(mcandle_core::Error::msg)?) .build() - .map_err(candle_core::Error::msg)?; + .map_err(mcandle_core::Error::msg)?; let revision = revision.clone().unwrap_or("main".to_string()); let api = api.repo(Repo::with_revision( model_id_str.clone(), diff --git a/mistralrs-core/src/prefix_cacher.rs b/mistralrs-core/src/prefix_cacher.rs index 24bac68e6..7653f0a5c 100644 --- a/mistralrs-core/src/prefix_cacher.rs +++ b/mistralrs-core/src/prefix_cacher.rs @@ -1,6 +1,6 @@ use std::sync::{Arc, Mutex}; -use candle_core::{Device, Result, Tensor}; +use mcandle_core::{Device, Result, Tensor}; use radix_trie::{Trie, TrieCommon, TrieKey}; use crate::{get_mut_arcmutex, pipeline::LayerCaches, sequence::Sequence}; diff --git a/mistralrs-core/src/sampler.rs b/mistralrs-core/src/sampler.rs index 9a205fff7..14503cb6e 100644 --- a/mistralrs-core/src/sampler.rs +++ b/mistralrs-core/src/sampler.rs @@ -6,7 +6,7 @@ use std::{ sync::{Arc, Mutex}, }; -use candle_core::{Device, Error, Result, Tensor, D}; +use mcandle_core::{Device, Error, Result, Tensor, D}; #[cfg(feature = "pyo3_macros")] use pyo3::pyclass; @@ -153,7 +153,7 @@ impl DrySamplingParamsInner { /// ```rust /// use std::{sync::Arc, ops::Mul}; /// use mistralrs_core::CustomLogitsProcessor; -/// use candle_core::{Result, Tensor}; +/// use mcandle_core::{Result, Tensor}; /// /// struct ThresholdLogitsProcessor; /// impl CustomLogitsProcessor for ThresholdLogitsProcessor { @@ -526,7 +526,7 @@ impl Sampler { fn apply_penalties(&self, mut logits: Vec, context: &[u32]) -> Result { if context.is_empty() { - candle_core::bail!("Penalty context is empty, this should not happen."); + mcandle_core::bail!("Penalty context is empty, this should not happen."); } // Dry penalty @@ -661,7 +661,7 @@ impl Sampler { )?, Some(temperature) => { let logits = (&logits / temperature)?; - let probs = candle_nn::ops::softmax_last_dim(&logits)?; + let probs = mcandle_nn::ops::softmax_last_dim(&logits)?; self.sample_speculative_top_kp_min_p( probs, @@ -677,7 +677,7 @@ impl Sampler { None => self.sample_argmax(logits, return_logprobs)?, Some(temperature) => { let logits = (&logits / temperature)?; - let probs = candle_nn::ops::softmax_last_dim(&logits)?; + let probs = mcandle_nn::ops::softmax_last_dim(&logits)?; let mut probs: Vec = probs.to_vec1()?; self.sample_top_kp_min_p( @@ -715,7 +715,7 @@ mod tests { #[test] fn test_argmax() { use super::Sampler; - use candle_core::{Device, Tensor}; + use mcandle_core::{Device, Tensor}; use rand::SeedableRng; use rand_isaac::Isaac64Rng; use std::sync::Arc; @@ -747,7 +747,7 @@ mod tests { #[test] fn test_gumbel_speculative() { use super::Sampler; - use candle_core::{Device, Tensor}; + use mcandle_core::{Device, Tensor}; use rand::SeedableRng; use rand_isaac::Isaac64Rng; use std::sync::Arc; diff --git a/mistralrs-core/src/sequence.rs b/mistralrs-core/src/sequence.rs index 321f69b75..a0a9f4bde 100644 --- a/mistralrs-core/src/sequence.rs +++ b/mistralrs-core/src/sequence.rs @@ -24,7 +24,7 @@ use crate::{ sampler::{Logprobs, Sampler}, ChatCompletionResponse, Usage, }; -use candle_core::Tensor; +use mcandle_core::Tensor; use regex_automata::util::primitives::StateID; #[derive(Clone, Copy, PartialEq, Debug)] diff --git a/mistralrs-core/src/topology/mod.rs b/mistralrs-core/src/topology/mod.rs index 7215af9be..e1cc57c9c 100644 --- a/mistralrs-core/src/topology/mod.rs +++ b/mistralrs-core/src/topology/mod.rs @@ -1,7 +1,7 @@ use std::{collections::HashMap, fs, io::Read, ops::Range, path::Path}; -use candle_core::Device; use itertools::Itertools; +use mcandle_core::Device; use mistralrs_quant::IsqType; use regex::Regex; use serde::Deserialize; diff --git a/mistralrs-core/src/utils/debug.rs b/mistralrs-core/src/utils/debug.rs index 48a528169..e891ab9e4 100644 --- a/mistralrs-core/src/utils/debug.rs +++ b/mistralrs-core/src/utils/debug.rs @@ -1,4 +1,4 @@ -use candle_core::{Device, DeviceLocation}; +use mcandle_core::{Device, DeviceLocation}; use tracing::level_filters::LevelFilter; use tracing_subscriber::EnvFilter; diff --git a/mistralrs-core/src/utils/gguf_metadata.rs b/mistralrs-core/src/utils/gguf_metadata.rs index 9754990da..2b2a02d3c 100644 --- a/mistralrs-core/src/utils/gguf_metadata.rs +++ b/mistralrs-core/src/utils/gguf_metadata.rs @@ -1,7 +1,7 @@ use akin::akin; use anyhow::ensure; use anyhow::Result; -use candle_core::quantized::gguf_file; +use mcandle_core::quantized::gguf_file; use std::collections::HashMap; use tracing::warn; @@ -75,17 +75,17 @@ impl ContentMetadata<'_> { } } -// These traits below are a workaround for converting candles GGUF `Value` enum type wrapper. +// These traits below are a workaround for converting mcandles GGUF `Value` enum type wrapper. // A better upstream approach would instead be to provide serialize/deserialize support? pub trait TryFromValue { - fn try_from_value(value: gguf_file::Value) -> Result + fn try_from_value(value: gguf_file::Value) -> Result where Self: Sized; } // Value wrapped types, each has a different conversion method: // NOTE: Type conversion methods internally bail with "not a " -// https://docs.rs/candle-core/latest/candle_core/quantized/gguf_file/enum.Value.html#variants +// https://docs.rs/mcandle-core/latest/mcandle_core/quantized/gguf_file/enum.Value.html#variants akin! { let &types = [String, bool, f32, f64, i8, i16, i32, i64, u8, u16, u32, u64]; let &to_type = [ @@ -104,18 +104,18 @@ akin! { ]; impl TryFromValue for *types { - fn try_from_value(value: gguf_file::Value) -> Result { - *to_type.or_else(|_| candle_core::bail!("value is not a `*types`")) + fn try_from_value(value: gguf_file::Value) -> Result { + *to_type.or_else(|_| mcandle_core::bail!("value is not a `*types`")) } } } // Vec to Vec from above types: impl TryFromValue for Vec { - fn try_from_value(value_vec: gguf_file::Value) -> Result { + fn try_from_value(value_vec: gguf_file::Value) -> Result { value_vec .to_vec() - .or_else(|_| candle_core::bail!("value is not a `Vec`"))? + .or_else(|_| mcandle_core::bail!("value is not a `Vec`"))? .clone() .into_iter() .map(|item| T::try_from_value(item)) @@ -124,20 +124,20 @@ impl TryFromValue for Vec { } pub trait TryValueInto: Sized { - fn try_value_into(self) -> Result; + fn try_value_into(self) -> Result; } impl TryValueInto for gguf_file::Value { - fn try_value_into(self) -> Result { + fn try_value_into(self) -> Result { T::try_from_value(self) } } impl TryValueInto for Option { - fn try_value_into(self) -> Result { + fn try_value_into(self) -> Result { match self { Some(value) => value.try_value_into(), - None => candle_core::bail!("Expected `Option` to contain a value"), + None => mcandle_core::bail!("Expected `Option` to contain a value"), } } } diff --git a/mistralrs-core/src/utils/memory_usage.rs b/mistralrs-core/src/utils/memory_usage.rs index dcf52a479..247857297 100644 --- a/mistralrs-core/src/utils/memory_usage.rs +++ b/mistralrs-core/src/utils/memory_usage.rs @@ -1,4 +1,4 @@ -use candle_core::{Device, Result}; +use mcandle_core::{Device, Result}; use sysinfo::System; const KB_TO_BYTES: usize = 1024; @@ -16,17 +16,17 @@ impl MemoryUsage { } #[cfg(feature = "cuda")] Device::Cuda(_) => { - use candle_core::cuda_backend::WrapErr; - Ok(candle_core::cuda::cudarc::driver::result::mem_get_info() + use mcandle_core::cuda_backend::WrapErr; + Ok(mcandle_core::cuda::cudarc::driver::result::mem_get_info() .w()? .0) } #[cfg(not(feature = "cuda"))] Device::Cuda(_) => { - candle_core::bail!("Cannot get memory available for CUDA device") + mcandle_core::bail!("Cannot get memory available for CUDA device") } Device::Metal(_) => { - candle_core::bail!("Cannot get memory available for Metal device") + mcandle_core::bail!("Cannot get memory available for Metal device") } } } @@ -41,17 +41,17 @@ impl MemoryUsage { } #[cfg(feature = "cuda")] Device::Cuda(_) => { - use candle_core::cuda_backend::WrapErr; - Ok(candle_core::cuda::cudarc::driver::result::mem_get_info() + use mcandle_core::cuda_backend::WrapErr; + Ok(mcandle_core::cuda::cudarc::driver::result::mem_get_info() .w()? .1) } #[cfg(not(feature = "cuda"))] Device::Cuda(_) => { - candle_core::bail!("Cannot get total memory for CUDA device") + mcandle_core::bail!("Cannot get total memory for CUDA device") } Device::Metal(_) => { - candle_core::bail!("Cannot get total memory for Metal device") + mcandle_core::bail!("Cannot get total memory for Metal device") } } } diff --git a/mistralrs-core/src/utils/model_config.rs b/mistralrs-core/src/utils/model_config.rs index 6b912da05..339af2b97 100644 --- a/mistralrs-core/src/utils/model_config.rs +++ b/mistralrs-core/src/utils/model_config.rs @@ -1,7 +1,7 @@ use super::varbuilder_utils::{from_mmaped_safetensors, load_preload_adapters}; use anyhow::Result; -use candle_core::quantized::ggml_file; -use candle_nn::VarBuilder; +use mcandle_core::quantized::ggml_file; +use mcandle_nn::VarBuilder; use std::{collections::HashMap, path::PathBuf}; use crate::{ @@ -21,7 +21,7 @@ pub struct FileGGML { #[derive(derive_more::From)] pub struct Device<'a> { - device: &'a candle_core::Device, + device: &'a mcandle_core::Device, pub mapper: DeviceMapMetadata, pub topology: Option<&'a Topology>, } @@ -42,7 +42,7 @@ impl<'a> Adapter<'a> { #[allow(clippy::borrowed_box)] pub fn try_new<'b: 'a>( paths: &'b Box, - device: &'b candle_core::Device, + device: &'b mcandle_core::Device, silent: bool, is_xlora: bool, ) -> Result { @@ -50,7 +50,7 @@ impl<'a> Adapter<'a> { let ordering = paths.get_ordering().as_ref().unwrap(); let preload_adapters = load_preload_adapters( paths.get_lora_preload_adapter_info(), - candle_core::DType::F32, + mcandle_core::DType::F32, device, silent, )?; @@ -74,7 +74,7 @@ impl<'a> Adapter<'a> { .iter() .map(|(_, x)| (*x).to_owned()) .collect::>(), - Some(candle_core::DType::F32), + Some(mcandle_core::DType::F32), device, silent, None, @@ -151,7 +151,7 @@ impl<'a, Q: QuantParams> ModelParams<'a, Q> { // Traits for the existing methods used across various model types to impl `from_ggml()` / `from_gguf()` // Basic: pub trait FromGGML { - fn from_ggml(ct: ggml_file::Content, gqa: usize) -> Result + fn from_ggml(ct: ggml_file::Content, gqa: usize) -> Result where Self: Sized; } @@ -159,11 +159,11 @@ pub trait FromGGML { pub trait FromGGUF { fn from_gguf( ct: Content<'_, R>, - device: &candle_core::Device, + device: &mcandle_core::Device, mapper: DeviceMapMetadata, topology: Option<&Topology>, attention_mechanism: AttentionImplementation, - ) -> Result + ) -> Result where Self: Sized; } @@ -178,7 +178,7 @@ pub trait FromAdapterGGML { ordering: &Ordering, xlora_config: Option, preload_adapters: &Option>, - ) -> Result + ) -> Result where Self: Sized; } @@ -186,7 +186,7 @@ pub trait FromAdapterGGUF { #[allow(clippy::too_many_arguments)] fn from_gguf( ct: Content<'_, R>, - device: &candle_core::Device, + device: &mcandle_core::Device, lora_config: &[((String, String), LoraConfig)], vb: &VarBuilder, ordering: &Ordering, @@ -194,14 +194,14 @@ pub trait FromAdapterGGUF { mapper: DeviceMapMetadata, topology: Option<&Topology>, preload_adapters: &Option>, - ) -> Result + ) -> Result where Self: Sized; } // NOTE: Below is a workaround to proxy params to the existing API methods `get_gguf()` / `get_gmml()` traits covered above. impl Config { - pub fn try_into_model(self) -> Result { + pub fn try_into_model(self) -> Result { // Destructure props: let ParamsGGML(FileGGML { ct, gqa }) = self.quant; @@ -211,7 +211,7 @@ impl Config { } impl Config> { - pub fn try_into_model(self) -> Result { + pub fn try_into_model(self) -> Result { // Destructure props: let ParamsGGML(FileGGML { ct, gqa }) = self.quant; @@ -237,7 +237,7 @@ impl Config> { } impl Config, NoAdapter> { - pub fn try_into_model(self) -> Result { + pub fn try_into_model(self) -> Result { // Destructure props: let ParamsGGUF( ct, @@ -255,7 +255,7 @@ impl Config, NoAdapter> { } impl Config, Adapter<'_>> { - pub fn try_into_model(self) -> Result { + pub fn try_into_model(self) -> Result { // Destructure props: let ParamsGGUF( ct, @@ -301,7 +301,7 @@ use crate::{ use akin::akin; impl TryFrom> for QLlama { - type Error = candle_core::Error; + type Error = mcandle_core::Error; fn try_from(params: ModelParams<'_, ParamsGGML>) -> Result { let config = params.expect_quantized("`Config` should be GGML Quantized"); @@ -310,7 +310,7 @@ impl TryFrom> for QLlama { } impl TryFrom> for XLoraQLlama { - type Error = candle_core::Error; + type Error = mcandle_core::Error; fn try_from(params: ModelParams<'_, ParamsGGML>) -> Result { let config = params.expect_adapted("`Config` should be GGML Quantized with an Adapter"); @@ -322,7 +322,7 @@ akin! { let &models_gguf = [QLlama, QPhi, QPhi3, QStarcoder2, QQwen2]; impl TryFrom>> for *models_gguf { - type Error = candle_core::Error; + type Error = mcandle_core::Error; fn try_from(params: ModelParams<'_, ParamsGGUF<'_, R>>) -> Result { let config = params.expect_quantized("`Config` should be GGUF Quantized"); @@ -335,7 +335,7 @@ akin! { let &models_gguf_a = [XLoraQLlama, XLoraQPhi3]; impl TryFrom>> for *models_gguf_a { - type Error = candle_core::Error; + type Error = mcandle_core::Error; fn try_from(params: ModelParams<'_, ParamsGGUF<'_, R>>) -> Result { let config = params.expect_adapted("`Config` should be GGUF Quantized with an Adapter"); diff --git a/mistralrs-core/src/utils/normal.rs b/mistralrs-core/src/utils/normal.rs index bdb092f23..5d62be32f 100644 --- a/mistralrs-core/src/utils/normal.rs +++ b/mistralrs-core/src/utils/normal.rs @@ -1,7 +1,7 @@ use std::{fmt::Display, str::FromStr}; use anyhow::Result; -use candle_core::{DType, Device, Tensor}; +use mcandle_core::{DType, Device, Tensor}; use serde::Deserialize; use tracing::info; @@ -115,7 +115,7 @@ fn get_dtypes() -> Vec { get_dtypes_non_cuda() } -fn determine_auto_dtype_all(devices: &[&Device]) -> candle_core::Result { +fn determine_auto_dtype_all(devices: &[&Device]) -> mcandle_core::Result { let dev_dtypes = get_dtypes(); for dtype in get_dtypes_non_cuda() .iter() @@ -135,14 +135,14 @@ fn determine_auto_dtype_all(devices: &[&Device]) -> candle_core::Result { Ok(_) => (), Err(e) => match e { // For CUDA - candle_core::Error::UnsupportedDTypeForOp(_, _) => continue, + mcandle_core::Error::UnsupportedDTypeForOp(_, _) => continue, // Accelerate backend doesn't support f16/bf16 // Metal backend doesn't support f16 - candle_core::Error::Msg(_) => continue, + mcandle_core::Error::Msg(_) => continue, // This is when the metal backend doesn't support bf16 - candle_core::Error::Metal(_) => continue, + mcandle_core::Error::Metal(_) => continue, // If running with RUST_BACKTRACE=1 - candle_core::Error::WithBacktrace { .. } => continue, + mcandle_core::Error::WithBacktrace { .. } => continue, other => return Err(other), }, } diff --git a/mistralrs-core/src/utils/unvarbuilder.rs b/mistralrs-core/src/utils/unvarbuilder.rs index 5614c5e73..9976db060 100644 --- a/mistralrs-core/src/utils/unvarbuilder.rs +++ b/mistralrs-core/src/utils/unvarbuilder.rs @@ -3,9 +3,9 @@ use std::{ sync::{Arc, RwLock}, }; -use candle_core::{quantized::QMatMul, Tensor}; -use candle_nn::{Conv2d, Embedding, LayerNorm, Linear}; use itertools::Itertools; +use mcandle_core::{quantized::QMatMul, Tensor}; +use mcandle_nn::{Conv2d, Embedding, LayerNorm, Linear}; use mistralrs_quant::QuantMethod; use crate::layers::{F32RmsNorm, QLinear, RmsNorm}; diff --git a/mistralrs-core/src/utils/varbuilder_utils.rs b/mistralrs-core/src/utils/varbuilder_utils.rs index 67511d203..c156dcc8e 100644 --- a/mistralrs-core/src/utils/varbuilder_utils.rs +++ b/mistralrs-core/src/utils/varbuilder_utils.rs @@ -7,10 +7,10 @@ use std::{ thread::{self, JoinHandle}, }; -use candle_core::{ +use mcandle_core::{ pickle::PthTensors, safetensors::MmapedSafetensors, DType, Device, Result, Tensor, }; -use candle_nn::{ +use mcandle_nn::{ var_builder::{SimpleBackend, VarBuilderArgs}, VarBuilder, }; @@ -59,7 +59,7 @@ impl TensorLoaderBackend for PickleBackend { let t = self .0 .get(name)? - .ok_or(candle_core::Error::Msg(format!( + .ok_or(mcandle_core::Error::Msg(format!( "Could not load tensor {name}" )))? .to_device(device)?; @@ -197,12 +197,12 @@ trait LoadTensors { .expect("Expected to convert") { "safetensors" => Box::new(SafetensorBackend(unsafe { - candle_core::safetensors::MmapedSafetensors::new(path)? + mcandle_core::safetensors::MmapedSafetensors::new(path)? })), "pth" | "pt" | "bin" => Box::new(PickleBackend( - candle_core::pickle::PthTensors::new(path, None)? + mcandle_core::pickle::PthTensors::new(path, None)? )), - other => candle_core::bail!("Unexpected extension `{other}`, this should have been handles by `get_model_paths`."), + other => mcandle_core::bail!("Unexpected extension `{other}`, this should have been handles by `get_model_paths`."), }; // Extracts the tensor name and processes it, filtering tensors and deriving the key name: diff --git a/mistralrs-core/src/vision_models/clip.rs b/mistralrs-core/src/vision_models/clip.rs index 515f4d7ae..947556616 100644 --- a/mistralrs-core/src/vision_models/clip.rs +++ b/mistralrs-core/src/vision_models/clip.rs @@ -2,9 +2,9 @@ use std::sync::Arc; -// Sourced from https://github.com/huggingface/candle/blob/main/candle-transformers/src/models/clip/vision_model.rs -use candle_core::{IndexOp, Result, Shape, Tensor, D}; -use candle_nn::{Conv2dConfig, Module}; +// Sourced from https://github.com/huggingface/mcandle/blob/main/mcandle-transformers/src/models/clip/vision_model.rs +use mcandle_core::{IndexOp, Result, Shape, Tensor, D}; +use mcandle_nn::{Conv2dConfig, Module}; use mistralrs_quant::QuantMethod; use crate::{serde_default_fn, utils::unvarbuilder::UnVarBuilder}; @@ -17,7 +17,7 @@ pub enum Activation { impl Module for Activation { fn forward(&self, xs: &Tensor) -> Result { match self { - Activation::QuickGelu => xs * candle_nn::ops::sigmoid(&(xs * 1.702f64)?), + Activation::QuickGelu => xs * mcandle_nn::ops::sigmoid(&(xs * 1.702f64)?), } } } @@ -54,14 +54,14 @@ pub struct ClipConfig { // https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L112 #[derive(Clone, Debug)] struct ClipVisionEmbeddings { - patch_embedding: candle_nn::Conv2d, + patch_embedding: mcandle_nn::Conv2d, position_ids: Tensor, class_embedding: Tensor, - position_embedding: candle_nn::Embedding, + position_embedding: mcandle_nn::Embedding, } impl ClipVisionEmbeddings { - fn new(vs: candle_nn::VarBuilder, c: &ClipConfig) -> Result { + fn new(vs: mcandle_nn::VarBuilder, c: &ClipConfig) -> Result { // originally nn.Parameter let class_embedding = if vs.contains_tensor("class_embedding") { vs.get(c.hidden_size, "class_embedding")? @@ -78,8 +78,8 @@ impl ClipVisionEmbeddings { ..Default::default() }; let position_embedding = - candle_nn::embedding(num_positions, c.hidden_size, vs.pp("position_embedding"))?; - let patch_embedding = candle_nn::conv2d_no_bias( + mcandle_nn::embedding(num_positions, c.hidden_size, vs.pp("position_embedding"))?; + let patch_embedding = mcandle_nn::conv2d_no_bias( c.num_channels, c.hidden_size, c.patch_size, @@ -123,7 +123,7 @@ struct ClipAttention { } impl ClipAttention { - fn new(vs: candle_nn::VarBuilder, c: &ClipConfig) -> Result { + fn new(vs: mcandle_nn::VarBuilder, c: &ClipConfig) -> Result { let hidden_size = c.hidden_size; let num_attention_heads = c.num_attention_heads; let k_proj = mistralrs_quant::linear(hidden_size, hidden_size, &None, vs.pp("k_proj"))?; @@ -177,7 +177,7 @@ impl ClipAttention { attn_weights }; - let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; + let attn_weights = mcandle_nn::ops::softmax_last_dim(&attn_weights)?; let attn_output = attn_weights.matmul(&value_states)?; let attn_output = attn_output @@ -196,7 +196,7 @@ struct ClipMlp { } impl ClipMlp { - fn new(vs: candle_nn::VarBuilder, c: &ClipConfig) -> Result { + fn new(vs: mcandle_nn::VarBuilder, c: &ClipConfig) -> Result { let fc1 = mistralrs_quant::linear(c.hidden_size, c.intermediate_size, &None, vs.pp("fc1"))?; let fc2 = mistralrs_quant::linear(c.intermediate_size, c.hidden_size, &None, vs.pp("fc2"))?; @@ -218,17 +218,17 @@ impl ClipMlp { #[derive(Clone, Debug)] struct ClipEncoderLayer { self_attn: ClipAttention, - layer_norm1: candle_nn::LayerNorm, + layer_norm1: mcandle_nn::LayerNorm, mlp: ClipMlp, - layer_norm2: candle_nn::LayerNorm, + layer_norm2: mcandle_nn::LayerNorm, } impl ClipEncoderLayer { - fn new(vs: candle_nn::VarBuilder, c: &ClipConfig) -> Result { + fn new(vs: mcandle_nn::VarBuilder, c: &ClipConfig) -> Result { let self_attn = ClipAttention::new(vs.pp("self_attn"), c)?; - let layer_norm1 = candle_nn::layer_norm(c.hidden_size, 1e-5, vs.pp("layer_norm1"))?; + let layer_norm1 = mcandle_nn::layer_norm(c.hidden_size, 1e-5, vs.pp("layer_norm1"))?; let mlp = ClipMlp::new(vs.pp("mlp"), c)?; - let layer_norm2 = candle_nn::layer_norm(c.hidden_size, 1e-5, vs.pp("layer_norm2"))?; + let layer_norm2 = mcandle_nn::layer_norm(c.hidden_size, 1e-5, vs.pp("layer_norm2"))?; Ok(ClipEncoderLayer { self_attn, @@ -257,7 +257,7 @@ pub struct ClipEncoder { } impl ClipEncoder { - pub fn new(vs: candle_nn::VarBuilder, c: &ClipConfig) -> Result { + pub fn new(vs: mcandle_nn::VarBuilder, c: &ClipConfig) -> Result { let vs = vs.pp("layers"); let mut layers: Vec = Vec::new(); for index in 0..c.num_hidden_layers { @@ -287,18 +287,19 @@ impl ClipEncoder { pub struct ClipVisionTransformer { embeddings: ClipVisionEmbeddings, encoder: ClipEncoder, - pre_layer_norm: candle_nn::LayerNorm, - final_layer_norm: candle_nn::LayerNorm, + pre_layer_norm: mcandle_nn::LayerNorm, + final_layer_norm: mcandle_nn::LayerNorm, } impl ClipVisionTransformer { /// Create a CLIP vision transformer model. Expects the vb to point to the root (not model) /// where (for example) `.pp("embeddings")` is valid. - pub fn new(vb: candle_nn::VarBuilder, c: &ClipConfig) -> Result { + pub fn new(vb: mcandle_nn::VarBuilder, c: &ClipConfig) -> Result { let embeddings = ClipVisionEmbeddings::new(vb.pp("embeddings"), c)?; - let pre_layer_norm = candle_nn::layer_norm(c.hidden_size, 1e-5, vb.pp("pre_layrnorm"))?; + let pre_layer_norm = mcandle_nn::layer_norm(c.hidden_size, 1e-5, vb.pp("pre_layrnorm"))?; let encoder = ClipEncoder::new(vb.pp("encoder"), c)?; - let final_layer_norm = candle_nn::layer_norm(c.hidden_size, 1e-5, vb.pp("post_layernorm"))?; + let final_layer_norm = + mcandle_nn::layer_norm(c.hidden_size, 1e-5, vb.pp("post_layernorm"))?; Ok(Self { embeddings, encoder, diff --git a/mistralrs-core/src/vision_models/idefics2.rs b/mistralrs-core/src/vision_models/idefics2.rs index 2d1be2e18..7a3e6601d 100644 --- a/mistralrs-core/src/vision_models/idefics2.rs +++ b/mistralrs-core/src/vision_models/idefics2.rs @@ -1,7 +1,7 @@ #![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] -use candle_core::{DType, Device, IndexOp, Result, Tensor, D}; -use candle_nn::{ +use mcandle_core::{DType, Device, IndexOp, Result, Tensor, D}; +use mcandle_nn::{ conv2d, embedding, layer_norm, linear, linear_no_bias, Conv2d, Conv2dConfig, Embedding, LayerNorm, Module, VarBuilder, }; @@ -416,7 +416,7 @@ impl Attention { attn_weights, &self.neg_inf, )?; - let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; + let attn_weights = mcandle_nn::ops::softmax_last_dim(&attn_weights)?; let mut attn_output = attn_weights.matmul(&v.contiguous()?)?; if self.q_proj.is_quant() { @@ -773,7 +773,7 @@ impl PerceiverAttention { attn_weights, &self.neg_inf, )?; - let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; + let attn_weights = mcandle_nn::ops::softmax_last_dim(&attn_weights)?; let mut attn_output = attn_weights.matmul(&v.contiguous()?)?; if self.q_proj.is_quant() { @@ -1157,7 +1157,7 @@ impl Idefics2 { &image_hidden_states, )? } else { - candle_core::bail!("Pixel values were specified for a non-prompt.") + mcandle_core::bail!("Pixel values were specified for a non-prompt.") } } else { self.text_model.get_input_embeddings(input_ids)? @@ -1249,7 +1249,7 @@ impl VisionModel for Idefics2 { model_specific_args: Box, metadata: Option<(Vec<(Tensor, Tensor)>, &mut PagedAttentionInputMetadata)>, flash_params: &FlashParams, - ) -> candle_core::Result { + ) -> mcandle_core::Result { let pixel_attention_mask: Option = *model_specific_args .downcast() .expect("Cannot downcast into `Option`"); diff --git a/mistralrs-core/src/vision_models/idefics2_input_processor.rs b/mistralrs-core/src/vision_models/idefics2_input_processor.rs index 51b6aa22f..d4794390f 100644 --- a/mistralrs-core/src/vision_models/idefics2_input_processor.rs +++ b/mistralrs-core/src/vision_models/idefics2_input_processor.rs @@ -2,9 +2,9 @@ use std::{any::Any, num::NonZeroUsize, sync::Arc}; -use candle_core::{Device, Result, Tensor}; use image::{DynamicImage, GenericImageView}; use indexmap::IndexMap; +use mcandle_core::{Device, Result, Tensor}; use mistralrs_vision::{ApplyTransforms, Normalize, Rescale, ToTensorNoNorm, Transforms}; use tokenizers::Tokenizer; use tracing::warn; @@ -312,7 +312,7 @@ impl ImagePreProcessor for Idefics2ImageProcessor { } else if size.contains_key("height") && size.contains_key("width") { (size["height"] as usize, size["width"] as usize) } else { - candle_core::bail!("Size must be a map of `shortest_edge` and `longest_edge` or `height` and `width`."); + mcandle_core::bail!("Size must be a map of `shortest_edge` and `longest_edge` or `height` and `width`."); }; *image = image.resize_exact(w as u32, h as u32, config.resampling.to_filter()?); diff --git a/mistralrs-core/src/vision_models/image_processor.rs b/mistralrs-core/src/vision_models/image_processor.rs index f867b91a8..d82d7ab58 100644 --- a/mistralrs-core/src/vision_models/image_processor.rs +++ b/mistralrs-core/src/vision_models/image_processor.rs @@ -1,7 +1,7 @@ #![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] -use candle_core::{Device, Result, Tensor}; use image::DynamicImage; +use mcandle_core::{Device, Result, Tensor}; use crate::pipeline::InputsProcessor; diff --git a/mistralrs-core/src/vision_models/llava/llava15.rs b/mistralrs-core/src/vision_models/llava/llava15.rs index e81550b91..d929aac64 100644 --- a/mistralrs-core/src/vision_models/llava/llava15.rs +++ b/mistralrs-core/src/vision_models/llava/llava15.rs @@ -19,8 +19,8 @@ use crate::vision_models::clip::{ClipConfig, ClipVisionTransformer}; use crate::vision_models::llava::config::Config; use crate::AnyMoeConfig; use crate::AnyMoeExpertType; -use candle_core::{bail, DType, Device, IndexOp, Result, Tensor}; -use candle_nn::{linear, Activation, Linear, VarBuilder}; +use mcandle_core::{bail, DType, Device, IndexOp, Result, Tensor}; +use mcandle_nn::{linear, Activation, Linear, VarBuilder}; pub(crate) struct LLaVAVisionSpecificArgs; // only a dumb struct to satisfy the trait @@ -302,7 +302,7 @@ impl VisionModel for Model { _model_specific_args: Box, // pixel attention mask, or image sizes, or anything else metadata: Option<(Vec<(Tensor, Tensor)>, &mut PagedAttentionInputMetadata)>, flash_params: &FlashParams, - ) -> candle_core::Result { + ) -> mcandle_core::Result { self.forward_inputs( input_ids, pixel_values, diff --git a/mistralrs-core/src/vision_models/llava/llava_inputs_processor.rs b/mistralrs-core/src/vision_models/llava/llava_inputs_processor.rs index e4d8b6be4..6269c110b 100644 --- a/mistralrs-core/src/vision_models/llava/llava_inputs_processor.rs +++ b/mistralrs-core/src/vision_models/llava/llava_inputs_processor.rs @@ -3,11 +3,11 @@ use std::any::Any; use std::num::NonZeroUsize; use std::sync::Arc; -use candle_core::Result; -use candle_core::{DType, Device, Tensor}; use image::GenericImageView; use image::Rgb; use itertools::Itertools; +use mcandle_core::Result; +use mcandle_core::{DType, Device, Tensor}; use regex_automata::meta::Regex; use tokenizers::Tokenizer; use tracing::warn; @@ -334,11 +334,11 @@ impl ImagePreProcessor for LLaVAInputProcessor { images: Vec, videos: Vec>, config: &preprocessor_config::PreProcessorConfig, - device: &candle_core::Device, + device: &mcandle_core::Device, (_, _): (usize, usize), - ) -> candle_core::Result { + ) -> mcandle_core::Result { if images.len() > 1 { - candle_core::bail!("Can only process one image per batch"); // This is no different from phi3_input_processor + mcandle_core::bail!("Can only process one image per batch"); // This is no different from phi3_input_processor }; assert!(videos.is_empty()); let resized_size = *config.size.as_ref().unwrap().get("shortest_edge").unwrap() as usize; diff --git a/mistralrs-core/src/vision_models/llava/llava_llm/llama.rs b/mistralrs-core/src/vision_models/llava/llava_llm/llama.rs index 9fd539a09..f4be3a5ec 100644 --- a/mistralrs-core/src/vision_models/llava/llava_llm/llama.rs +++ b/mistralrs-core/src/vision_models/llava/llava_llm/llama.rs @@ -7,8 +7,8 @@ use std::sync::Arc; -use candle_core::{DType, Device, Result, Tensor}; -use candle_nn::{embedding, linear_no_bias as linear, Embedding, Module, VarBuilder}; +use mcandle_core::{DType, Device, Result, Tensor}; +use mcandle_nn::{embedding, linear_no_bias as linear, Embedding, Module, VarBuilder}; use mistralrs_quant::{QuantMethod, QuantMethodConfig, UnquantLinear}; use crate::{ @@ -227,7 +227,7 @@ impl MlpLayer for Mlp { if let Some(t) = self.c_fc1.quantized_act_type() { x = x.to_dtype(t)?; } - let x = (candle_nn::ops::silu(&MatMul.qmethod_matmul(&x, &*self.c_fc1)?)? + let x = (mcandle_nn::ops::silu(&MatMul.qmethod_matmul(&x, &*self.c_fc1)?)? * MatMul.qmethod_matmul(&x, &*self.c_fc2)?)?; let mut res = MatMul.qmethod_matmul(&x, &*self.c_proj)?; if self.c_fc1.quantized_act_type().is_some() { diff --git a/mistralrs-core/src/vision_models/llava/llava_llm/mistral.rs b/mistralrs-core/src/vision_models/llava/llava_llm/mistral.rs index f3d4ac8b9..b1d1e4e8d 100644 --- a/mistralrs-core/src/vision_models/llava/llava_llm/mistral.rs +++ b/mistralrs-core/src/vision_models/llava/llava_llm/mistral.rs @@ -3,8 +3,8 @@ use std::sync::Arc; /// Mistral LLM, https://github.com/mistralai/mistral-src -use candle_core::{DType, Device, Module, Result, Tensor}; -use candle_nn::{linear_no_bias, VarBuilder}; +use mcandle_core::{DType, Device, Module, Result, Tensor}; +use mcandle_nn::{linear_no_bias, VarBuilder}; use mistralrs_quant::{QuantMethod, QuantMethodConfig, UnquantLinear}; use crate::{ @@ -359,7 +359,7 @@ impl DecoderLayer { } pub struct Model { - embed_tokens: candle_nn::Embedding, + embed_tokens: mcandle_nn::Embedding, layers: Vec, norm: RmsNorm, lm_head: Arc, @@ -401,7 +401,7 @@ impl Model { attention_mechanism: AttentionImplementation, ) -> Result { let mapper = normal_loading_metadata.mapper; - let embed_tokens = candle_nn::embedding( + let embed_tokens = mcandle_nn::embedding( cfg.vocab_size, cfg.hidden_size, mapper.set_nm_device(vb_m.pp("embed_tokens"), false), diff --git a/mistralrs-core/src/vision_models/llava/llava_llm/mod.rs b/mistralrs-core/src/vision_models/llava/llava_llm/mod.rs index 16da865cf..5654a7259 100644 --- a/mistralrs-core/src/vision_models/llava/llava_llm/mod.rs +++ b/mistralrs-core/src/vision_models/llava/llava_llm/mod.rs @@ -1,5 +1,5 @@ #![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] -use candle_core::{DType, Device, Result, Tensor}; +use mcandle_core::{DType, Device, Result, Tensor}; use crate::pipeline::{ text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata}, @@ -50,7 +50,7 @@ impl OrdinaryRoPE { let (_b_sz, _, seq_len, _hidden_size) = x.dims4()?; let cos = cos.narrow(0, index_pos, seq_len)?; let sin = sin.narrow(0, index_pos, seq_len)?; - candle_nn::rotary_emb::rope(x, &cos, &sin) + mcandle_nn::rotary_emb::rope(x, &cos, &sin) } } pub(crate) mod llama; diff --git a/mistralrs-core/src/vision_models/llava/llava_next.rs b/mistralrs-core/src/vision_models/llava/llava_next.rs index 58459e94f..9b483f932 100644 --- a/mistralrs-core/src/vision_models/llava/llava_next.rs +++ b/mistralrs-core/src/vision_models/llava/llava_next.rs @@ -3,8 +3,8 @@ clippy::cast_precision_loss, clippy::too_many_arguments )] -use candle_core::{bail, DType, Device, IndexOp, Result, Tensor}; -use candle_nn::{linear, Activation, Linear, VarBuilder}; +use mcandle_core::{bail, DType, Device, IndexOp, Result, Tensor}; +use mcandle_nn::{linear, Activation, Linear, VarBuilder}; use crate::amoe::{AnyMoeBaseModelMixin, MlpLayer}; use crate::device_map::DeviceMapper; @@ -383,7 +383,7 @@ impl VisionModel for Model { model_specific_args: Box, // pixel attention mask, or image sizes, or anything else metadata: Option<(Vec<(Tensor, Tensor)>, &mut PagedAttentionInputMetadata)>, flash_params: &FlashParams, - ) -> candle_core::Result { + ) -> mcandle_core::Result { let LLaVANextVisionSpecificArgs { image_sizes, num_image_tokens, diff --git a/mistralrs-core/src/vision_models/llava/llava_next_inputs_processor.rs b/mistralrs-core/src/vision_models/llava/llava_next_inputs_processor.rs index 5f43387be..49e125c7b 100644 --- a/mistralrs-core/src/vision_models/llava/llava_next_inputs_processor.rs +++ b/mistralrs-core/src/vision_models/llava/llava_next_inputs_processor.rs @@ -3,10 +3,10 @@ use std::any::Any; use std::num::NonZeroUsize; use std::sync::Arc; -use candle_core::Result; -use candle_core::{DType, Device, Tensor}; use image::GenericImageView; use itertools::Itertools; +use mcandle_core::Result; +use mcandle_core::{DType, Device, Tensor}; use regex_automata::meta::Regex; use tokenizers::Tokenizer; use tracing::warn; @@ -382,11 +382,11 @@ impl ImagePreProcessor for LLaVANextInputProcessor { images: Vec, videos: Vec>, config: &preprocessor_config::PreProcessorConfig, - device: &candle_core::Device, + device: &mcandle_core::Device, (_, _): (usize, usize), - ) -> candle_core::Result { + ) -> mcandle_core::Result { if images.len() > 1 { - candle_core::bail!("Can only process one image per batch"); // This is no different from phi3_input_processor + mcandle_core::bail!("Can only process one image per batch"); // This is no different from phi3_input_processor }; assert!(videos.is_empty()); diff --git a/mistralrs-core/src/vision_models/llava/utils.rs b/mistralrs-core/src/vision_models/llava/utils.rs index bb384fa1c..626cb34fb 100644 --- a/mistralrs-core/src/vision_models/llava/utils.rs +++ b/mistralrs-core/src/vision_models/llava/utils.rs @@ -4,11 +4,11 @@ clippy::too_many_arguments )] use crate::vision_models::preprocessor_config::PreProcessorConfig; -use candle_core::{DType, Device, Result, Tensor}; use image::{ imageops::{overlay, FilterType}, DynamicImage, GenericImageView, Rgb, RgbImage, }; +use mcandle_core::{DType, Device, Result, Tensor}; use std::cmp::min; pub(crate) fn get_anyres_image_grid_shape( diff --git a/mistralrs-core/src/vision_models/mllama/config.rs b/mistralrs-core/src/vision_models/mllama/config.rs index 5346e5b41..81f6804b5 100644 --- a/mistralrs-core/src/vision_models/mllama/config.rs +++ b/mistralrs-core/src/vision_models/mllama/config.rs @@ -1,5 +1,5 @@ -use candle_core::{Result, Tensor}; -use candle_nn::Module; +use mcandle_core::{Result, Tensor}; +use mcandle_nn::Module; use mistralrs_quant::QuantizedConfig; use crate::serde_default_fn; @@ -18,7 +18,7 @@ pub(super) enum VisionActivation { impl Module for VisionActivation { fn forward(&self, xs: &Tensor) -> Result { match self { - Self::QuickGelu => xs * candle_nn::ops::sigmoid(&(xs * 1.702f64)?), + Self::QuickGelu => xs * mcandle_nn::ops::sigmoid(&(xs * 1.702f64)?), Self::Gelu => xs.gelu_erf(), // https://github.com/huggingface/transformers/blob/12f043eaeaabfef6f6efea411d98e6f6d3c094b7/src/transformers/activations.py#L49-L78 Self::NewGelu => xs.gelu(), @@ -93,7 +93,7 @@ pub struct MLlamaTextConfig { pub(crate) rope_scaling: Option, pub(crate) vocab_size: usize, pub(crate) hidden_size: usize, - pub(crate) hidden_act: candle_nn::Activation, + pub(crate) hidden_act: mcandle_nn::Activation, pub(crate) num_hidden_layers: usize, pub(crate) num_attention_heads: usize, pub(crate) num_key_value_heads: usize, diff --git a/mistralrs-core/src/vision_models/mllama/inputs_processor.rs b/mistralrs-core/src/vision_models/mllama/inputs_processor.rs index 7433558ae..e46b4d1ba 100644 --- a/mistralrs-core/src/vision_models/mllama/inputs_processor.rs +++ b/mistralrs-core/src/vision_models/mllama/inputs_processor.rs @@ -7,9 +7,9 @@ use std::{ sync::{Arc, RwLock}, }; -use candle_core::{Context, DType, Device, Result, Tensor}; use image::{imageops::FilterType, DynamicImage}; use itertools::Itertools; +use mcandle_core::{Context, DType, Device, Result, Tensor}; use mistralrs_vision::{ ApplyTensorTransforms, ApplyTransforms, Normalize, Rescale, TensorTransforms, ToTensorNoNorm, Transforms, @@ -125,7 +125,7 @@ fn convert_sparse_cross_attention_mask_to_dense( max_num_tiles: usize, length: usize, dev: &Device, -) -> candle_core::Result { +) -> mcandle_core::Result { let bs = cross_attn_token_mask.len(); let max_num_images = cross_attn_token_mask.iter().map(|x| x.len()).max().unwrap(); diff --git a/mistralrs-core/src/vision_models/mllama/mod.rs b/mistralrs-core/src/vision_models/mllama/mod.rs index ed4e94a30..4b264e3c7 100644 --- a/mistralrs-core/src/vision_models/mllama/mod.rs +++ b/mistralrs-core/src/vision_models/mllama/mod.rs @@ -13,8 +13,8 @@ pub(crate) use inputs_processor::MLlamaProcessor; use text::MLlamaTextModel; use vision::MLlamaVisionModel; -use candle_core::{DType, Device, Result, Tensor, D}; -use candle_nn::{linear, Linear, Module, VarBuilder}; +use mcandle_core::{DType, Device, Result, Tensor, D}; +use mcandle_nn::{linear, Linear, Module, VarBuilder}; use mistralrs_quant::QuantMethod; use crate::{ @@ -133,10 +133,10 @@ impl MLlamaModel { ) -> Result { let cross_attn_states = if let Some(pixel_values) = pixel_values { let Some(aspect_ratio_mask) = aspect_ratio_mask else { - candle_core::bail!("`aspect_ratio_mask` must be specified if `pixel_values` is."); + mcandle_core::bail!("`aspect_ratio_mask` must be specified if `pixel_values` is."); }; let Some(aspect_ratio_ids) = aspect_ratio_ids else { - candle_core::bail!("`aspect_ratio_ids` must be specified if `pixel_values` is."); + mcandle_core::bail!("`aspect_ratio_ids` must be specified if `pixel_values` is."); }; let vision_outputs = self.vision_model diff --git a/mistralrs-core/src/vision_models/mllama/text.rs b/mistralrs-core/src/vision_models/mllama/text.rs index 3a8550afd..34e65d97b 100644 --- a/mistralrs-core/src/vision_models/mllama/text.rs +++ b/mistralrs-core/src/vision_models/mllama/text.rs @@ -2,8 +2,8 @@ use std::{collections::HashMap, sync::Arc}; -use candle_core::{DType, Device, IndexOp, Result, Tensor}; -use candle_nn::{embedding, Activation, Embedding, Module, VarBuilder}; +use mcandle_core::{DType, Device, IndexOp, Result, Tensor}; +use mcandle_nn::{embedding, Activation, Embedding, Module, VarBuilder}; use mistralrs_quant::{linear_no_bias, QuantMethod, QuantMethodConfig, UnquantLinear}; use crate::{ @@ -390,7 +390,7 @@ impl MLlamaTextCrossAttention { } else if let Some((k_cache, v_cache)) = kv_cache { (k_cache.clone(), v_cache.clone()) } else { - candle_core::bail!("Cross attn cannot find k,v cache or cross attn hidden states!") + mcandle_core::bail!("Cross attn cannot find k,v cache or cross attn hidden states!") }; let mut attn_output = Sdpa @@ -531,7 +531,7 @@ impl MLlamaTextModel { attention_mechanism: AttentionImplementation, ) -> Result { if !matches!(attention_mechanism, AttentionImplementation::Eager) { - candle_core::bail!("Expected eager attention implementation"); + mcandle_core::bail!("Expected eager attention implementation"); } let mapper = normal_loading_metadata.mapper; @@ -550,7 +550,7 @@ impl MLlamaTextModel { )? } else { Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized( - candle_nn::Linear::new( + mcandle_nn::Linear::new( mapper.cast_nm_device(embed_tokens.embeddings(), false)?, None, ), diff --git a/mistralrs-core/src/vision_models/mllama/vision.rs b/mistralrs-core/src/vision_models/mllama/vision.rs index 524c5d645..1c987be21 100644 --- a/mistralrs-core/src/vision_models/mllama/vision.rs +++ b/mistralrs-core/src/vision_models/mllama/vision.rs @@ -2,8 +2,8 @@ use std::{ops::Mul, sync::Arc}; -use candle_core::{DType, Result, Tensor, D}; -use candle_nn::{ +use mcandle_core::{DType, Result, Tensor, D}; +use mcandle_nn::{ conv2d_no_bias, embedding, layer_norm, Conv2d, Conv2dConfig, Embedding, LayerNorm, LayerNormConfig, Linear, Module, VarBuilder, }; @@ -139,22 +139,22 @@ impl MLlamaVisionAttention { fn new(cfg: &MLlamaVisionConfig, vb: VarBuilder) -> Result { let head_dim = cfg.hidden_size / cfg.num_attention_heads; Ok(Self { - q_proj: candle_nn::linear_no_bias( + q_proj: mcandle_nn::linear_no_bias( cfg.hidden_size, cfg.num_attention_heads * head_dim, vb.pp("q_proj"), )?, - k_proj: candle_nn::linear_no_bias( + k_proj: mcandle_nn::linear_no_bias( cfg.hidden_size, cfg.num_attention_heads * head_dim, vb.pp("k_proj"), )?, - v_proj: candle_nn::linear_no_bias( + v_proj: mcandle_nn::linear_no_bias( cfg.hidden_size, cfg.num_attention_heads * head_dim, vb.pp("v_proj"), )?, - o_proj: candle_nn::linear_no_bias( + o_proj: mcandle_nn::linear_no_bias( cfg.hidden_size, cfg.num_attention_heads * head_dim, vb.pp("o_proj"), diff --git a/mistralrs-core/src/vision_models/mod.rs b/mistralrs-core/src/vision_models/mod.rs index 7b93ae7a0..33ce1980e 100644 --- a/mistralrs-core/src/vision_models/mod.rs +++ b/mistralrs-core/src/vision_models/mod.rs @@ -1,6 +1,6 @@ use std::any::Any; -use candle_core::Tensor; +use mcandle_core::Tensor; pub(crate) mod clip; pub(crate) mod idefics2; diff --git a/mistralrs-core/src/vision_models/phi3.rs b/mistralrs-core/src/vision_models/phi3.rs index 207e240aa..5083bfbd6 100644 --- a/mistralrs-core/src/vision_models/phi3.rs +++ b/mistralrs-core/src/vision_models/phi3.rs @@ -2,11 +2,11 @@ // This implementation is based on: // https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py -use candle_core::{ +use either::Either; +use mcandle_core::{ shape::ShapeWithOneHole, DType, Device, IndexOp, Module, Result, Shape, Tensor, D, }; -use candle_nn::VarBuilder; -use either::Either; +use mcandle_nn::VarBuilder; use mistralrs_quant::{QuantMethod, QuantMethodConfig, QuantizedConfig, UnquantLinear}; use std::{any::Any, collections::HashMap, fmt::Debug, sync::Arc}; @@ -56,7 +56,7 @@ serde_default_fn!(bool, word_emb_default, false); #[derive(Debug, Clone, serde::Deserialize, Default)] pub struct Config { pub vocab_size: usize, - pub hidden_act: candle_nn::Activation, + pub hidden_act: mcandle_nn::Activation, pub hidden_size: usize, pub intermediate_size: usize, pub num_hidden_layers: usize, @@ -120,7 +120,7 @@ impl ModuleWithMetadata for QuantMethodWrapper { } } -impl ModuleWithMetadata for candle_nn::Activation { +impl ModuleWithMetadata for mcandle_nn::Activation { fn device(&self) -> Device { unreachable!() } @@ -134,10 +134,10 @@ struct BigShapeWithOneHole((usize, usize, usize, usize, usize, ())); fn hole_size(el_count: usize, prod_d: usize, s: &dyn std::fmt::Debug) -> Result { if prod_d == 0 { - candle_core::bail!("cannot reshape tensor of {el_count} elements to {s:?}") + mcandle_core::bail!("cannot reshape tensor of {el_count} elements to {s:?}") } if el_count % prod_d != 0 { - candle_core::bail!("cannot reshape tensor with {el_count} elements to {s:?}") + mcandle_core::bail!("cannot reshape tensor with {el_count} elements to {s:?}") } Ok(el_count / prod_d) } @@ -317,7 +317,7 @@ impl Attention { struct Mlp { gate_up_proj: Arc, down_proj: Arc, - act_fn: candle_nn::Activation, + act_fn: mcandle_nn::Activation, i_size: usize, params: Vec, } @@ -500,7 +500,7 @@ impl Module for EmbeddingLayers { #[derive(Debug)] pub struct ImageEmbedding { - wte: candle_nn::Embedding, + wte: mcandle_nn::Embedding, image_dim_out: usize, num_img_tokens: usize, glb_gn: Option, @@ -518,13 +518,13 @@ pub struct ImageEmbedding { impl ImageEmbedding { fn new( config: &Config, - wte: candle_nn::Embedding, + wte: mcandle_nn::Embedding, embed_config: &EmbedLayerConfig, vb: VarBuilder, ) -> Result { let hidden_size = config.hidden_size; if config.img_processor.name != "clip_vision_model" { - candle_core::bail!( + mcandle_core::bail!( "img_processor=`{}` nor supported.", config.img_processor.name ); @@ -615,7 +615,7 @@ impl ImageEmbedding { } vec![ Box::new(QuantMethodWrapper(a)), - Box::new(candle_nn::Activation::Gelu), + Box::new(mcandle_nn::Activation::Gelu), Box::new(QuantMethodWrapper(b)), ] } @@ -647,12 +647,12 @@ impl ImageEmbedding { } vec![ Box::new(QuantMethodWrapper(a)), - Box::new(candle_nn::Activation::Gelu), + Box::new(mcandle_nn::Activation::Gelu), Box::new(QuantMethodWrapper(b)), ] } _ => { - candle_core::bail!("projection_cls=`{projection_cls}` not implemented."); + mcandle_core::bail!("projection_cls=`{projection_cls}` not implemented."); } }; @@ -691,7 +691,7 @@ impl ImageEmbedding { } else if self.type_feature == "cls_patch" { Ok(img_feature) } else { - candle_core::bail!("Unsupported image feature type {}", self.type_feature) + mcandle_core::bail!("Unsupported image feature type {}", self.type_feature) } } @@ -818,7 +818,7 @@ impl ImageEmbedding { )?); } other => { - candle_core::bail!("Invalid hd_transform_order=`{other}`"); + mcandle_core::bail!("Invalid hd_transform_order=`{other}`"); } } @@ -924,7 +924,7 @@ impl ImageEmbedding { pub struct Model { vision_embed_tokens: ImageEmbedding, - embed_tokens: candle_nn::Embedding, + embed_tokens: mcandle_nn::Embedding, layers: Vec, norm: RmsNorm, lm_head: Arc, @@ -947,7 +947,7 @@ impl Model { let mapper = normal_loading_metadata.mapper; let vb_m = vb.pp("model"); - let embed_tokens = candle_nn::embedding( + let embed_tokens = mcandle_nn::embedding( cfg.vocab_size, cfg.hidden_size, mapper.set_nm_device(vb_m.pp("embed_tokens"), false), @@ -1017,7 +1017,7 @@ impl Model { )? } else { Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized( - candle_nn::Linear::new( + mcandle_nn::Linear::new( mapper.cast_nm_device( embed_tokens.embeddings(), normal_loading_metadata.loading_isq, diff --git a/mistralrs-core/src/vision_models/phi3_inputs_processor.rs b/mistralrs-core/src/vision_models/phi3_inputs_processor.rs index 7336df078..11bcbbbe3 100644 --- a/mistralrs-core/src/vision_models/phi3_inputs_processor.rs +++ b/mistralrs-core/src/vision_models/phi3_inputs_processor.rs @@ -2,9 +2,9 @@ use std::{any::Any, num::NonZeroUsize, sync::Arc}; -use candle_core::{Device, Result, Tensor}; use image::{imageops::FilterType, DynamicImage, GenericImage, GenericImageView, Rgba}; use itertools::Itertools; +use mcandle_core::{Device, Result, Tensor}; use mistralrs_vision::{ApplyTransforms, Normalize, ToTensor, Transforms}; use regex_automata::meta::Regex; use tokenizers::Tokenizer; @@ -543,7 +543,7 @@ impl ImagePreProcessor for Phi3InputsProcessor { num_img_tokens.push(num_image_tokens); } if padded_images.len() > 1 { - candle_core::bail!("Can only process one image per batch"); + mcandle_core::bail!("Can only process one image per batch"); } let image_sizes = image_sizes[0]; diff --git a/mistralrs-core/src/vision_models/preprocessor_config.rs b/mistralrs-core/src/vision_models/preprocessor_config.rs index 3b603b87c..04357c8e3 100644 --- a/mistralrs-core/src/vision_models/preprocessor_config.rs +++ b/mistralrs-core/src/vision_models/preprocessor_config.rs @@ -1,7 +1,7 @@ use std::collections::HashMap; -use candle_core::Result; use image::imageops::FilterType; +use mcandle_core::Result; use serde::Deserialize; #[derive(Deserialize, Debug, Clone)] @@ -45,7 +45,7 @@ impl ToFilter for Option { Some(2) | None => Ok(FilterType::Triangle), // BiLinear Some(3) => Ok(FilterType::CatmullRom), // BiCubic Some(4) => Ok(FilterType::Nearest), - Some(x) => candle_core::bail!("Filter number {x} not supported"), + Some(x) => mcandle_core::bail!("Filter number {x} not supported"), } } } diff --git a/mistralrs-core/src/vision_models/qwen2vl/inputs_processor.rs b/mistralrs-core/src/vision_models/qwen2vl/inputs_processor.rs index 40a8fde3e..d0336ba79 100644 --- a/mistralrs-core/src/vision_models/qwen2vl/inputs_processor.rs +++ b/mistralrs-core/src/vision_models/qwen2vl/inputs_processor.rs @@ -5,8 +5,8 @@ use std::{ }; use anyhow::Result; -use candle_core::{Context, Device, IndexOp, Tensor}; use image::{imageops::FilterType, DynamicImage, GenericImageView}; +use mcandle_core::{Context, Device, IndexOp, Tensor}; use mistralrs_vision::{ ApplyTensorTransforms, ApplyTransforms, Normalize, TensorTransforms, ToTensor, Transforms, }; @@ -499,16 +499,16 @@ impl Qwen2VLImageProcessor { factor: usize, min_pixels: usize, max_pixels: usize, - ) -> candle_core::Result<(usize, usize)> { + ) -> mcandle_core::Result<(usize, usize)> { if height < factor || width < factor { - candle_core::bail!( + mcandle_core::bail!( "height:{} or width:{} must be larger than factor:{}", height, width, factor ); } else if (height.max(width) as f64 / height.min(width) as f64) > 200.0 { - candle_core::bail!( + mcandle_core::bail!( "absolute aspect ratio must be smaller than 200, got {:.2}", height.max(width) as f64 / height.min(width) as f64 ); @@ -537,7 +537,7 @@ impl Qwen2VLImageProcessor { config: &PreProcessorConfig, device: &Device, (mut height, mut width): (u32, u32), - ) -> candle_core::Result<(Tensor, (u32, u32, u32))> { + ) -> mcandle_core::Result<(Tensor, (u32, u32, u32))> { let mut processed_images = Vec::new(); for mut image in images { @@ -638,7 +638,7 @@ impl ImagePreProcessor for Qwen2VLImageProcessor { config: &PreProcessorConfig, device: &Device, (_, _): (usize, usize), - ) -> candle_core::Result { + ) -> mcandle_core::Result { let mut pixel_values = Vec::new(); let mut vision_grid_thw = Vec::new(); diff --git a/mistralrs-core/src/vision_models/qwen2vl/mod.rs b/mistralrs-core/src/vision_models/qwen2vl/mod.rs index da88869ba..e36792016 100644 --- a/mistralrs-core/src/vision_models/qwen2vl/mod.rs +++ b/mistralrs-core/src/vision_models/qwen2vl/mod.rs @@ -2,8 +2,8 @@ use std::{any::Any, sync::Arc}; -use candle_core::{Context, DType, Device, IndexOp, Result, Tensor, D}; -use candle_nn::VarBuilder; +use mcandle_core::{Context, DType, Device, IndexOp, Result, Tensor, D}; +use mcandle_nn::VarBuilder; use mistralrs_quant::QuantMethod; use text::Qwen2VLTextModel; use vision::Qwen2VLVisionModel; @@ -46,7 +46,7 @@ impl Qwen2VLModel { ) -> Result { if cfg.use_sliding_window { // TODO! - candle_core::bail!("Sliding window is unsupported for now!"); + mcandle_core::bail!("Sliding window is unsupported for now!"); } let vision = Qwen2VLVisionModel::new( &cfg.vision_config, @@ -447,7 +447,7 @@ impl VisionModel for Qwen2VLModel { (None, Some(_)) => (None, pixel_values), (None, None) => (None, None), (Some(_), Some(_)) => { - candle_core::bail!("Images and videos cannot be provided together.") + mcandle_core::bail!("Images and videos cannot be provided together.") } }; self.forward( diff --git a/mistralrs-core/src/vision_models/qwen2vl/text.rs b/mistralrs-core/src/vision_models/qwen2vl/text.rs index 3d6de2a84..4b62b1c85 100644 --- a/mistralrs-core/src/vision_models/qwen2vl/text.rs +++ b/mistralrs-core/src/vision_models/qwen2vl/text.rs @@ -1,7 +1,7 @@ use std::{collections::HashMap, sync::Arc}; -use candle_core::{DType, Device, Result, Tensor}; -use candle_nn::{Embedding, Module, VarBuilder}; +use mcandle_core::{DType, Device, Result, Tensor}; +use mcandle_nn::{Embedding, Module, VarBuilder}; use mistralrs_quant::{QuantMethod, QuantMethodConfig, UnquantLinear}; use crate::{ @@ -292,12 +292,12 @@ impl Qwen2VLTextModel { attention_mechanism: AttentionImplementation, ) -> Result { if !matches!(attention_mechanism, AttentionImplementation::Eager) { - candle_core::bail!("Expected eager attention implementation"); + mcandle_core::bail!("Expected eager attention implementation"); } let mapper = normal_loading_metadata.mapper; let vb_m = vb.pp("model"); - let embed_tokens = candle_nn::embedding( + let embed_tokens = mcandle_nn::embedding( cfg.vocab_size, cfg.hidden_size, mapper.set_nm_device(vb_m.pp("embed_tokens"), false), @@ -356,7 +356,7 @@ impl Qwen2VLTextModel { )? } else { Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized( - candle_nn::Linear::new( + mcandle_nn::Linear::new( mapper.cast_nm_device( embed_tokens.embeddings(), normal_loading_metadata.loading_isq, diff --git a/mistralrs-core/src/vision_models/qwen2vl/vision.rs b/mistralrs-core/src/vision_models/qwen2vl/vision.rs index d049a987e..14851d679 100644 --- a/mistralrs-core/src/vision_models/qwen2vl/vision.rs +++ b/mistralrs-core/src/vision_models/qwen2vl/vision.rs @@ -1,7 +1,7 @@ use std::sync::Arc; -use candle_core::{DType, Device, IndexOp, Result, Tensor, D}; -use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; +use mcandle_core::{DType, Device, IndexOp, Result, Tensor, D}; +use mcandle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; use mistralrs_quant::QuantMethod; use crate::{ @@ -23,7 +23,7 @@ struct PatchEmbed { impl PatchEmbed { fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result { if cfg.temporal_patch_size != 2 { - candle_core::bail!("Only support temporal patch size of 2"); + mcandle_core::bail!("Only support temporal patch size of 2"); } Ok(Self { proj: Conv3dNoBias::new( @@ -143,7 +143,7 @@ impl VisionAttention { Some(m) => att.broadcast_add(m)?, None => att, }; - att = candle_nn::ops::softmax_last_dim(&att)?; + att = mcandle_nn::ops::softmax_last_dim(&att)?; att.matmul(&v)? .transpose(0, 1)? .reshape((seq_len, ()))? @@ -208,8 +208,8 @@ impl PatchMerger { vb: VarBuilder, ) -> Result { let hidden_size = context_dim * spatial_merge_size.pow(2); - let mlp0 = candle_nn::linear(hidden_size, hidden_size, vb.pp("mlp.0"))?; - let mlp2 = candle_nn::linear(hidden_size, dim, vb.pp("mlp.2"))?; + let mlp0 = mcandle_nn::linear(hidden_size, hidden_size, vb.pp("mlp.0"))?; + let mlp2 = mcandle_nn::linear(hidden_size, dim, vb.pp("mlp.2"))?; Ok(Self { ln_q: layer_norm(context_dim, 1e-6, vb.pp("ln_q"))?, mlp0, diff --git a/mistralrs-core/src/xlora_models/classifier.rs b/mistralrs-core/src/xlora_models/classifier.rs index 251e774bd..7439b3baf 100644 --- a/mistralrs-core/src/xlora_models/classifier.rs +++ b/mistralrs-core/src/xlora_models/classifier.rs @@ -1,5 +1,5 @@ -use candle_core::{DType, Device, Result, Tensor, D}; -use candle_nn::{ +use mcandle_core::{DType, Device, Result, Tensor, D}; +use mcandle_nn::{ activation, linear, linear_no_bias, ops::softmax_last_dim, Dropout, Linear, Module, ModuleT, VarBuilder, }; @@ -38,7 +38,7 @@ impl XLoraClassifier { is_quantized: bool, ) -> Result { if config.enable_softmax_topk { - candle_core::bail!("`enable_softmax_topk` is not implemented"); + mcandle_core::bail!("`enable_softmax_topk` is not implemented"); } let (last, inner): (Linear, Vec>) = if config.xlora_depth diff --git a/mistralrs-core/src/xlora_models/gemma.rs b/mistralrs-core/src/xlora_models/gemma.rs index 68ccac8ba..220f96c64 100644 --- a/mistralrs-core/src/xlora_models/gemma.rs +++ b/mistralrs-core/src/xlora_models/gemma.rs @@ -14,8 +14,8 @@ use crate::{ }, utils::progress::NiceProgressBar, }; -use candle_core::{DType, Device, Module, Result, Tensor}; -use candle_nn::{RotaryEmbedding, VarBuilder}; +use mcandle_core::{DType, Device, Module, Result, Tensor}; +use mcandle_nn::{RotaryEmbedding, VarBuilder}; use mistralrs_quant::QuantMethod; use tqdm::Iter; use tracing::info; @@ -436,7 +436,7 @@ impl DecoderLayer { } pub struct XLoraModel { - embed_tokens: candle_nn::Embedding, + embed_tokens: mcandle_nn::Embedding, layers: Vec, norm: RmsNorm, lm_head: Arc, @@ -472,7 +472,7 @@ impl XLoraModel { let mapper = normal_loading_metadata.mapper; let vb_m = vb.pp("model"); - let embed_tokens = candle_nn::embedding( + let embed_tokens = mcandle_nn::embedding( cfg.vocab_size, cfg.hidden_size, mapper.set_nm_device(vb_m.pp("embed_tokens"), false), @@ -568,7 +568,7 @@ impl XLoraModel { )?; if xlora_config.is_some() && lm_head.is_lora() { // This is why we can pass dummy values (..., None, 1.0, None)? - candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA."); + mcandle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA."); } Ok(Self { @@ -860,7 +860,7 @@ impl NormalModel for XLoraModel { } fn activate_adapters(&mut self, adapter_names: Vec) -> Result { if self.xlora_classifier.is_some() { - candle_core::bail!("Adapter activation is not supported for X-LoRA models as the adapter set must remain the same."); + mcandle_core::bail!("Adapter activation is not supported for X-LoRA models as the adapter set must remain the same."); } let mut sum = 0; for layer in self.layers.iter_mut() { diff --git a/mistralrs-core/src/xlora_models/gemma2.rs b/mistralrs-core/src/xlora_models/gemma2.rs index 040891063..abf0a9655 100644 --- a/mistralrs-core/src/xlora_models/gemma2.rs +++ b/mistralrs-core/src/xlora_models/gemma2.rs @@ -2,8 +2,8 @@ use std::{collections::HashMap, sync::Arc}; -use candle_core::{DType, Device, Module, Result, Tensor}; -use candle_nn::{RotaryEmbedding, VarBuilder}; +use mcandle_core::{DType, Device, Module, Result, Tensor}; +use mcandle_nn::{RotaryEmbedding, VarBuilder}; use mistralrs_quant::QuantMethod; use tqdm::Iter; use tracing::info; @@ -477,7 +477,7 @@ impl DecoderLayer { } pub struct Model { - embed_tokens: candle_nn::Embedding, + embed_tokens: mcandle_nn::Embedding, layers: Vec, norm: RmsNorm, lm_head: Arc, @@ -515,7 +515,7 @@ impl Model { let mapper = normal_loading_metadata.mapper; let vb_m = vb.pp("model"); - let embed_tokens = candle_nn::embedding( + let embed_tokens = mcandle_nn::embedding( cfg.vocab_size, cfg.hidden_size, mapper.set_nm_device(vb_m.pp("embed_tokens"), false), @@ -610,7 +610,7 @@ impl Model { )?; if xlora_config.is_some() && lm_head.is_lora() { // This is why we can pass dummy values (..., None, 1.0, None)? - candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA."); + mcandle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA."); } Ok(Self { @@ -927,7 +927,7 @@ impl NormalModel for Model { } fn activate_adapters(&mut self, adapter_names: Vec) -> Result { if self.xlora_classifier.is_some() { - candle_core::bail!("Adapter activation is not supported for X-LoRA models as the adapter set must remain the same."); + mcandle_core::bail!("Adapter activation is not supported for X-LoRA models as the adapter set must remain the same."); } let mut sum = 0; for layer in self.layers.iter_mut() { diff --git a/mistralrs-core/src/xlora_models/llama.rs b/mistralrs-core/src/xlora_models/llama.rs index d112cd15e..0fdead8bc 100644 --- a/mistralrs-core/src/xlora_models/llama.rs +++ b/mistralrs-core/src/xlora_models/llama.rs @@ -12,8 +12,8 @@ use crate::{ }, utils::progress::NiceProgressBar, }; -use candle_core::{DType, Device, Result, Tensor}; -use candle_nn::{embedding, Embedding, Module, VarBuilder}; +use mcandle_core::{DType, Device, Result, Tensor}; +use mcandle_nn::{embedding, Embedding, Module, VarBuilder}; use mistralrs_quant::QuantMethod; use std::{collections::HashMap, sync::Arc}; use tqdm::Iter; @@ -244,7 +244,7 @@ impl Mlp { if let Some(t) = self.c_fc1.quantized_act_type() { x = x.to_dtype(t)?; } - let x = (candle_nn::ops::silu(&self.c_fc1.lora_forward( + let x = (mcandle_nn::ops::silu(&self.c_fc1.lora_forward( &x, scalings.clone(), global_scaling_weight, @@ -617,7 +617,7 @@ impl XLoraLlama { )?; if xlora_config.is_some() && lm_head.is_lora() { // This is why we can pass dummy values (..., None, 1.0, None)? - candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA."); + mcandle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA."); } let ln_f = RmsNorm::new( cfg.hidden_size, @@ -821,7 +821,7 @@ impl NormalModel for XLoraLlama { } fn activate_adapters(&mut self, adapter_names: Vec) -> Result { if self.xlora_classifier.is_some() { - candle_core::bail!("Adapter activation is not supported for X-LoRA models as the adapter set must remain the same."); + mcandle_core::bail!("Adapter activation is not supported for X-LoRA models as the adapter set must remain the same."); } let mut sum = 0; for layer in self.blocks.iter_mut() { diff --git a/mistralrs-core/src/xlora_models/mistral.rs b/mistralrs-core/src/xlora_models/mistral.rs index e570717e9..d1b22d017 100644 --- a/mistralrs-core/src/xlora_models/mistral.rs +++ b/mistralrs-core/src/xlora_models/mistral.rs @@ -13,8 +13,8 @@ use crate::{ utils::progress::NiceProgressBar, }; /// Mistral LLM, https://github.com/mistralai/mistral-src -use candle_core::{DType, Device, Module, Result, Tensor}; -use candle_nn::{RotaryEmbedding, VarBuilder}; +use mcandle_core::{DType, Device, Module, Result, Tensor}; +use mcandle_nn::{RotaryEmbedding, VarBuilder}; use mistralrs_quant::QuantMethod; use std::{collections::HashMap, sync::Arc}; use tqdm::Iter; @@ -433,7 +433,7 @@ impl DecoderLayer { } pub struct XLoraModel { - embed_tokens: candle_nn::Embedding, + embed_tokens: mcandle_nn::Embedding, layers: Vec, norm: RmsNorm, lm_head: Arc, @@ -469,7 +469,7 @@ impl XLoraModel { let mapper = normal_loading_metadata.mapper; let vb_m = vb.pp("model"); - let embed_tokens = candle_nn::embedding( + let embed_tokens = mcandle_nn::embedding( cfg.vocab_size, cfg.hidden_size, mapper.set_nm_device(vb_m.pp("embed_tokens"), false), @@ -564,7 +564,7 @@ impl XLoraModel { )?; if xlora_config.is_some() && lm_head.is_lora() { // This is why we can pass dummy values (..., None, 1.0, None)? - candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA."); + mcandle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA."); } Ok(Self { embed_tokens, @@ -858,7 +858,7 @@ impl NormalModel for XLoraModel { } fn activate_adapters(&mut self, adapter_names: Vec) -> Result { if self.xlora_classifier.is_some() { - candle_core::bail!("Adapter activation is not supported for X-LoRA models as the adapter set must remain the same."); + mcandle_core::bail!("Adapter activation is not supported for X-LoRA models as the adapter set must remain the same."); } let mut sum = 0; for layer in self.layers.iter_mut() { diff --git a/mistralrs-core/src/xlora_models/mixtral.rs b/mistralrs-core/src/xlora_models/mixtral.rs index 2556be69a..3827de2c8 100644 --- a/mistralrs-core/src/xlora_models/mixtral.rs +++ b/mistralrs-core/src/xlora_models/mixtral.rs @@ -15,8 +15,8 @@ use crate::{ /// Mixtral Model /// https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py /// https://mistral.ai/news/mixtral-of-experts/ -use candle_core::{DType, Device, Module, Result, Tensor}; -use candle_nn::{RotaryEmbedding, VarBuilder}; +use mcandle_core::{DType, Device, Module, Result, Tensor}; +use mcandle_nn::{RotaryEmbedding, VarBuilder}; use mistralrs_quant::QuantMethod; use std::{collections::HashMap, sync::Arc}; use tqdm::Iter; @@ -411,7 +411,7 @@ impl SparseMoeBlock { router_logits = router_logits.to_dtype(original_dtype)?; } - let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?; + let routing_weights = mcandle_nn::ops::softmax_last_dim(&router_logits)?; // In order to extract topk, we extract the data from the tensor and manipulate it // directly. Maybe we will want to use some custom ops instead at some point. @@ -574,7 +574,7 @@ impl DecoderLayer { } pub struct XLoraModel { - embed_tokens: candle_nn::Embedding, + embed_tokens: mcandle_nn::Embedding, layers: Vec, norm: RmsNorm, lm_head: Arc, @@ -610,7 +610,7 @@ impl XLoraModel { let mapper = normal_loading_metadata.mapper; let vb_m = vb.pp("model"); - let embed_tokens = candle_nn::embedding( + let embed_tokens = mcandle_nn::embedding( cfg.vocab_size, cfg.hidden_size, mapper.set_nm_device(vb_m.pp("embed_tokens"), false), @@ -705,7 +705,7 @@ impl XLoraModel { )?; if xlora_config.is_some() && lm_head.is_lora() { // This is why we can pass dummy values (..., None, 1.0, None)? - candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA."); + mcandle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA."); } Ok(Self { embed_tokens, @@ -994,7 +994,7 @@ impl NormalModel for XLoraModel { } fn activate_adapters(&mut self, adapter_names: Vec) -> Result { if self.xlora_classifier.is_some() { - candle_core::bail!("Adapter activation is not supported for X-LoRA models as the adapter set must remain the same."); + mcandle_core::bail!("Adapter activation is not supported for X-LoRA models as the adapter set must remain the same."); } let mut sum = 0; for layer in self.layers.iter_mut() { diff --git a/mistralrs-core/src/xlora_models/mod.rs b/mistralrs-core/src/xlora_models/mod.rs index 1c82263fc..596b4b1bc 100644 --- a/mistralrs-core/src/xlora_models/mod.rs +++ b/mistralrs-core/src/xlora_models/mod.rs @@ -17,11 +17,11 @@ use crate::{ lora::Ordering, pipeline::{text_models_inputs_processor::FlashParams, EitherCache}, }; -use candle_core::{DType, Device, Result, Tensor}; pub(crate) use config::XLoraConfig; pub(crate) use gemma::XLoraModel as XLoraGemma; pub(crate) use gemma2::Model as XLoraGemma2; pub(crate) use llama::XLoraLlama; +use mcandle_core::{DType, Device, Result, Tensor}; pub(crate) use mistral::XLoraModel as XLoraMistral; pub(crate) use mixtral::XLoraModel as XLoraMixtral; pub(crate) use phi2::Model as XLoraPhi2; @@ -148,7 +148,7 @@ fn verify_sanity_adapters(ordering: &Ordering, supported_layers: &[&str]) -> Res } for path in ordering.layers.as_ref().unwrap().keys() { if !supported_layers.iter().any(|layer| path.ends_with(layer)) { - candle_core::bail!("Got a layer name `{path}` in the ordering, expected it to end with one of {supported_layers:?}"); + mcandle_core::bail!("Got a layer name `{path}` in the ordering, expected it to end with one of {supported_layers:?}"); } } Ok(()) diff --git a/mistralrs-core/src/xlora_models/phi2.rs b/mistralrs-core/src/xlora_models/phi2.rs index 38c9e5b07..aef085581 100644 --- a/mistralrs-core/src/xlora_models/phi2.rs +++ b/mistralrs-core/src/xlora_models/phi2.rs @@ -19,8 +19,8 @@ use crate::{ /// There is an alternative implementation of the phi model in mixformers.rs. /// This corresponds to the model update made with the following commit: /// https://huggingface.co/microsoft/phi-2/commit/cb2f4533604d8b67de604e7df03bfe6f3ca22869 -use candle_core::{DType, Device, Result, Tensor}; -use candle_nn::{embedding, layer_norm, Embedding, LayerNorm, RotaryEmbedding, VarBuilder}; +use mcandle_core::{DType, Device, Result, Tensor}; +use mcandle_nn::{embedding, layer_norm, Embedding, LayerNorm, RotaryEmbedding, VarBuilder}; use mistralrs_quant::QuantMethod; use tqdm::Iter; use tracing::info; @@ -530,7 +530,7 @@ impl Model { )?; if xlora_config.is_some() && lm_head.is_lora() { // This is why we can pass dummy values (..., None, 1.0, None)? - candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA."); + mcandle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA."); } Ok(Self { embed_tokens, @@ -809,7 +809,7 @@ impl NormalModel for Model { } fn activate_adapters(&mut self, adapter_names: Vec) -> Result { if self.xlora_classifier.is_some() { - candle_core::bail!("Adapter activation is not supported for X-LoRA models as the adapter set must remain the same."); + mcandle_core::bail!("Adapter activation is not supported for X-LoRA models as the adapter set must remain the same."); } let mut sum = 0; for layer in self.layers.iter_mut() { diff --git a/mistralrs-core/src/xlora_models/phi3.rs b/mistralrs-core/src/xlora_models/phi3.rs index 9ad19e215..1937bdada 100644 --- a/mistralrs-core/src/xlora_models/phi3.rs +++ b/mistralrs-core/src/xlora_models/phi3.rs @@ -14,8 +14,8 @@ use crate::{ }, utils::progress::NiceProgressBar, }; -use candle_core::{DType, Device, Module, Result, Tensor, D}; -use candle_nn::VarBuilder; +use mcandle_core::{DType, Device, Module, Result, Tensor, D}; +use mcandle_nn::VarBuilder; use mistralrs_quant::QuantMethod; use std::{collections::HashMap, sync::Arc}; use tqdm::Iter; @@ -379,7 +379,7 @@ impl DecoderLayer { } pub struct Model { - embed_tokens: candle_nn::Embedding, + embed_tokens: mcandle_nn::Embedding, layers: Vec, norm: RmsNorm, lm_head: Arc, @@ -415,7 +415,7 @@ impl Model { let mapper = normal_loading_metadata.mapper; let vb_m = vb.pp("model"); - let embed_tokens = candle_nn::embedding( + let embed_tokens = mcandle_nn::embedding( cfg.vocab_size, cfg.hidden_size, mapper.set_nm_device(vb_m.pp("embed_tokens"), false), @@ -493,7 +493,7 @@ impl Model { )?; if xlora_config.is_some() && lm_head.is_lora() { // This is why we can pass dummy values (..., None, 1.0, None)? - candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA."); + mcandle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA."); } Ok(Self { embed_tokens, @@ -774,7 +774,7 @@ impl NormalModel for Model { } fn activate_adapters(&mut self, adapter_names: Vec) -> Result { if self.xlora_classifier.is_some() { - candle_core::bail!("Adapter activation is not supported for X-LoRA models as the adapter set must remain the same."); + mcandle_core::bail!("Adapter activation is not supported for X-LoRA models as the adapter set must remain the same."); } let mut sum = 0; for layer in self.layers.iter_mut() { diff --git a/mistralrs-core/src/xlora_models/quantized_llama.rs b/mistralrs-core/src/xlora_models/quantized_llama.rs index affd1a4f7..d8d3acd2c 100644 --- a/mistralrs-core/src/xlora_models/quantized_llama.rs +++ b/mistralrs-core/src/xlora_models/quantized_llama.rs @@ -10,10 +10,10 @@ use crate::lora::{ }; use crate::pipeline::text_models_inputs_processor::FlashParams; use crate::utils::progress::NiceProgressBar; -use candle_core::quantized::ggml_file; -use candle_core::quantized::QMatMul; -use candle_core::{DType, Device, Result, Tensor}; -use candle_nn::{Embedding, Module, RotaryEmbedding, VarBuilder}; +use mcandle_core::quantized::ggml_file; +use mcandle_core::quantized::QMatMul; +use mcandle_core::{DType, Device, Result, Tensor}; +use mcandle_nn::{Embedding, Module, RotaryEmbedding, VarBuilder}; use tqdm::Iter; use tracing::info; @@ -68,7 +68,7 @@ impl Mlp { is_scaling_pass, )?; self.feed_forward_w2.lora_forward( - &(candle_nn::ops::silu(&w1)? * w3)?, + &(mcandle_nn::ops::silu(&w1)? * w3)?, scalings.clone(), global_scaling_weight, is_scaling_pass, @@ -103,7 +103,7 @@ impl MlpOrMoe { let (b_size, seq_len, hidden_dim) = xs.dims3()?; let xs = xs.reshape(((), hidden_dim))?; let router_logits = MatMul.qmatmul(&xs, feed_forward_gate_inp)?; - let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?; + let routing_weights = mcandle_nn::ops::softmax_last_dim(&router_logits)?; // In order to extract topk, we extract the data from the tensor and manipulate it // directly. Maybe we will want to use some custom ops instead at some point. @@ -461,7 +461,7 @@ impl ModelConfig::FromAdapterGGML for ModelWeights { )?; if xlora_config.is_some() && output.is_lora() { // This is why we can pass dummy values (..., None, 1.0, None)? - candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA."); + mcandle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA."); } Ok(Self { tok_embeddings: Embedding::new(tok_embeddings, ct.hparams.n_embd as usize), @@ -513,11 +513,11 @@ impl ModelConfig::FromAdapterGGUF for ModelWeights { rope_freq_base, key_length, value_length, - } = PropsGGUF::try_from(metadata).or_else(|err| candle_core::bail!("{err}"))?; + } = PropsGGUF::try_from(metadata).or_else(|err| mcandle_core::bail!("{err}"))?; let head_dim = key_length; if key_length != value_length { - candle_core::bail!( + mcandle_core::bail!( "Expected key_length == value_length, got {key_length} != {value_length}" ); } @@ -759,7 +759,7 @@ impl ModelConfig::FromAdapterGGUF for ModelWeights { )?; if xlora_config.is_some() && output.is_lora() { // This is why we can pass dummy values (..., None, 1.0, None)? - candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA."); + mcandle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA."); } Ok(Self { tok_embeddings: Embedding::new(tok_embeddings, embedding_length), @@ -781,7 +781,7 @@ impl ModelConfig::FromAdapterGGUF for ModelWeights { impl ModelWeights { pub fn activate_adapters(&mut self, adapter_names: Vec) -> Result { if self.xlora_classifier.is_some() { - candle_core::bail!("Adapter activation is not supported for X-LoRA models as the adapter set must remain the same."); + mcandle_core::bail!("Adapter activation is not supported for X-LoRA models as the adapter set must remain the same."); } let mut sum = 0; for layer in self.layers.iter_mut() { diff --git a/mistralrs-core/src/xlora_models/quantized_phi3.rs b/mistralrs-core/src/xlora_models/quantized_phi3.rs index 62b5d06ac..c7f81ebbc 100644 --- a/mistralrs-core/src/xlora_models/quantized_phi3.rs +++ b/mistralrs-core/src/xlora_models/quantized_phi3.rs @@ -21,11 +21,11 @@ use crate::pipeline::EitherCache; use crate::utils::progress::NiceProgressBar; use crate::DeviceMapMetadata; use crate::Topology; -use candle_core::quantized::QMatMul; -use candle_core::quantized::QTensor; -use candle_core::{DType, Device, IndexOp, Module, Result, Tensor, D}; -use candle_nn::Embedding; -use candle_nn::VarBuilder; +use mcandle_core::quantized::QMatMul; +use mcandle_core::quantized::QTensor; +use mcandle_core::{DType, Device, IndexOp, Module, Result, Tensor, D}; +use mcandle_nn::Embedding; +use mcandle_nn::VarBuilder; use tqdm::Iter; use tracing::info; @@ -104,7 +104,7 @@ impl LayerWeights { for (i, offset) in seqlen_offsets.iter().enumerate() { let cos = self.cos.narrow(0, *offset, seq_len)?; let sin = self.sin.narrow(0, *offset, seq_len)?; - outputs.push(candle_nn::rotary_emb::rope( + outputs.push(mcandle_nn::rotary_emb::rope( &xs.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin, @@ -250,7 +250,7 @@ impl ModelConfig::FromAdapterGGUF for ModelWeights { rope_dim, rms_eps, context_window, - } = PropsGGUF::try_from(metadata).or_else(|err| candle_core::bail!("{err}"))?; + } = PropsGGUF::try_from(metadata).or_else(|err| mcandle_core::bail!("{err}"))?; let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device, context_window)?; @@ -368,7 +368,7 @@ impl ModelConfig::FromAdapterGGUF for ModelWeights { )?; if xlora_config.is_some() && output.is_lora() { // This is why we can pass dummy values (..., None, 1.0, None)? - candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA."); + mcandle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA."); } Ok(Self { tok_embeddings: Embedding::new(tok_embeddings, embedding_length), @@ -390,7 +390,7 @@ impl ModelConfig::FromAdapterGGUF for ModelWeights { impl ModelWeights { pub fn activate_adapters(&mut self, adapter_names: Vec) -> Result { if self.xlora_classifier.is_some() { - candle_core::bail!("Adapter activation is not supported for X-LoRA models as the adapter set must remain the same."); + mcandle_core::bail!("Adapter activation is not supported for X-LoRA models as the adapter set must remain the same."); } let mut sum = 0; for layer in self.layers.iter_mut() { diff --git a/mistralrs-core/src/xlora_models/starcoder2.rs b/mistralrs-core/src/xlora_models/starcoder2.rs index 2dbc6c437..c2d04efed 100644 --- a/mistralrs-core/src/xlora_models/starcoder2.rs +++ b/mistralrs-core/src/xlora_models/starcoder2.rs @@ -1,7 +1,7 @@ #![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] -use candle_core::{DType, Device, Module, Result, Tensor}; -use candle_nn::{layer_norm, LayerNorm, VarBuilder}; +use mcandle_core::{DType, Device, Module, Result, Tensor}; +use mcandle_nn::{layer_norm, LayerNorm, VarBuilder}; use mistralrs_quant::QuantMethod; use std::{collections::HashMap, sync::Arc}; use tqdm::Iter; @@ -421,7 +421,7 @@ impl DecoderLayer { } pub struct Model { - embed_tokens: candle_nn::Embedding, + embed_tokens: mcandle_nn::Embedding, layers: Vec, norm: LayerNorm, lm_head: Arc, @@ -457,7 +457,7 @@ impl Model { let mapper = normal_loading_metadata.mapper; let vb_m = vb.pp("model"); - let embed_tokens = candle_nn::embedding( + let embed_tokens = mcandle_nn::embedding( cfg.vocab_size, cfg.hidden_size, mapper.set_nm_device(vb_m.pp("embed_tokens"), false), @@ -546,7 +546,7 @@ impl Model { )?; if xlora_config.is_some() && lm_head.is_lora() { // This is why we can pass dummy values (..., None, 1.0, None)? - candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA."); + mcandle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA."); } Ok(Self { embed_tokens, @@ -834,7 +834,7 @@ impl NormalModel for Model { } fn activate_adapters(&mut self, adapter_names: Vec) -> Result { if self.xlora_classifier.is_some() { - candle_core::bail!("Adapter activation is not supported for X-LoRA models as the adapter set must remain the same."); + mcandle_core::bail!("Adapter activation is not supported for X-LoRA models as the adapter set must remain the same."); } let mut sum = 0; for layer in self.layers.iter_mut() { diff --git a/mistralrs-paged-attn/Cargo.toml b/mistralrs-paged-attn/Cargo.toml index 5e16e57a0..f48da9ac1 100644 --- a/mistralrs-paged-attn/Cargo.toml +++ b/mistralrs-paged-attn/Cargo.toml @@ -12,7 +12,7 @@ license.workspace = true homepage.workspace = true [dependencies] -candle-core.workspace = true +mcandle-core.workspace = true half.workspace = true float8.workspace = true diff --git a/mistralrs-paged-attn/src/backend/cache.rs b/mistralrs-paged-attn/src/backend/cache.rs index f9c2152ba..13e99aecf 100644 --- a/mistralrs-paged-attn/src/backend/cache.rs +++ b/mistralrs-paged-attn/src/backend/cache.rs @@ -2,11 +2,11 @@ use std::{collections::HashMap, iter::zip, ptr::NonNull}; use crate::backend::get_or_load_func; -use candle_core::cuda::cudarc::driver::LaunchAsync; -use candle_core::cuda::WrapErr; -use candle_core::cuda_backend::CudaStorageSlice; -use candle_core::Result; -use candle_core::{ +use mcandle_core::cuda::cudarc::driver::LaunchAsync; +use mcandle_core::cuda::WrapErr; +use mcandle_core::cuda_backend::CudaStorageSlice; +use mcandle_core::Result; +use mcandle_core::{ cuda_backend::cudarc::driver::{CudaSlice, DevicePtr, LaunchConfig}, Device, IndexOp, Storage, Tensor, }; @@ -24,14 +24,14 @@ pub fn copy_blocks( panic!("Expected the key caches to be on a CUDA device.") }; if !cache_dev.same_device(value_caches.first().unwrap().device()) { - candle_core::bail!( + mcandle_core::bail!( "`key` and `value` caches have different devices, got {:?} and {:?} respectively.", cache_dev, value_caches.first().unwrap().device() ); } if key_caches.first().unwrap().dtype() != value_caches.first().unwrap().dtype() { - candle_core::bail!( + mcandle_core::bail!( "Key and value caches have different types, got {:?} and {:?}.", key_caches.first().unwrap().dtype(), value_caches.first().unwrap().dtype() @@ -87,7 +87,7 @@ pub fn copy_blocks( (ptr_key, ptr_value) } _ => { - candle_core::bail!("only f32, f16 and bf16 input data type supported!",); + mcandle_core::bail!("only f32, f16 and bf16 input data type supported!",); } }; key_cache_ptrs.push(key_ptr + key_offset); @@ -171,7 +171,7 @@ pub unsafe fn swap_blocks( match (src.device(), dst.device()) { (Device::Cuda(src_dev), Device::Cuda(dst_dev)) => { if src_dev.ordinal() != dst_dev.ordinal() { - candle_core::bail!("Tensors must be on the same device to copy, got ordinals {} (src) and {} (dst).", src_dev.ordinal(), dst_dev.ordinal()); + mcandle_core::bail!("Tensors must be on the same device to copy, got ordinals {} (src) and {} (dst).", src_dev.ordinal(), dst_dev.ordinal()); } let (src_storage, src_layout) = src.storage_and_layout(); let (dst_storage, dst_layout) = dst.storage_and_layout(); @@ -200,7 +200,7 @@ pub unsafe fn swap_blocks( (ptr_src, ptr_dst) } _ => { - candle_core::bail!("only f32, f16 and bf16 input data type supported!") + mcandle_core::bail!("only f32, f16 and bf16 input data type supported!") } }; @@ -250,7 +250,7 @@ pub unsafe fn swap_blocks( } } (src, dst) => { - candle_core::bail!("Tensors must be on either the GPU or CPU to swap, got {src:?} (src) and {dst:?} (dst)."); + mcandle_core::bail!("Tensors must be on either the GPU or CPU to swap, got {src:?} (src) and {dst:?} (dst)."); } } diff --git a/mistralrs-paged-attn/src/backend/mod.rs b/mistralrs-paged-attn/src/backend/mod.rs index 579caf44d..4b0310d68 100644 --- a/mistralrs-paged-attn/src/backend/mod.rs +++ b/mistralrs-paged-attn/src/backend/mod.rs @@ -6,7 +6,7 @@ use std::{ ptr::{addr_of, NonNull}, }; -use candle_core::{ +use mcandle_core::{ cuda::cudarc::driver::DeviceRepr, cuda_backend::cudarc::driver::CudaFunction, CudaDevice, DType, Result, }; diff --git a/mistralrs-paged-attn/src/backend/paged_attention.rs b/mistralrs-paged-attn/src/backend/paged_attention.rs index 07b31b67b..da6eea8ee 100644 --- a/mistralrs-paged-attn/src/backend/paged_attention.rs +++ b/mistralrs-paged-attn/src/backend/paged_attention.rs @@ -1,11 +1,11 @@ use crate::ffi; use crate::ffi::{paged_attention_v1, paged_attention_v2}; -use candle::backend::BackendStorage; -use candle::cuda_backend::cudarc::driver::DevicePtr; -use candle::cuda_backend::WrapErr; -use candle::{CpuStorage, CudaStorage, DType, Layout, Result, Shape, Storage, Tensor}; -use candle_core as candle; use half::{bf16, f16}; +use mcandle::backend::BackendStorage; +use mcandle::cuda_backend::cudarc::driver::DevicePtr; +use mcandle::cuda_backend::WrapErr; +use mcandle::{CpuStorage, CudaStorage, DType, Layout, Result, Shape, Storage, Tensor}; +use mcandle_core as mcandle; use std::ffi::c_int; struct PagedAttention { @@ -21,7 +21,7 @@ struct PagedAttention { impl PagedAttention { fn cuda_fwd_t< - T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr, + T: mcandle::cuda_backend::CudaDType + mcandle::cuda_backend::cudarc::driver::DeviceRepr, >( &self, q: &CudaStorage, @@ -32,7 +32,7 @@ impl PagedAttention { DType::F16 => 0, DType::BF16 => 1, DType::F32 => 2, - dtype => candle::bail!("dtype {dtype:?} is not supported"), + dtype => mcandle::bail!("dtype {dtype:?} is not supported"), }; let dev = q.device(); @@ -41,25 +41,25 @@ impl PagedAttention { let (kc, kc_l) = self.key_cache.storage_and_layout(); let kc = match &*kc { Storage::Cuda(kc) => kc, - _ => candle::bail!("key_cache must be a cuda tensor"), + _ => mcandle::bail!("key_cache must be a cuda tensor"), }; let (vc, vc_l) = self.value_cache.storage_and_layout(); let vc = match &*vc { Storage::Cuda(vc) => vc, - _ => candle::bail!("value_cache must be a cuda tensor"), + _ => mcandle::bail!("value_cache must be a cuda tensor"), }; let (bt, bt_l) = self.block_tables.storage_and_layout(); let bt = match &*bt { Storage::Cuda(bt) => bt, - _ => candle::bail!("block_tables must be a cuda tensor"), + _ => mcandle::bail!("block_tables must be a cuda tensor"), }; let (cl, cl_l) = self.context_lens.storage_and_layout(); let cl = match &*cl { Storage::Cuda(cl) => cl, - _ => candle::bail!("context_lens must be a cuda tensor"), + _ => mcandle::bail!("context_lens must be a cuda tensor"), }; let q_rank = q_l.stride().len(); @@ -67,21 +67,21 @@ impl PagedAttention { let vc_rank = vc_l.stride().len(); if q_rank != 3 { - candle::bail!( + mcandle::bail!( "paged-attention expects `q` tensor to be of rank 3 \ (q: {q_l:?})" ) } if kc_rank != 5 { - candle::bail!( + mcandle::bail!( "paged-attention expects `key_cache` tensor to be of rank 5 \ (key_cache: {kc_l:?})" ) } if vc_rank != 4 { - candle::bail!( + mcandle::bail!( "paged-attention expects `value_cache` tensor to be of rank 4 \ (value_cache: {vc_l:?})" ) @@ -109,13 +109,13 @@ impl PagedAttention { || head_size == 128 || head_size == 256) { - candle::bail!("`head_size` must be one of 64, 80, 96, 112, 128 or 256"); + mcandle::bail!("`head_size` must be one of 64, 80, 96, 112, 128 or 256"); } let (num_seqs_bt, max_num_blocks_per_seq) = bt_l.shape().dims2()?; if num_seqs_bt != num_seqs { - candle::bail!( + mcandle::bail!( "shape mismatch block_tables {:?}, expected {:?}", bt_l.shape(), (num_seqs, max_num_blocks_per_seq) @@ -124,7 +124,7 @@ impl PagedAttention { let (num_blocks, num_kv_heads, head_size_kc, block_size, x) = kc_l.shape().dims5()?; if head_size_kc != head_size / x { - candle::bail!( + mcandle::bail!( "shape mismatch value_cache {:?}, expected {:?}", vc_l.shape(), (num_blocks, num_heads, head_size / x, block_size, x) @@ -132,7 +132,7 @@ impl PagedAttention { } if (num_blocks, num_kv_heads, head_size, block_size) != vc_l.shape().dims4()? { - candle::bail!( + mcandle::bail!( "shape mismatch key_cache {:?} and value_cache {:?}", kc_l.shape(), vc_l.shape() @@ -140,7 +140,7 @@ impl PagedAttention { } if (num_seqs) != cl_l.shape().dims1()? { - candle::bail!( + mcandle::bail!( "shape mismatch context_lens {:?}, expected {:?}", cl_l.shape(), (num_seqs) @@ -234,13 +234,13 @@ impl PagedAttention { } } -impl candle::CustomOp1 for PagedAttention { +impl mcandle::CustomOp1 for PagedAttention { fn name(&self) -> &'static str { "paged-attention" } fn cpu_fwd(&self, _: &CpuStorage, _: &Layout) -> Result<(CpuStorage, Shape)> { - candle::bail!("no cpu support for paged-attention") + mcandle::bail!("no cpu support for paged-attention") } fn cuda_fwd(&self, q: &CudaStorage, q_l: &Layout) -> Result<(CudaStorage, Shape)> { @@ -248,7 +248,7 @@ impl candle::CustomOp1 for PagedAttention { DType::F32 => self.cuda_fwd_t::(q, q_l), DType::F16 => self.cuda_fwd_t::(q, q_l), DType::BF16 => self.cuda_fwd_t::(q, q_l), - dt => candle::bail!("paged-attention is only supported for f32/f16/bf16 ({dt:?})"), + dt => mcandle::bail!("paged-attention is only supported for f32/f16/bf16 ({dt:?})"), } } } @@ -296,7 +296,7 @@ pub fn paged_attention( } fn update_cache< - T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr, + T: mcandle::cuda_backend::CudaDType + mcandle::cuda_backend::cudarc::driver::DeviceRepr, >( key: &Tensor, value: &Tensor, @@ -310,37 +310,37 @@ fn update_cache< DType::F16 => 0, DType::BF16 => 1, DType::F32 => 2, - dtype => candle::bail!("dtype {dtype:?} is not supported"), + dtype => mcandle::bail!("dtype {dtype:?} is not supported"), }; let (k, k_l) = key.storage_and_layout(); let k = match &*k { Storage::Cuda(k) => k, - _ => candle::bail!("key must be a cuda tensor"), + _ => mcandle::bail!("key must be a cuda tensor"), }; let (v, v_l) = value.storage_and_layout(); let v = match &*v { Storage::Cuda(v) => v, - _ => candle::bail!("value must be a cuda tensor"), + _ => mcandle::bail!("value must be a cuda tensor"), }; let (kc, kc_l) = key_cache.storage_and_layout(); let kc = match &*kc { Storage::Cuda(kc) => kc, - _ => candle::bail!("key_cache must be a cuda tensor"), + _ => mcandle::bail!("key_cache must be a cuda tensor"), }; let (vc, vc_l) = value_cache.storage_and_layout(); let vc = match &*vc { Storage::Cuda(vc) => vc, - _ => candle::bail!("value_cache must be a cuda tensor"), + _ => mcandle::bail!("value_cache must be a cuda tensor"), }; let (s, s_l) = slot_mapping.storage_and_layout(); let s = match &*s { Storage::Cuda(s) => s, - _ => candle::bail!("slot_mapping must be a cuda tensor"), + _ => mcandle::bail!("slot_mapping must be a cuda tensor"), }; let k_rank = k_l.stride().len(); @@ -349,18 +349,18 @@ fn update_cache< let vc_rank = vc_l.stride().len(); if k_rank != 3 || v_rank != 3 { - candle::bail!("paged-attention expects input tensors of rank 3 (k: {k_l:?}, v: {v_l:?})") + mcandle::bail!("paged-attention expects input tensors of rank 3 (k: {k_l:?}, v: {v_l:?})") } if kc_rank != 5 { - candle::bail!( + mcandle::bail!( "paged-attention expects `key_cache` tensor to be of rank 5 \ (key_cache: {kc_l:?})" ) } if vc_rank != 4 { - candle::bail!( + mcandle::bail!( "paged-attention expects `value_cache` tensor to be of rank 4 \ (value_cache: {vc_l:?})" ) @@ -382,12 +382,12 @@ fn update_cache< let (num_tokens, num_heads, head_size) = k_l.shape().dims3()?; if (num_tokens, num_heads, head_size) != v_l.shape().dims3()? { - candle::bail!("shape mismatch k {:?} and v {:?}", k_l.shape(), v_l.shape()) + mcandle::bail!("shape mismatch k {:?} and v {:?}", k_l.shape(), v_l.shape()) } let (num_blocks, num_heads_kc, head_size_kc, block_size, x) = kc_l.shape().dims5()?; if num_heads_kc != num_heads || head_size_kc != head_size / x { - candle::bail!( + mcandle::bail!( "shape mismatch value_cache {:?}, expected {:?}", vc_l.shape(), (num_blocks, num_heads, head_size / x, block_size, x) @@ -395,7 +395,7 @@ fn update_cache< } if (num_blocks, num_heads, head_size, block_size) != vc_l.shape().dims4()? { - candle::bail!( + mcandle::bail!( "shape mismatch key_cache {:?} and value_cache {:?}", kc_l.shape(), vc_l.shape() @@ -403,7 +403,7 @@ fn update_cache< } if (num_tokens) != s_l.shape().dims1()? { - candle::bail!( + mcandle::bail!( "shape mismatch slot_mapping {:?}, expected {:?}", s_l.shape(), (num_tokens) @@ -461,7 +461,7 @@ pub fn reshape_and_cache( DType::BF16 => update_cache::(key, value, key_cache, value_cache, slot_mapping), DType::F32 => update_cache::(key, value, key_cache, value_cache, slot_mapping), dt => { - candle::bail!("reshape_and_cache is only supported for f32, f16 and bf16 ({dt:?})") + mcandle::bail!("reshape_and_cache is only supported for f32, f16 and bf16 ({dt:?})") } } } diff --git a/mistralrs-pyo3/Cargo.toml b/mistralrs-pyo3/Cargo.toml index d4e172940..3b297b002 100644 --- a/mistralrs-pyo3/Cargo.toml +++ b/mistralrs-pyo3/Cargo.toml @@ -20,7 +20,7 @@ pyo3.workspace = true mistralrs-core = { version = "0.3.2", path = "../mistralrs-core", features = ["pyo3_macros"] } serde.workspace = true serde_json.workspace = true -candle-core.workspace = true +mcandle-core.workspace = true indexmap.workspace = true accelerate-src = { workspace = true, optional = true } intel-mkl-src = { workspace = true, optional = true } @@ -38,9 +38,9 @@ anyhow.workspace = true pyo3-build-config = "0.22" [features] -cuda = ["candle-core/cuda", "mistralrs-core/cuda"] -cudnn = ["candle-core/cudnn", "mistralrs-core/cudnn"] -metal = ["candle-core/metal", "mistralrs-core/metal"] +cuda = ["mcandle-core/cuda", "mistralrs-core/cuda"] +cudnn = ["mcandle-core/cudnn", "mistralrs-core/cudnn"] +metal = ["mcandle-core/metal", "mistralrs-core/metal"] flash-attn = ["cuda", "mistralrs-core/flash-attn"] accelerate = ["mistralrs-core/accelerate"] mkl = ["mistralrs-core/mkl"] diff --git a/mistralrs-pyo3/src/lib.rs b/mistralrs-pyo3/src/lib.rs index aea3f8ec8..969f197a7 100644 --- a/mistralrs-pyo3/src/lib.rs +++ b/mistralrs-pyo3/src/lib.rs @@ -17,7 +17,7 @@ use stream::ChatCompletionStreamer; use tokio::sync::mpsc::channel; use util::{PyApiErr, PyApiResult}; -use candle_core::{Device, Result}; +use mcandle_core::{Device, Result}; use mistralrs_core::{ initialize_logging, paged_attn_supported, parse_isq_value, AnyMoeLoader, ChatCompletionResponse, CompletionResponse, Constraint, DefaultSchedulerMethod, diff --git a/mistralrs-pyo3/src/util.rs b/mistralrs-pyo3/src/util.rs index 7086223f3..290782375 100644 --- a/mistralrs-pyo3/src/util.rs +++ b/mistralrs-pyo3/src/util.rs @@ -42,8 +42,8 @@ impl From for PyApiErr { } } -impl From<&candle_core::Error> for PyApiErr { - fn from(value: &candle_core::Error) -> Self { +impl From<&mcandle_core::Error> for PyApiErr { + fn from(value: &mcandle_core::Error) -> Self { Self::from(value.to_string()) } } diff --git a/mistralrs-quant/Cargo.toml b/mistralrs-quant/Cargo.toml index e56f0a181..1d7f15c76 100644 --- a/mistralrs-quant/Cargo.toml +++ b/mistralrs-quant/Cargo.toml @@ -12,8 +12,8 @@ license.workspace = true homepage.workspace = true [dependencies] -candle-core.workspace = true -candle-nn.workspace = true +mcandle-core.workspace = true +mcandle-nn.workspace = true half.workspace = true serde.workspace = true lazy_static = "1.4" @@ -27,8 +27,8 @@ metal = { version = "0.27.0", features = ["mps"], optional = true } thiserror = "1" [features] -cuda = ["candle-core/cuda", "candle-nn/cuda", "dep:bindgen_cuda"] -metal = ["candle-core/metal", "candle-nn/metal", "dep:metal"] +cuda = ["mcandle-core/cuda", "mcandle-nn/cuda", "dep:bindgen_cuda"] +metal = ["mcandle-core/metal", "mcandle-nn/metal", "dep:metal"] [build-dependencies] bindgen_cuda = { version = "0.1.5", optional = true } diff --git a/mistralrs-quant/src/cublaslt/api.rs b/mistralrs-quant/src/cublaslt/api.rs index 1fb0cea43..1b8726e9a 100644 --- a/mistralrs-quant/src/cublaslt/api.rs +++ b/mistralrs-quant/src/cublaslt/api.rs @@ -1,11 +1,11 @@ -use candle_core::cuda::cudarc::driver::DevicePtr; use float8::F8E4M3; +use mcandle_core::cuda::cudarc::driver::DevicePtr; use std::ffi::c_int; -use candle_core::backend::BackendStorage; -use candle_core::cuda_backend::WrapErr; -use candle_core::{CpuStorage, DType, Device, Layout, Result, Shape, Storage, Tensor}; use half::{bf16, f16}; +use mcandle_core::backend::BackendStorage; +use mcandle_core::cuda_backend::WrapErr; +use mcandle_core::{CpuStorage, DType, Device, Layout, Result, Shape, Storage, Tensor}; use std::sync::Arc; use super::matmul::{Activation, CudaBlasLT, Matmul, MatmulConfig, OutSlice}; @@ -18,7 +18,7 @@ impl CublasLt { pub fn new(device: &Device) -> Result { let dev = match device { Device::Cuda(d) => d, - _ => candle_core::bail!("`device` must be a `cuda` device"), + _ => mcandle_core::bail!("`device` must be a `cuda` device"), }; let inner = CudaBlasLT::new(dev.cuda_device()).unwrap(); @@ -44,13 +44,13 @@ pub struct CublasLTBatchMatmulF8 { impl CublasLTBatchMatmulF8 { pub fn fwd_f8e4m3( &self, - a: &candle_core::CudaStorage, + a: &mcandle_core::CudaStorage, a_l: &Layout, - b: &candle_core::CudaStorage, + b: &mcandle_core::CudaStorage, b_l: &Layout, - bias: Option<&candle_core::CudaStorage>, + bias: Option<&mcandle_core::CudaStorage>, bias_l: Option<&Layout>, - ) -> Result<(candle_core::CudaStorage, Shape)> { + ) -> Result<(mcandle_core::CudaStorage, Shape)> { let dev = a.device(); // Assume TN @@ -58,21 +58,21 @@ impl CublasLTBatchMatmulF8 { let (b_0, n, b_2) = b_l.shape().dims3()?; if b_2 != k { - candle_core::bail!("This layer only supports TN layout"); + mcandle_core::bail!("This layer only supports TN layout"); } if b_0 != batch_size { - candle_core::bail!("`b` must have the same batch size as `a`") + mcandle_core::bail!("`b` must have the same batch size as `a`") } if !self.a_scale.dims().is_empty() || self.a_scale.dtype() != DType::F32 { - candle_core::bail!("`a_scale` must be a f32 scalar."); + mcandle_core::bail!("`a_scale` must be a f32 scalar."); } if !self.b_scale.dims().is_empty() || self.b_scale.dtype() != DType::F32 { - candle_core::bail!("`b_scale` must be a f32 scalar."); + mcandle_core::bail!("`b_scale` must be a f32 scalar."); } if !self.d_scale.dims().is_empty() || self.d_scale.dtype() != DType::F32 { - candle_core::bail!("`d_scale` must be a f32 scalar."); + mcandle_core::bail!("`d_scale` must be a f32 scalar."); } let (a_s, _) = self.a_scale.storage_and_layout(); let (b_s, _) = self.b_scale.storage_and_layout(); @@ -80,15 +80,15 @@ impl CublasLTBatchMatmulF8 { let a_scale = match &*a_s { Storage::Cuda(scale) => scale.as_cuda_slice::()?, - _ => candle_core::bail!("`a_scale` must be a cuda tensor"), + _ => mcandle_core::bail!("`a_scale` must be a cuda tensor"), }; let b_scale = match &*b_s { Storage::Cuda(scale) => scale.as_cuda_slice::()?, - _ => candle_core::bail!("`b_scale` must be a cuda tensor"), + _ => mcandle_core::bail!("`b_scale` must be a cuda tensor"), }; let d_scale = match &*d_s { Storage::Cuda(scale) => scale.as_cuda_slice::()?, - _ => candle_core::bail!("`d_scale` must be a cuda tensor"), + _ => mcandle_core::bail!("`d_scale` must be a cuda tensor"), }; let lda = k; @@ -103,7 +103,7 @@ impl CublasLTBatchMatmulF8 { let (bias, bias_stride) = if let (Some(bias), Some(bias_l)) = (bias, bias_l) { if bias_l.dims().len() == 1 { if bias_l.shape().dims1()? != m { - candle_core::bail!("Bias does not have the correct shape"); + mcandle_core::bail!("Bias does not have the correct shape"); } ( Some(bias.as_cuda_slice::()?.slice(bias_l.start_offset()..)), @@ -111,10 +111,10 @@ impl CublasLTBatchMatmulF8 { ) } else { if bias_l.shape().dims2()?.1 != m { - candle_core::bail!("Bias does not have the correct shape"); + mcandle_core::bail!("Bias does not have the correct shape"); } if bias_l.shape().dims2()?.0 != batch_size { - candle_core::bail!("Bias batch size must match batch size of `a`"); + mcandle_core::bail!("Bias batch size must match batch size of `a`"); } let bias_stride = bias_l.stride()[0] as i64; ( @@ -130,22 +130,22 @@ impl CublasLTBatchMatmulF8 { let (c, c_l) = c.storage_and_layout(); let c = match &*c { Storage::Cuda(storage) => storage.as_cuda_slice::()?, - _ => candle_core::bail!("`c` must be a cuda tensor"), + _ => mcandle_core::bail!("`c` must be a cuda tensor"), }; match c_l.contiguous_offsets() { Some((o1, o2)) => { if o1 != 0 { - candle_core::bail!("`c` start offset must be 0"); + mcandle_core::bail!("`c` start offset must be 0"); } if o2 != out_shape.elem_count() { - candle_core::bail!("`c` end offset must be {}", out_shape.elem_count()) + mcandle_core::bail!("`c` end offset must be {}", out_shape.elem_count()) } } - None => candle_core::bail!("`c` has to be contiguous"), + None => mcandle_core::bail!("`c` has to be contiguous"), }; if c_l.shape().dims3()? != (batch_size, n, m) { - candle_core::bail!("`c` does not have the correct shape"); + mcandle_core::bail!("`c` does not have the correct shape"); } // Set beta to 0.0 if it is not set @@ -185,7 +185,7 @@ impl CublasLTBatchMatmulF8 { for case in cases { if case % 16 != 0 { - candle_core::bail!("F8 cuBLASlt matmul must match all cases described here: https://docs.nvidia.com/cuda/cublas/#tensor-core-usage"); + mcandle_core::bail!("F8 cuBLASlt matmul must match all cases described here: https://docs.nvidia.com/cuda/cublas/#tensor-core-usage"); } } @@ -224,12 +224,12 @@ impl CublasLTBatchMatmulF8 { bias.as_ref(), self.act.as_ref(), ) - .map_err(|e| candle_core::Error::Cuda(Box::new(e)))?; + .map_err(|e| mcandle_core::Error::Cuda(Box::new(e)))?; } let out = match out { - OutSlice::BF16(s) => candle_core::CudaStorage::wrap_cuda_slice(s, dev.clone()), - OutSlice::F8(s) => candle_core::CudaStorage::wrap_cuda_slice(s, dev.clone()), + OutSlice::BF16(s) => mcandle_core::CudaStorage::wrap_cuda_slice(s, dev.clone()), + OutSlice::F8(s) => mcandle_core::CudaStorage::wrap_cuda_slice(s, dev.clone()), }; Ok((out, out_shape)) @@ -288,7 +288,7 @@ pub fn fused_batch_matmul_f8( } } -impl candle_core::CustomOp2 for CublasLTBatchMatmulF8 { +impl mcandle_core::CustomOp2 for CublasLTBatchMatmulF8 { fn name(&self) -> &'static str { "cublaslt-batch-matmul-f8" } @@ -300,26 +300,26 @@ impl candle_core::CustomOp2 for CublasLTBatchMatmulF8 { _: &CpuStorage, _: &Layout, ) -> Result<(CpuStorage, Shape)> { - candle_core::bail!("no cpu support for cublaslt-batch-matmul-f8") + mcandle_core::bail!("no cpu support for cublaslt-batch-matmul-f8") } fn cuda_fwd( &self, - a: &candle_core::CudaStorage, + a: &mcandle_core::CudaStorage, a_l: &Layout, - b: &candle_core::CudaStorage, + b: &mcandle_core::CudaStorage, b_l: &Layout, - ) -> Result<(candle_core::CudaStorage, Shape)> { + ) -> Result<(mcandle_core::CudaStorage, Shape)> { match a.dtype() { - candle_core::DType::F8E4M3 => self.fwd_f8e4m3(a, a_l, b, b_l, None, None), + mcandle_core::DType::F8E4M3 => self.fwd_f8e4m3(a, a_l, b, b_l, None, None), dt => { - candle_core::bail!("cublaslt-batch-matmul is only supported for f8e4m3 ({dt:?})") + mcandle_core::bail!("cublaslt-batch-matmul is only supported for f8e4m3 ({dt:?})") } } } } -impl candle_core::CustomOp3 for CublasLTBatchMatmulF8 { +impl mcandle_core::CustomOp3 for CublasLTBatchMatmulF8 { fn name(&self) -> &'static str { "cublaslt-batch-matmul-add-f8" } @@ -333,21 +333,23 @@ impl candle_core::CustomOp3 for CublasLTBatchMatmulF8 { _: &CpuStorage, _: &Layout, ) -> Result<(CpuStorage, Shape)> { - candle_core::bail!("no cpu support for cublaslt-batch-matmul-add-f8") + mcandle_core::bail!("no cpu support for cublaslt-batch-matmul-add-f8") } fn cuda_fwd( &self, - a: &candle_core::CudaStorage, + a: &mcandle_core::CudaStorage, a_l: &Layout, - b: &candle_core::CudaStorage, + b: &mcandle_core::CudaStorage, b_l: &Layout, - bias: &candle_core::CudaStorage, + bias: &mcandle_core::CudaStorage, bias_l: &Layout, - ) -> Result<(candle_core::CudaStorage, Shape)> { + ) -> Result<(mcandle_core::CudaStorage, Shape)> { match a.dtype() { - candle_core::DType::F8E4M3 => self.fwd_f8e4m3(a, a_l, b, b_l, Some(bias), Some(bias_l)), - dt => candle_core::bail!( + mcandle_core::DType::F8E4M3 => { + self.fwd_f8e4m3(a, a_l, b, b_l, Some(bias), Some(bias_l)) + } + dt => mcandle_core::bail!( "cublaslt-batch-matmul-add is only supported for f8e4m3 ({dt:?})" ), } @@ -365,13 +367,13 @@ pub struct CublasLTBatchMatmul { impl CublasLTBatchMatmul { pub fn fwd_f16( &self, - a: &candle_core::CudaStorage, + a: &mcandle_core::CudaStorage, a_l: &Layout, - b: &candle_core::CudaStorage, + b: &mcandle_core::CudaStorage, b_l: &Layout, - bias: Option<&candle_core::CudaStorage>, + bias: Option<&mcandle_core::CudaStorage>, bias_l: Option<&Layout>, - ) -> Result<(candle_core::CudaStorage, Shape)> { + ) -> Result<(mcandle_core::CudaStorage, Shape)> { let dev = a.device(); // Assume TN @@ -379,11 +381,11 @@ impl CublasLTBatchMatmul { let (b_0, n, b_2) = b_l.shape().dims3()?; if b_2 != k { - candle_core::bail!("This layer only supports TN layout"); + mcandle_core::bail!("This layer only supports TN layout"); } if b_0 != batch_size { - candle_core::bail!("`b` must have the same batch size as `a`") + mcandle_core::bail!("`b` must have the same batch size as `a`") } let lda = k; @@ -397,7 +399,7 @@ impl CublasLTBatchMatmul { let bias = if let (Some(bias), Some(bias_l)) = (bias, bias_l) { if bias_l.shape().dims1()? != m { - candle_core::bail!("Bias does not have the correct shape"); + mcandle_core::bail!("Bias does not have the correct shape"); } Some(bias.as_cuda_slice::()?.slice(bias_l.start_offset()..)) @@ -409,22 +411,22 @@ impl CublasLTBatchMatmul { let (c, c_l) = c.storage_and_layout(); let c = match &*c { Storage::Cuda(storage) => storage.as_cuda_slice::()?, - _ => candle_core::bail!("`c` must be a cuda tensor"), + _ => mcandle_core::bail!("`c` must be a cuda tensor"), }; match c_l.contiguous_offsets() { Some((o1, o2)) => { if o1 != 0 { - candle_core::bail!("`c` start offset must be 0"); + mcandle_core::bail!("`c` start offset must be 0"); } if o2 != out_shape.elem_count() { - candle_core::bail!("`c` end offset must be {}", out_shape.elem_count()) + mcandle_core::bail!("`c` end offset must be {}", out_shape.elem_count()) } } - None => candle_core::bail!("`c` has to be contiguous"), + None => mcandle_core::bail!("`c` has to be contiguous"), }; if c_l.shape().dims3()? != (batch_size, n, m) { - candle_core::bail!("`c` does not have the correct shape"); + mcandle_core::bail!("`c` does not have the correct shape"); } // Set beta to 0.0 if it is not set @@ -458,23 +460,23 @@ impl CublasLTBatchMatmul { unsafe { self.cublaslt .matmul(config, &a, &b, &mut out, bias.as_ref(), self.act.as_ref()) - .map_err(|e| candle_core::Error::Cuda(Box::new(e)))?; + .map_err(|e| mcandle_core::Error::Cuda(Box::new(e)))?; } - let out = candle_core::CudaStorage::wrap_cuda_slice(out, dev.clone()); + let out = mcandle_core::CudaStorage::wrap_cuda_slice(out, dev.clone()); Ok((out, out_shape)) } pub fn fwd_bf16( &self, - a: &candle_core::CudaStorage, + a: &mcandle_core::CudaStorage, a_l: &Layout, - b: &candle_core::CudaStorage, + b: &mcandle_core::CudaStorage, b_l: &Layout, - bias: Option<&candle_core::CudaStorage>, + bias: Option<&mcandle_core::CudaStorage>, bias_l: Option<&Layout>, - ) -> Result<(candle_core::CudaStorage, Shape)> { + ) -> Result<(mcandle_core::CudaStorage, Shape)> { let dev = a.device(); // Assume TN @@ -482,11 +484,11 @@ impl CublasLTBatchMatmul { let (b_0, n, b_2) = b_l.shape().dims3()?; if b_2 != k { - candle_core::bail!("This layer only supports TN layout"); + mcandle_core::bail!("This layer only supports TN layout"); } if b_0 != batch_size { - candle_core::bail!("`b` must have the same batch size as `a`") + mcandle_core::bail!("`b` must have the same batch size as `a`") } let lda = k; @@ -500,7 +502,7 @@ impl CublasLTBatchMatmul { let bias = if let (Some(bias), Some(bias_l)) = (bias, bias_l) { if bias_l.shape().dims1()? != m { - candle_core::bail!("Bias does not have the correct shape"); + mcandle_core::bail!("Bias does not have the correct shape"); } Some(bias.as_cuda_slice::()?.slice(bias_l.start_offset()..)) @@ -512,22 +514,22 @@ impl CublasLTBatchMatmul { let (c, c_l) = c.storage_and_layout(); let c = match &*c { Storage::Cuda(storage) => storage.as_cuda_slice::()?, - _ => candle_core::bail!("`c` must be a cuda tensor"), + _ => mcandle_core::bail!("`c` must be a cuda tensor"), }; match c_l.contiguous_offsets() { Some((o1, o2)) => { if o1 != 0 { - candle_core::bail!("`c` start offset must be 0"); + mcandle_core::bail!("`c` start offset must be 0"); } if o2 != out_shape.elem_count() { - candle_core::bail!("`c` end offset must be {}", out_shape.elem_count()) + mcandle_core::bail!("`c` end offset must be {}", out_shape.elem_count()) } } - None => candle_core::bail!("`c` has to be contiguous"), + None => mcandle_core::bail!("`c` has to be contiguous"), }; if c_l.shape().dims3()? != (batch_size, n, m) { - candle_core::bail!("`c` does not have the correct shape"); + mcandle_core::bail!("`c` does not have the correct shape"); } // Set beta to 0.0 if it is not set @@ -561,23 +563,23 @@ impl CublasLTBatchMatmul { unsafe { self.cublaslt .matmul(config, &a, &b, &mut out, bias.as_ref(), self.act.as_ref()) - .map_err(|e| candle_core::Error::Cuda(Box::new(e)))?; + .map_err(|e| mcandle_core::Error::Cuda(Box::new(e)))?; } - let out = candle_core::CudaStorage::wrap_cuda_slice(out, dev.clone()); + let out = mcandle_core::CudaStorage::wrap_cuda_slice(out, dev.clone()); Ok((out, out_shape)) } pub fn fwd_f32( &self, - a: &candle_core::CudaStorage, + a: &mcandle_core::CudaStorage, a_l: &Layout, - b: &candle_core::CudaStorage, + b: &mcandle_core::CudaStorage, b_l: &Layout, - bias: Option<&candle_core::CudaStorage>, + bias: Option<&mcandle_core::CudaStorage>, bias_l: Option<&Layout>, - ) -> Result<(candle_core::CudaStorage, Shape)> { + ) -> Result<(mcandle_core::CudaStorage, Shape)> { let dev = a.device(); // Assume TN @@ -585,11 +587,11 @@ impl CublasLTBatchMatmul { let (b_0, n, b_2) = b_l.shape().dims3()?; if b_2 != k { - candle_core::bail!("This layer only supports TN layout"); + mcandle_core::bail!("This layer only supports TN layout"); } if b_0 != batch_size { - candle_core::bail!("`b` must have the same batch size as `a`") + mcandle_core::bail!("`b` must have the same batch size as `a`") } let lda = k; @@ -603,7 +605,7 @@ impl CublasLTBatchMatmul { let bias = if let (Some(bias), Some(bias_l)) = (bias, bias_l) { if bias_l.shape().dims1()? != m { - candle_core::bail!("Bias does not have the correct shape"); + mcandle_core::bail!("Bias does not have the correct shape"); } Some(bias.as_cuda_slice::()?.slice(bias_l.start_offset()..)) @@ -615,22 +617,22 @@ impl CublasLTBatchMatmul { let (c, c_l) = c.storage_and_layout(); let c = match &*c { Storage::Cuda(storage) => storage.as_cuda_slice::()?, - _ => candle_core::bail!("`c` must be a cuda tensor"), + _ => mcandle_core::bail!("`c` must be a cuda tensor"), }; match c_l.contiguous_offsets() { Some((o1, o2)) => { if o1 != 0 { - candle_core::bail!("`c` start offset must be 0"); + mcandle_core::bail!("`c` start offset must be 0"); } if o2 != out_shape.elem_count() { - candle_core::bail!("`c` end offset must be {}", out_shape.elem_count()) + mcandle_core::bail!("`c` end offset must be {}", out_shape.elem_count()) } } - None => candle_core::bail!("`c` has to be contiguous"), + None => mcandle_core::bail!("`c` has to be contiguous"), }; if c_l.shape().dims3()? != (batch_size, n, m) { - candle_core::bail!("`c` does not have the correct shape"); + mcandle_core::bail!("`c` does not have the correct shape"); } // Set beta to 0.0 if it is not set @@ -664,16 +666,16 @@ impl CublasLTBatchMatmul { unsafe { self.cublaslt .matmul(config, &a, &b, &mut out, bias.as_ref(), self.act.as_ref()) - .map_err(|e| candle_core::Error::Cuda(Box::new(e)))?; + .map_err(|e| mcandle_core::Error::Cuda(Box::new(e)))?; } - let out = candle_core::CudaStorage::wrap_cuda_slice(out, dev.clone()); + let out = mcandle_core::CudaStorage::wrap_cuda_slice(out, dev.clone()); Ok((out, out_shape)) } } -impl candle_core::CustomOp2 for CublasLTBatchMatmul { +impl mcandle_core::CustomOp2 for CublasLTBatchMatmul { fn name(&self) -> &'static str { "cublaslt-batch-matmul" } @@ -685,22 +687,22 @@ impl candle_core::CustomOp2 for CublasLTBatchMatmul { _: &CpuStorage, _: &Layout, ) -> Result<(CpuStorage, Shape)> { - candle_core::bail!("no cpu support for cublaslt-batch-matmul") + mcandle_core::bail!("no cpu support for cublaslt-batch-matmul") } fn cuda_fwd( &self, - a: &candle_core::CudaStorage, + a: &mcandle_core::CudaStorage, a_l: &Layout, - b: &candle_core::CudaStorage, + b: &mcandle_core::CudaStorage, b_l: &Layout, - ) -> Result<(candle_core::CudaStorage, Shape)> { + ) -> Result<(mcandle_core::CudaStorage, Shape)> { match a.dtype() { - candle_core::DType::F16 => self.fwd_f16(a, a_l, b, b_l, None, None), - candle_core::DType::BF16 => self.fwd_bf16(a, a_l, b, b_l, None, None), - candle_core::DType::F32 => self.fwd_f32(a, a_l, b, b_l, None, None), + mcandle_core::DType::F16 => self.fwd_f16(a, a_l, b, b_l, None, None), + mcandle_core::DType::BF16 => self.fwd_bf16(a, a_l, b, b_l, None, None), + mcandle_core::DType::F32 => self.fwd_f32(a, a_l, b, b_l, None, None), dt => { - candle_core::bail!( + mcandle_core::bail!( "cublaslt-batch-matmul is only supported for f16/bf16/f32 ({dt:?})" ) } @@ -708,7 +710,7 @@ impl candle_core::CustomOp2 for CublasLTBatchMatmul { } } -impl candle_core::CustomOp3 for CublasLTBatchMatmul { +impl mcandle_core::CustomOp3 for CublasLTBatchMatmul { fn name(&self) -> &'static str { "cublaslt-batch-matmul-add" } @@ -722,23 +724,23 @@ impl candle_core::CustomOp3 for CublasLTBatchMatmul { _: &CpuStorage, _: &Layout, ) -> Result<(CpuStorage, Shape)> { - candle_core::bail!("no cpu support for cublaslt-batch-matmul-add") + mcandle_core::bail!("no cpu support for cublaslt-batch-matmul-add") } fn cuda_fwd( &self, - a: &candle_core::CudaStorage, + a: &mcandle_core::CudaStorage, a_l: &Layout, - b: &candle_core::CudaStorage, + b: &mcandle_core::CudaStorage, b_l: &Layout, - bias: &candle_core::CudaStorage, + bias: &mcandle_core::CudaStorage, bias_l: &Layout, - ) -> Result<(candle_core::CudaStorage, Shape)> { + ) -> Result<(mcandle_core::CudaStorage, Shape)> { match a.dtype() { - candle_core::DType::F16 => self.fwd_f16(a, a_l, b, b_l, Some(bias), Some(bias_l)), - candle_core::DType::BF16 => self.fwd_bf16(a, a_l, b, b_l, Some(bias), Some(bias_l)), - candle_core::DType::F32 => self.fwd_f32(a, a_l, b, b_l, Some(bias), Some(bias_l)), - dt => candle_core::bail!( + mcandle_core::DType::F16 => self.fwd_f16(a, a_l, b, b_l, Some(bias), Some(bias_l)), + mcandle_core::DType::BF16 => self.fwd_bf16(a, a_l, b, b_l, Some(bias), Some(bias_l)), + mcandle_core::DType::F32 => self.fwd_f32(a, a_l, b, b_l, Some(bias), Some(bias_l)), + dt => mcandle_core::bail!( "cublaslt-batch-matmul-add is only supported for f16/bf16/f32 ({dt:?})" ), } @@ -791,7 +793,7 @@ mod tests { use std::f32::consts::PI; use super::*; - use candle_core::{DType, Device, IndexOp}; + use mcandle_core::{DType, Device, IndexOp}; // The bias bit seems to trip the test up. Not really sure why; it may be something locally. #[test] diff --git a/mistralrs-quant/src/cublaslt/matmul.rs b/mistralrs-quant/src/cublaslt/matmul.rs index 0b7ccd179..75618b869 100644 --- a/mistralrs-quant/src/cublaslt/matmul.rs +++ b/mistralrs-quant/src/cublaslt/matmul.rs @@ -1,13 +1,13 @@ -use candle_core::cuda::cudarc::cublaslt::result::set_matrix_layout_attribute; -use candle_core::cuda::cudarc::cublaslt::{result, result::CublasError, sys}; -use candle_core::cuda::cudarc::driver::sys::{CUdevice_attribute, CUdeviceptr, CUstream}; -use candle_core::cuda::cudarc::driver::{ - CudaDevice, CudaSlice, DevicePtr, DevicePtrMut, DriverError, -}; use core::ffi::c_int; use core::mem; use float8::F8E4M3; use half::bf16; +use mcandle_core::cuda::cudarc::cublaslt::result::set_matrix_layout_attribute; +use mcandle_core::cuda::cudarc::cublaslt::{result, result::CublasError, sys}; +use mcandle_core::cuda::cudarc::driver::sys::{CUdevice_attribute, CUdeviceptr, CUstream}; +use mcandle_core::cuda::cudarc::driver::{ + CudaDevice, CudaSlice, DevicePtr, DevicePtrMut, DriverError, +}; use std::sync::Arc; /// Wrapper around [sys::cublasLtHandle_t] diff --git a/mistralrs-quant/src/cublaslt/mod.rs b/mistralrs-quant/src/cublaslt/mod.rs index beb72cef0..048446886 100644 --- a/mistralrs-quant/src/cublaslt/mod.rs +++ b/mistralrs-quant/src/cublaslt/mod.rs @@ -1,9 +1,9 @@ -// https://github.com/huggingface/text-embeddings-inference/blob/cc1c510e8d8af8447c01e6b14c417473cf2dfda9/backends/candle/src/layers/cublaslt.rs +// https://github.com/huggingface/text-embeddings-inference/blob/cc1c510e8d8af8447c01e6b14c417473cf2dfda9/backends/mcandle/src/layers/cublaslt.rs #![allow(unused_variables, unused_imports, dead_code)] -use candle_core::{Device, Result, Tensor}; -use candle_nn::Activation as CandleActivation; +use mcandle_core::{Device, Result, Tensor}; +use mcandle_nn::Activation as CandleActivation; use once_cell::sync::Lazy; use std::sync::{Mutex, Once}; @@ -38,7 +38,7 @@ pub fn maybe_init_cublas_lt_wrapper() { // Check if we can call the driver // Then check if we can create a device // Then check that the device is CUDA - use candle_core::cuda_backend::cudarc::driver; + use mcandle_core::cuda_backend::cudarc::driver; CUBLASLT = driver::result::init() .ok() .and_then(|_| Device::cuda_if_available(0).ok()) @@ -117,13 +117,13 @@ impl CublasLtWrapper { )?; if Some(CandleActivation::Swiglu) == act { - result = candle_nn::ops::swiglu(&result)?; + result = mcandle_nn::ops::swiglu(&result)?; } Ok(result) } #[cfg(not(feature = "cuda"))] { - candle_core::bail!("`cuda` feature is not enabled") + mcandle_core::bail!("`cuda` feature is not enabled") } } @@ -171,13 +171,13 @@ impl CublasLtWrapper { )?; if Some(CandleActivation::Swiglu) == act { - result = candle_nn::ops::swiglu(&result)?; + result = mcandle_nn::ops::swiglu(&result)?; } Ok(result) } #[cfg(not(feature = "cuda"))] { - candle_core::bail!("`cuda` feature is not enabled") + mcandle_core::bail!("`cuda` feature is not enabled") } } } diff --git a/mistralrs-quant/src/dummy/mod.rs b/mistralrs-quant/src/dummy/mod.rs index 377b7cd73..0628321a5 100644 --- a/mistralrs-quant/src/dummy/mod.rs +++ b/mistralrs-quant/src/dummy/mod.rs @@ -4,7 +4,7 @@ use crate::{QuantMethod, QuantizedSerde}; pub struct DummyLayer; impl QuantMethod for DummyLayer { - fn new(_method: crate::QuantMethodConfig) -> candle_core::Result + fn new(_method: crate::QuantMethodConfig) -> mcandle_core::Result where Self: Sized, { @@ -12,37 +12,37 @@ impl QuantMethod for DummyLayer { } fn add_delta_w( &self, - _delta: &candle_core::Tensor, - ) -> candle_core::Result> { - candle_core::bail!("DummyLayer should not ever be present in forward pass!") + _delta: &mcandle_core::Tensor, + ) -> mcandle_core::Result> { + mcandle_core::bail!("DummyLayer should not ever be present in forward pass!") } fn apply_isq( self: std::sync::Arc, _dtype: Option, - _device: candle_core::Device, + _device: mcandle_core::Device, _n_quantized: &std::sync::atomic::AtomicUsize, - ) -> candle_core::Result> { - candle_core::bail!("DummyLayer should not ever be present in forward pass!") + ) -> mcandle_core::Result> { + mcandle_core::bail!("DummyLayer should not ever be present in forward pass!") } - fn dtype_and_device(&self) -> (candle_core::DType, candle_core::Device) { - (candle_core::DType::F64, candle_core::Device::Cpu) + fn dtype_and_device(&self) -> (mcandle_core::DType, mcandle_core::Device) { + (mcandle_core::DType::F64, mcandle_core::Device::Cpu) } - fn forward(&self, _a: &candle_core::Tensor) -> candle_core::Result { - candle_core::bail!("DummyLayer should not ever be present in forward pass!") + fn forward(&self, _a: &mcandle_core::Tensor) -> mcandle_core::Result { + mcandle_core::bail!("DummyLayer should not ever be present in forward pass!") } fn forward_via_half( &self, - _a: &candle_core::Tensor, - ) -> candle_core::Result { - candle_core::bail!("DummyLayer should not ever be present in forward pass!") + _a: &mcandle_core::Tensor, + ) -> mcandle_core::Result { + mcandle_core::bail!("DummyLayer should not ever be present in forward pass!") } - fn get_bias_mut(&mut self) -> Option<&mut candle_core::Tensor> { + fn get_bias_mut(&mut self) -> Option<&mut mcandle_core::Tensor> { None } fn get_max_isq_cpu_threads(&self, _dtype: crate::IsqType) -> Option { None } - fn quantized_act_type(&self) -> Option { + fn quantized_act_type(&self) -> Option { None } } diff --git a/mistralrs-quant/src/fp8/mod.rs b/mistralrs-quant/src/fp8/mod.rs index 3de9c6f85..ea3e9e69f 100644 --- a/mistralrs-quant/src/fp8/mod.rs +++ b/mistralrs-quant/src/fp8/mod.rs @@ -6,8 +6,8 @@ use std::{ }; use byteorder::{LittleEndian, ReadBytesExt}; -use candle_core::{DType, Device, Result, Tensor, D}; -use candle_nn::{Linear, Module}; +use mcandle_core::{DType, Device, Result, Tensor, D}; +use mcandle_nn::{Linear, Module}; use quantize::QuantizationResult; mod quantize; @@ -32,7 +32,7 @@ pub struct FP8Linear { } impl QuantMethod for FP8Linear { - fn new(method: QuantMethodConfig) -> candle_core::Result + fn new(method: QuantMethodConfig) -> mcandle_core::Result where Self: Sized, { @@ -67,7 +67,7 @@ impl QuantMethod for FP8Linear { Some(handle) => { let n_dims = x.dims().len(); if n_dims < 3 { - candle_core::bail!( + mcandle_core::bail!( "FP8Linear `matmul` via cuBLASlt expects `x` to have at least 3 dimensions" ); } @@ -138,7 +138,7 @@ impl QuantMethod for FP8Linear { })?)) } - fn dtype_and_device(&self) -> (DType, candle_core::Device) { + fn dtype_and_device(&self) -> (DType, mcandle_core::Device) { (DType::F8E4M3, self.lin.weight().device().clone()) } @@ -245,12 +245,12 @@ impl QuantizedSerde for FP8Linear { let version = buffer.read_u32::()?; if let Err(e) = version_is_compatible(version) { - return Err(candle_core::Error::wrap(e)); + return Err(mcandle_core::Error::wrap(e)); } let isq_type = buffer.read_u8()? as usize; if isq_type != QuantizedSerdeType::Fp8 as usize { - candle_core::bail!( + mcandle_core::bail!( "ISQ type ({isq_type}) doesn't match expected type {}", QuantizedSerdeType::Fp8 as usize ); diff --git a/mistralrs-quant/src/fp8/quantize.rs b/mistralrs-quant/src/fp8/quantize.rs index c87dd82bf..be93fc917 100644 --- a/mistralrs-quant/src/fp8/quantize.rs +++ b/mistralrs-quant/src/fp8/quantize.rs @@ -1,6 +1,6 @@ -use candle_core::{DType, Result, Tensor}; -use candle_nn::Linear; use float8::F8E4M3; +use mcandle_core::{DType, Result, Tensor}; +use mcandle_nn::Linear; use super::FP8Linear; @@ -58,7 +58,7 @@ impl FP8Linear { #[cfg(test)] mod tests { - use candle_core::{ + use mcandle_core::{ quantized::{GgmlDType, QTensor}, DType, Device, Result, Tensor, }; diff --git a/mistralrs-quant/src/gguf/mod.rs b/mistralrs-quant/src/gguf/mod.rs index 697b13567..cd3e60544 100644 --- a/mistralrs-quant/src/gguf/mod.rs +++ b/mistralrs-quant/src/gguf/mod.rs @@ -6,11 +6,11 @@ use std::{ }; use byteorder::{LittleEndian, ReadBytesExt}; -use candle_core::{ +use mcandle_core::{ quantized::{ggml_file::qtensor_from_ggml, GgmlDType, QMatMul, QTensor}, DType, Device, Result, Tensor, }; -use candle_nn::Module; +use mcandle_nn::Module; use crate::{ generate_isq, @@ -86,14 +86,14 @@ impl QuantMethod for GgufMatMul { } => { let (w, dtype) = (w.dequantize(&w.device())?, w.dtype()); let w = QMatMul::QTensor(std::sync::Arc::new( - candle_core::quantized::QTensor::quantize(&(w + delta)?, dtype)?, + mcandle_core::quantized::QTensor::quantize(&(w + delta)?, dtype)?, )); Ok(Arc::new(Self { w, b: b.clone() })) } } } - fn dtype_and_device(&self) -> (DType, candle_core::Device) { + fn dtype_and_device(&self) -> (DType, mcandle_core::Device) { match &self.w { QMatMul::QTensor(q) => (DType::F32, q.device()), QMatMul::Tensor(t) | QMatMul::TensorF16(t) => (t.dtype(), t.device().clone()), @@ -229,7 +229,7 @@ impl QuantizedSerde for GgufMatMul { buffer } QMatMul::TensorF16(_) | QMatMul::Tensor(_) => { - candle_core::bail!("Cannot serialize non-quantized") + mcandle_core::bail!("Cannot serialize non-quantized") } }; @@ -245,12 +245,12 @@ impl QuantizedSerde for GgufMatMul { let version = buffer.read_u32::()?; if let Err(e) = version_is_compatible(version) { - return Err(candle_core::Error::wrap(e)); + return Err(mcandle_core::Error::wrap(e)); } let isq_type = buffer.read_u8()? as usize; if isq_type != QuantizedSerdeType::Gguf as usize { - candle_core::bail!( + mcandle_core::bail!( "ISQ type ({isq_type}) doesn't match expected type {}", QuantizedSerdeType::Gguf as usize ); @@ -278,7 +278,7 @@ impl QuantizedSerde for GgufMatMul { 15 => GgmlDType::Q8K, // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389 30 => GgmlDType::BF16, - _ => candle_core::bail!("unknown dtype for quantized weight tensor {dtype}"), + _ => mcandle_core::bail!("unknown dtype for quantized weight tensor {dtype}"), }; let n_dims = buffer.read_u32::()? as usize; diff --git a/mistralrs-quant/src/gptq/gptq_cpu.rs b/mistralrs-quant/src/gptq/gptq_cpu.rs index ea66a243a..c22d7066b 100644 --- a/mistralrs-quant/src/gptq/gptq_cpu.rs +++ b/mistralrs-quant/src/gptq/gptq_cpu.rs @@ -1,6 +1,6 @@ use crate::{DummyLayer, IsqType, QuantMethod, QuantMethodConfig, QuantizedConfig, QuantizedSerde}; -use candle_core::{DType, Device, Result, Tensor}; -use candle_nn::VarBuilder; +use mcandle_core::{DType, Device, Result, Tensor}; +use mcandle_nn::VarBuilder; use std::{ num::NonZeroUsize, sync::{atomic::AtomicUsize, Arc}, @@ -15,7 +15,9 @@ impl QuantMethod for GptqLayer { Self: Sized, { match method { - QuantMethodConfig::Gptq { .. } => candle_core::bail!("GPTQ is only supported on CUDA."), + QuantMethodConfig::Gptq { .. } => { + mcandle_core::bail!("GPTQ is only supported on CUDA.") + } QuantMethodConfig::Gguf { .. } | QuantMethodConfig::Unquantized(_) | QuantMethodConfig::Hqq { .. } @@ -38,7 +40,7 @@ impl QuantMethod for GptqLayer { todo!() } - fn dtype_and_device(&self) -> (DType, candle_core::Device) { + fn dtype_and_device(&self) -> (DType, mcandle_core::Device) { todo!() } diff --git a/mistralrs-quant/src/gptq/gptq_cuda.rs b/mistralrs-quant/src/gptq/gptq_cuda.rs index c392c7394..026854de7 100644 --- a/mistralrs-quant/src/gptq/gptq_cuda.rs +++ b/mistralrs-quant/src/gptq/gptq_cuda.rs @@ -4,7 +4,9 @@ use std::{ sync::{atomic::AtomicUsize, Arc, Mutex}, }; -use candle_core::{ +use half::f16; +use lazy_static::lazy_static; +use mcandle_core::{ cuda::{ cudarc::{ cublas::{result::hgemm, sys::cublasOperation_t}, @@ -14,9 +16,7 @@ use candle_core::{ }, from_storage_no_op, Context, CudaStorage, DType, Device, Result, Shape, Storage, Tensor, D, }; -use candle_nn::VarBuilder; -use half::f16; -use lazy_static::lazy_static; +use mcandle_nn::VarBuilder; use crate::{ gptq::marlin_backend::{gptq_marlin_matmul, gptq_weight_repack}, @@ -65,7 +65,7 @@ impl GptqLayer { use_exllama: bool, ) -> Result { if !a.is_contiguous() { - candle_core::bail!( + mcandle_core::bail!( "Expected `a` to be contiguous, got strides {:?}", a.layout().stride() ) @@ -278,7 +278,7 @@ impl QuantMethod for GptqLayer { ); let reshaped_a = a.reshape(((), a.dim(D::Minus1)?))?; if !reshaped_a.device().is_cuda() { - candle_core::bail!("Expected CUDA input to GptqLayer"); + mcandle_core::bail!("Expected CUDA input to GptqLayer"); } let out = match ( @@ -317,7 +317,7 @@ impl QuantMethod for GptqLayer { } fn add_delta_w(&self, _delta: &Tensor) -> Result> { - candle_core::bail!("GPTQ quantization does not support adding weight delta.") + mcandle_core::bail!("GPTQ quantization does not support adding weight delta.") } fn dtype_and_device(&self) -> (DType, Device) { @@ -334,7 +334,7 @@ impl QuantMethod for GptqLayer { _device: Device, _n_quantized: &AtomicUsize, ) -> Result> { - candle_core::bail!("GPTQ quantization does not support ISQ.") + mcandle_core::bail!("GPTQ quantization does not support ISQ.") } fn get_max_isq_cpu_threads(&self, _dtype: IsqType) -> Option { diff --git a/mistralrs-quant/src/gptq/marlin_backend.rs b/mistralrs-quant/src/gptq/marlin_backend.rs index 200ad8a1a..a80ef5ab9 100644 --- a/mistralrs-quant/src/gptq/marlin_backend.rs +++ b/mistralrs-quant/src/gptq/marlin_backend.rs @@ -1,12 +1,12 @@ use super::marlin_ffi::{ gptq_marlin_repack, marlin_4bit_bf16, marlin_4bit_f16, HAVE_MARLIN_KERNELS, }; -use candle::backend::BackendStorage; -use candle::cuda_backend::cudarc::driver::DevicePtr; -use candle::cuda_backend::WrapErr; -use candle::{CpuStorage, CudaStorage, DType, Layout, Result, Shape, Storage, Tensor}; -use candle_core as candle; use half::{bf16, f16}; +use mcandle::backend::BackendStorage; +use mcandle::cuda_backend::cudarc::driver::DevicePtr; +use mcandle::cuda_backend::WrapErr; +use mcandle::{CpuStorage, CudaStorage, DType, Layout, Result, Shape, Storage, Tensor}; +use mcandle_core as mcandle; struct GPTQMatMul { workspace: Tensor, @@ -15,7 +15,7 @@ struct GPTQMatMul { impl GPTQMatMul { fn cuda_fwd_t< - T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr, + T: mcandle::cuda_backend::CudaDType + mcandle::cuda_backend::cudarc::driver::DeviceRepr, >( &self, x: &CudaStorage, @@ -61,7 +61,7 @@ impl GPTQMatMul { let (workspace, workspace_l) = self.workspace.storage_and_layout(); let workspace = match &*workspace { Storage::Cuda(p) => p, - _ => candle::bail!("workspace must be a cuda tensor"), + _ => mcandle::bail!("workspace must be a cuda tensor"), }; let workspace_ = workspace.as_cuda_slice::()?; let workspace_ = workspace_.slice(workspace_l.start_offset()..); @@ -74,7 +74,7 @@ impl GPTQMatMul { (size_k / scale_shape[0]) as i32 }; if !HAVE_MARLIN_KERNELS { - candle_core::bail!( + mcandle_core::bail!( "Marlin INT4xF16 matmul kernels were not compiled, please raise an issue." ) } @@ -113,7 +113,7 @@ impl GPTQMatMul { } } -impl candle::CustomOp3 for GPTQMatMul { +impl mcandle::CustomOp3 for GPTQMatMul { fn name(&self) -> &'static str { "GPTQMatMul" } @@ -127,7 +127,7 @@ impl candle::CustomOp3 for GPTQMatMul { _: &CpuStorage, _: &Layout, ) -> Result<(CpuStorage, Shape)> { - candle::bail!("no cpu support for GPTQMatMul") + mcandle::bail!("no cpu support for GPTQMatMul") } fn cuda_fwd( @@ -142,7 +142,7 @@ impl candle::CustomOp3 for GPTQMatMul { match x.dtype() { DType::F16 => self.cuda_fwd_t::(x, x_l, qweight, qweight_l, scale, scale_l), DType::BF16 => self.cuda_fwd_t::(x, x_l, qweight, qweight_l, scale, scale_l), - dt => candle::bail!("GPTQMatMul is only supported for f16 and bf16 ({dt:?})"), + dt => mcandle::bail!("GPTQMatMul is only supported for f16 and bf16 ({dt:?})"), } } } @@ -168,7 +168,7 @@ struct GPTQRepack { impl GPTQRepack { fn cuda_fwd_t< - T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr, + T: mcandle::cuda_backend::CudaDType + mcandle::cuda_backend::cudarc::driver::DeviceRepr, >( &self, qweight: &CudaStorage, @@ -204,7 +204,7 @@ impl GPTQRepack { gptq_marlin_repack(q_ptr, q_perm, out_ptr, self.k, q_shape[1] as i32, self.bits) } } else { - candle_core::bail!("Not compiled with marlin kernels, but attempted to use one. Please raise an issue."); + mcandle_core::bail!("Not compiled with marlin kernels, but attempted to use one. Please raise an issue."); } let out = CudaStorage::wrap_cuda_slice(out, dev.clone()); @@ -212,7 +212,7 @@ impl GPTQRepack { } } -impl candle::CustomOp2 for GPTQRepack { +impl mcandle::CustomOp2 for GPTQRepack { fn name(&self) -> &'static str { "GPTQRepack" } @@ -224,7 +224,7 @@ impl candle::CustomOp2 for GPTQRepack { _: &CpuStorage, _: &Layout, ) -> Result<(CpuStorage, Shape)> { - candle::bail!("no cpu support for GPTQRepack") + mcandle::bail!("no cpu support for GPTQRepack") } fn cuda_fwd( @@ -237,7 +237,7 @@ impl candle::CustomOp2 for GPTQRepack { match qweight.dtype() { DType::U32 => self.cuda_fwd_t::(qweight, qweight_l, perm, perm_l), DType::I32 => self.cuda_fwd_t::(qweight, qweight_l, perm, perm_l), - dt => candle::bail!("GPTQRepack is only supported for i32/u32 weight ({dt:?})"), + dt => mcandle::bail!("GPTQRepack is only supported for i32/u32 weight ({dt:?})"), } } } diff --git a/mistralrs-quant/src/hqq/hqq_op.rs b/mistralrs-quant/src/hqq/hqq_op.rs index 977b1ee3a..f612f9fd6 100644 --- a/mistralrs-quant/src/hqq/hqq_op.rs +++ b/mistralrs-quant/src/hqq/hqq_op.rs @@ -1,6 +1,6 @@ #[cfg(feature = "metal")] -use candle_core::{backend::BackendStorage, DType}; -use candle_core::{CpuStorage, CustomOp3, Layout, Result, Shape, WithDType}; +use mcandle_core::{backend::BackendStorage, DType}; +use mcandle_core::{CpuStorage, CustomOp3, Layout, Result, Shape, WithDType}; /* 8 bit @@ -35,10 +35,10 @@ impl CustomOp3 for Dequant8Bit { l_z: &Layout, ) -> Result<(CpuStorage, Shape)> { let CpuStorage::U8(w_slice) = w else { - candle_core::bail!("Weight must be u8, HQQ dequant 8-bit"); + mcandle_core::bail!("Weight must be u8, HQQ dequant 8-bit"); }; if !(l_w.is_contiguous() && l_s.is_contiguous() && l_z.is_contiguous()) { - candle_core::bail!("All inputs must be contiguous"); + mcandle_core::bail!("All inputs must be contiguous"); } match (s, z) { (CpuStorage::F32(s_slice), CpuStorage::F32(z_slice)) => Ok(( @@ -53,24 +53,24 @@ impl CustomOp3 for Dequant8Bit { CpuStorage::BF16(self.dequantize(w_slice, s_slice, z_slice)), Shape::from_dims(&[self.h, self.w]), )), - (_, _) => candle_core::bail!("Dtype mismatch, expected one of f32, f16, bf16"), + (_, _) => mcandle_core::bail!("Dtype mismatch, expected one of f32, f16, bf16"), } } #[cfg(feature = "metal")] fn metal_fwd( &self, - w: &candle_core::MetalStorage, + w: &mcandle_core::MetalStorage, l_w: &Layout, - s: &candle_core::MetalStorage, + s: &mcandle_core::MetalStorage, l_s: &Layout, - z: &candle_core::MetalStorage, + z: &mcandle_core::MetalStorage, l_z: &Layout, - ) -> Result<(candle_core::MetalStorage, Shape)> { + ) -> Result<(mcandle_core::MetalStorage, Shape)> { if w.dtype() != DType::U8 { - candle_core::bail!("Weight must be u8, HQQ dequant 8-bit"); + mcandle_core::bail!("Weight must be u8, HQQ dequant 8-bit"); }; if !(l_w.is_contiguous() && l_s.is_contiguous() && l_z.is_contiguous()) { - candle_core::bail!("All inputs must be contiguous"); + mcandle_core::bail!("All inputs must be contiguous"); } let command_buffer = w.device().command_buffer()?; @@ -94,9 +94,9 @@ impl CustomOp3 for Dequant8Bit { self.w as u32, &output, ) - .map_err(candle_core::Error::wrap)?; + .map_err(mcandle_core::Error::wrap)?; - let newstorage = candle_core::MetalStorage::new( + let newstorage = mcandle_core::MetalStorage::new( output, device.clone(), out_shape.elem_count(), @@ -144,10 +144,10 @@ impl CustomOp3 for Dequant4Bit { const PACK_FACTOR: usize = 2; let CpuStorage::U8(w_slice) = w else { - candle_core::bail!("Weight must be u8, HQQ dequant 4-bit"); + mcandle_core::bail!("Weight must be u8, HQQ dequant 4-bit"); }; if !(l_w.is_contiguous() && l_s.is_contiguous() && l_z.is_contiguous()) { - candle_core::bail!("All inputs must be contiguous"); + mcandle_core::bail!("All inputs must be contiguous"); } match (s, z) { (CpuStorage::F32(s_slice), CpuStorage::F32(z_slice)) => Ok(( @@ -162,26 +162,26 @@ impl CustomOp3 for Dequant4Bit { CpuStorage::BF16(self.dequantize(w_slice, s_slice, z_slice)), Shape::from_dims(&[PACK_FACTOR * self.h, self.w]), )), - (_, _) => candle_core::bail!("Dtype mismatch, expected one of f32, f16, bf16"), + (_, _) => mcandle_core::bail!("Dtype mismatch, expected one of f32, f16, bf16"), } } #[cfg(feature = "metal")] fn metal_fwd( &self, - w: &candle_core::MetalStorage, + w: &mcandle_core::MetalStorage, l_w: &Layout, - s: &candle_core::MetalStorage, + s: &mcandle_core::MetalStorage, l_s: &Layout, - z: &candle_core::MetalStorage, + z: &mcandle_core::MetalStorage, l_z: &Layout, - ) -> Result<(candle_core::MetalStorage, Shape)> { + ) -> Result<(mcandle_core::MetalStorage, Shape)> { const PACK_FACTOR: usize = 2; if w.dtype() != DType::U8 { - candle_core::bail!("Weight must be u8, HQQ dequant 4-bit"); + mcandle_core::bail!("Weight must be u8, HQQ dequant 4-bit"); }; if !(l_w.is_contiguous() && l_s.is_contiguous() && l_z.is_contiguous()) { - candle_core::bail!("All inputs must be contiguous"); + mcandle_core::bail!("All inputs must be contiguous"); } let command_buffer = w.device().command_buffer()?; @@ -205,9 +205,9 @@ impl CustomOp3 for Dequant4Bit { self.w as u32, &output, ) - .map_err(candle_core::Error::wrap)?; + .map_err(mcandle_core::Error::wrap)?; - let newstorage = candle_core::MetalStorage::new( + let newstorage = mcandle_core::MetalStorage::new( output, device.clone(), out_shape.elem_count(), @@ -257,10 +257,10 @@ impl CustomOp3 for Dequant2Bit { const PACK_FACTOR: usize = 4; let CpuStorage::U8(w_slice) = w else { - candle_core::bail!("Weight must be u8, HQQ dequant 2-bit"); + mcandle_core::bail!("Weight must be u8, HQQ dequant 2-bit"); }; if !(l_w.is_contiguous() && l_s.is_contiguous() && l_z.is_contiguous()) { - candle_core::bail!("All inputs must be contiguous"); + mcandle_core::bail!("All inputs must be contiguous"); } match (s, z) { (CpuStorage::F32(s_slice), CpuStorage::F32(z_slice)) => Ok(( @@ -275,26 +275,26 @@ impl CustomOp3 for Dequant2Bit { CpuStorage::BF16(self.dequantize(w_slice, s_slice, z_slice)), Shape::from_dims(&[PACK_FACTOR * self.h, self.w]), )), - (_, _) => candle_core::bail!("Dtype mismatch, expected one of f32, f16, bf16"), + (_, _) => mcandle_core::bail!("Dtype mismatch, expected one of f32, f16, bf16"), } } #[cfg(feature = "metal")] fn metal_fwd( &self, - w: &candle_core::MetalStorage, + w: &mcandle_core::MetalStorage, l_w: &Layout, - s: &candle_core::MetalStorage, + s: &mcandle_core::MetalStorage, l_s: &Layout, - z: &candle_core::MetalStorage, + z: &mcandle_core::MetalStorage, l_z: &Layout, - ) -> Result<(candle_core::MetalStorage, Shape)> { + ) -> Result<(mcandle_core::MetalStorage, Shape)> { const PACK_FACTOR: usize = 4; if w.dtype() != DType::U8 { - candle_core::bail!("Weight must be u8, HQQ dequant 2-bit"); + mcandle_core::bail!("Weight must be u8, HQQ dequant 2-bit"); }; if !(l_w.is_contiguous() && l_s.is_contiguous() && l_z.is_contiguous()) { - candle_core::bail!("All inputs must be contiguous"); + mcandle_core::bail!("All inputs must be contiguous"); } let command_buffer = w.device().command_buffer()?; @@ -318,9 +318,9 @@ impl CustomOp3 for Dequant2Bit { self.w as u32, &output, ) - .map_err(candle_core::Error::wrap)?; + .map_err(mcandle_core::Error::wrap)?; - let newstorage = candle_core::MetalStorage::new( + let newstorage = mcandle_core::MetalStorage::new( output, device.clone(), out_shape.elem_count(), @@ -374,10 +374,10 @@ impl CustomOp3 for Dequant1Bit { const PACK_FACTOR: usize = 8; let CpuStorage::U8(w_slice) = w else { - candle_core::bail!("Weight must be u8, HQQ dequant 1-bit"); + mcandle_core::bail!("Weight must be u8, HQQ dequant 1-bit"); }; if !(l_w.is_contiguous() && l_s.is_contiguous() && l_z.is_contiguous()) { - candle_core::bail!("All inputs must be contiguous"); + mcandle_core::bail!("All inputs must be contiguous"); } match (s, z) { (CpuStorage::F32(s_slice), CpuStorage::F32(z_slice)) => Ok(( @@ -392,26 +392,26 @@ impl CustomOp3 for Dequant1Bit { CpuStorage::BF16(self.dequantize(w_slice, s_slice, z_slice)), Shape::from_dims(&[PACK_FACTOR * self.h, self.w]), )), - (_, _) => candle_core::bail!("Dtype mismatch, expected one of f32, f16, bf16"), + (_, _) => mcandle_core::bail!("Dtype mismatch, expected one of f32, f16, bf16"), } } #[cfg(feature = "metal")] fn metal_fwd( &self, - w: &candle_core::MetalStorage, + w: &mcandle_core::MetalStorage, l_w: &Layout, - s: &candle_core::MetalStorage, + s: &mcandle_core::MetalStorage, l_s: &Layout, - z: &candle_core::MetalStorage, + z: &mcandle_core::MetalStorage, l_z: &Layout, - ) -> Result<(candle_core::MetalStorage, Shape)> { + ) -> Result<(mcandle_core::MetalStorage, Shape)> { const PACK_FACTOR: usize = 8; if w.dtype() != DType::U8 { - candle_core::bail!("Weight must be u8, HQQ dequant 1-bit"); + mcandle_core::bail!("Weight must be u8, HQQ dequant 1-bit"); }; if !(l_w.is_contiguous() && l_s.is_contiguous() && l_z.is_contiguous()) { - candle_core::bail!("All inputs must be contiguous"); + mcandle_core::bail!("All inputs must be contiguous"); } let command_buffer = w.device().command_buffer()?; @@ -435,9 +435,9 @@ impl CustomOp3 for Dequant1Bit { self.w as u32, &output, ) - .map_err(candle_core::Error::wrap)?; + .map_err(mcandle_core::Error::wrap)?; - let newstorage = candle_core::MetalStorage::new( + let newstorage = mcandle_core::MetalStorage::new( output, device.clone(), out_shape.elem_count(), @@ -493,10 +493,10 @@ impl CustomOp3 for Dequant3Bit { const PACK_FACTOR: usize = 10; let CpuStorage::I32(w_slice) = w else { - candle_core::bail!("Weight must be i32, HQQ dequant 3-bit"); + mcandle_core::bail!("Weight must be i32, HQQ dequant 3-bit"); }; if !(l_w.is_contiguous() && l_s.is_contiguous() && l_z.is_contiguous()) { - candle_core::bail!("All inputs must be contiguous"); + mcandle_core::bail!("All inputs must be contiguous"); } match (s, z) { (CpuStorage::F32(s_slice), CpuStorage::F32(z_slice)) => Ok(( @@ -511,26 +511,26 @@ impl CustomOp3 for Dequant3Bit { CpuStorage::BF16(self.dequantize(w_slice, s_slice, z_slice)), Shape::from_dims(&[PACK_FACTOR * self.h, self.w]), )), - (_, _) => candle_core::bail!("Dtype mismatch, expected one of f32, f16, bf16"), + (_, _) => mcandle_core::bail!("Dtype mismatch, expected one of f32, f16, bf16"), } } #[cfg(feature = "metal")] fn metal_fwd( &self, - w: &candle_core::MetalStorage, + w: &mcandle_core::MetalStorage, l_w: &Layout, - s: &candle_core::MetalStorage, + s: &mcandle_core::MetalStorage, l_s: &Layout, - z: &candle_core::MetalStorage, + z: &mcandle_core::MetalStorage, l_z: &Layout, - ) -> Result<(candle_core::MetalStorage, Shape)> { + ) -> Result<(mcandle_core::MetalStorage, Shape)> { const PACK_FACTOR: usize = 10; if w.dtype() != DType::I32 { - candle_core::bail!("Weight must be i32, HQQ dequant 3-bit"); + mcandle_core::bail!("Weight must be i32, HQQ dequant 3-bit"); }; if !(l_w.is_contiguous() && l_s.is_contiguous() && l_z.is_contiguous()) { - candle_core::bail!("All inputs must be contiguous"); + mcandle_core::bail!("All inputs must be contiguous"); } let command_buffer = w.device().command_buffer()?; @@ -554,9 +554,9 @@ impl CustomOp3 for Dequant3Bit { self.w as u32, &output, ) - .map_err(candle_core::Error::wrap)?; + .map_err(mcandle_core::Error::wrap)?; - let newstorage = candle_core::MetalStorage::new( + let newstorage = mcandle_core::MetalStorage::new( output, device.clone(), out_shape.elem_count(), diff --git a/mistralrs-quant/src/hqq/mod.rs b/mistralrs-quant/src/hqq/mod.rs index c7b4a8fd0..6d12473e2 100644 --- a/mistralrs-quant/src/hqq/mod.rs +++ b/mistralrs-quant/src/hqq/mod.rs @@ -1,8 +1,8 @@ use byteorder::{LittleEndian, ReadBytesExt}; -use candle_core::{DType, Device, Result, Shape, Tensor}; +use mcandle_core::{DType, Device, Result, Shape, Tensor}; #[cfg(feature = "cuda")] -use candle_core::{ +use mcandle_core::{ cuda::{cudarc::driver::DevicePtr, CudaStorageSlice, WrapErr}, from_storage_no_op, CudaStorage, Storage, }; @@ -86,12 +86,12 @@ pub enum HqqAxis { } impl TryFrom for HqqAxis { - type Error = candle_core::Error; + type Error = mcandle_core::Error; fn try_from(value: usize) -> std::result::Result { match value { 0 => Ok(Self::Zero), 1 => Ok(Self::One), - other => candle_core::bail!("Unexpected value for HQQ axis {other}"), + other => mcandle_core::bail!("Unexpected value for HQQ axis {other}"), } } } @@ -106,7 +106,7 @@ pub enum HqqBits { } impl TryFrom for HqqBits { - type Error = candle_core::Error; + type Error = mcandle_core::Error; fn try_from(value: usize) -> std::result::Result { match value { 8 => Ok(Self::Eight), @@ -114,7 +114,7 @@ impl TryFrom for HqqBits { 3 => Ok(Self::Three), 2 => Ok(Self::Two), 1 => Ok(Self::One), - other => candle_core::bail!("Unexpected value for HQQ bits {other}"), + other => mcandle_core::bail!("Unexpected value for HQQ bits {other}"), } } } @@ -236,15 +236,15 @@ impl HqqLayer { match (self.scales.dtype(), self.zeros.dtype()) { (DType::F16, DType::F16) | (DType::BF16, DType::BF16) | (DType::F32, DType::F32) => (), (a, b) => { - candle_core::bail!("Expected all dtypes to be the same, got ({a:?}, {b:?}).") + mcandle_core::bail!("Expected all dtypes to be the same, got ({a:?}, {b:?}).") } } if !(self.w_q.is_contiguous() && self.scales.is_contiguous() && self.zeros.is_contiguous()) { - candle_core::bail!("All tensors must be contiguous!"); + mcandle_core::bail!("All tensors must be contiguous!"); } if self.cfg.axis as usize != 0 { - candle_core::bail!( + mcandle_core::bail!( "CPU HQQ dequantization requires axis == 0, got {}.", self.cfg.axis as usize ); @@ -272,7 +272,7 @@ impl HqqLayer { .w_q .apply_op3_no_bwd(&self.scales, &self.zeros, &Dequant1Bit { h, w })? .reshape(&self.w_shape), - b => candle_core::bail!("Unreachable bits {b}"), + b => mcandle_core::bail!("Unreachable bits {b}"), } } @@ -282,15 +282,15 @@ impl HqqLayer { match (self.scales.dtype(), self.zeros.dtype()) { (DType::F16, DType::F16) | (DType::BF16, DType::BF16) | (DType::F32, DType::F32) => (), (a, b) => { - candle_core::bail!("Expected all dtypes to be the same, got ({a:?}, {b:?}).") + mcandle_core::bail!("Expected all dtypes to be the same, got ({a:?}, {b:?}).") } } if !(self.w_q.is_contiguous() && self.scales.is_contiguous() && self.zeros.is_contiguous()) { - candle_core::bail!("All tensors must be contiguous!"); + mcandle_core::bail!("All tensors must be contiguous!"); } if self.cfg.axis as usize != 0 { - candle_core::bail!( + mcandle_core::bail!( "CUDA HQQ dequantization requires axis == 0, got {}.", self.cfg.axis as usize ); @@ -491,7 +491,9 @@ impl HqqLayer { 1bit_u8_kernel_bf16 ) } - (bits, dtype) => candle_core::bail!("Unsupported bit width {bits} and dtype {dtype:?}"), + (bits, dtype) => { + mcandle_core::bail!("Unsupported bit width {bits} and dtype {dtype:?}") + } }; inner.reshape(&self.w_shape) } @@ -574,7 +576,7 @@ impl QuantMethod for HqqLayer { } fn add_delta_w(&self, _delta: &Tensor) -> Result> { - candle_core::bail!("HQQ quantization does not support adding weight delta.") + mcandle_core::bail!("HQQ quantization does not support adding weight delta.") } fn dtype_and_device(&self) -> (DType, Device) { @@ -598,7 +600,7 @@ impl QuantMethod for HqqLayer { // Some(IsqType::HQQ3) => HqqBits::Three, // Some(IsqType::HQQ2) => HqqBits::Two, // Some(IsqType::HQQ1) => HqqBits::One, - _ => candle_core::bail!("Expected a HQQ ISQ type."), + _ => mcandle_core::bail!("Expected a HQQ ISQ type."), }; let cfg = HqqConfig { bits, @@ -717,12 +719,12 @@ impl QuantizedSerde for HqqLayer { let version = buffer.read_u32::()?; if let Err(e) = version_is_compatible(version) { - return Err(candle_core::Error::wrap(e)); + return Err(mcandle_core::Error::wrap(e)); } let isq_type = buffer.read_u8()? as usize; if isq_type != QuantizedSerdeType::Hqq as usize { - candle_core::bail!( + mcandle_core::bail!( "ISQ type ({isq_type}) doesn't match expected type {}", QuantizedSerdeType::Hqq as usize ); diff --git a/mistralrs-quant/src/hqq/optimize.rs b/mistralrs-quant/src/hqq/optimize.rs index 9168f5856..4c91ab562 100644 --- a/mistralrs-quant/src/hqq/optimize.rs +++ b/mistralrs-quant/src/hqq/optimize.rs @@ -1,4 +1,4 @@ -use candle_core::{DType, Result, Tensor}; +use mcandle_core::{DType, Result, Tensor}; use super::{HqqAxis, HqqLayer, OPTIMIZER_HQQ_DEFAULT_STEPS}; diff --git a/mistralrs-quant/src/hqq/quantize.rs b/mistralrs-quant/src/hqq/quantize.rs index 17707d7d2..31563f6d6 100644 --- a/mistralrs-quant/src/hqq/quantize.rs +++ b/mistralrs-quant/src/hqq/quantize.rs @@ -1,4 +1,4 @@ -use candle_core::{DType, Device, Result, Tensor}; +use mcandle_core::{DType, Device, Result, Tensor}; use crate::hqq::optimize::OptResults; @@ -9,7 +9,7 @@ impl HqqLayer { pub fn quantize(input: &Tensor, device: &Device, cfg: HqqConfig) -> Result { let group_size: usize = cfg.group_size.into(); if input.elem_count() % group_size != 0 { - candle_core::bail!("`group_size` should be divisible by the tensor number of elements, which are {}, got a group size of {group_size}.", input.elem_count()); + mcandle_core::bail!("`group_size` should be divisible by the tensor number of elements, which are {}, got a group size of {group_size}.", input.elem_count()); } let mut w = input.clone().to_dtype(DType::F32)?; @@ -99,11 +99,11 @@ impl HqqLayer { #[cfg(test)] mod test { - use candle_core::{Device, Result, Tensor}; + use mcandle_core::{Device, Result, Tensor}; #[test] fn test_quantize_hqq() -> Result<()> { - use candle_core::DType; + use mcandle_core::DType; use crate::{HqqAxis, HqqBits, HqqConfig, HqqLayer}; diff --git a/mistralrs-quant/src/lib.rs b/mistralrs-quant/src/lib.rs index 172ef9740..ca28c1227 100644 --- a/mistralrs-quant/src/lib.rs +++ b/mistralrs-quant/src/lib.rs @@ -5,7 +5,7 @@ use std::{ sync::{atomic::AtomicUsize, Arc}, }; -use candle_core::{ +use mcandle_core::{ quantized::{GgmlDType, QTensor}, DType, Device, Result, Tensor, }; @@ -30,7 +30,7 @@ pub use gptq::GptqLayer; pub use hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer}; pub use unquantized::UnquantLinear; -use candle_nn::{Linear, Module, VarBuilder}; +use mcandle_nn::{Linear, Module, VarBuilder}; use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Deserialize, Serialize, Default)] @@ -114,7 +114,7 @@ pub enum IsqType { } impl TryFrom for GgmlDType { - type Error = candle_core::Error; + type Error = mcandle_core::Error; fn try_from(value: IsqType) -> Result { let tp = match value { @@ -130,7 +130,7 @@ impl TryFrom for GgmlDType { IsqType::Q8K => Self::Q8K, IsqType::Q8_0 => Self::Q8_0, IsqType::Q8_1 => Self::Q8_1, - _ => candle_core::bail!("Expected valid GGML ISQ type."), + _ => mcandle_core::bail!("Expected valid GGML ISQ type."), }; #[cfg(feature = "cuda")] { @@ -147,7 +147,7 @@ impl TryFrom for GgmlDType { | GgmlDType::Q5K | GgmlDType::Q6K ) { - candle_core::bail!("GGML ISQ type on CUDA must be one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `HQQ8`, `HQQ4`") + mcandle_core::bail!("GGML ISQ type on CUDA must be one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `HQQ8`, `HQQ4`") } } Ok(tp) @@ -162,14 +162,14 @@ pub enum QuantizedSerdeType { } impl TryFrom for QuantizedSerdeType { - type Error = candle_core::Error; + type Error = mcandle_core::Error; fn try_from(value: usize) -> std::result::Result { match value { 0 => Ok(Self::Gguf), 1 => Ok(Self::Unquant), 2 => Ok(Self::Hqq), 3 => Ok(Self::Fp8), - other => candle_core::bail!("QuantizedSerdeType {other} is invalid."), + other => mcandle_core::bail!("QuantizedSerdeType {other} is invalid."), } } } @@ -180,13 +180,13 @@ pub trait QuantizedSerde { false } fn serialize(&self) -> Result> { - candle_core::bail!("`QuantizedSerde::serialize` is not supported.") + mcandle_core::bail!("`QuantizedSerde::serialize` is not supported.") } fn deserialize(_data: Cow<[u8]>, _device: &Device) -> Result> where Self: Sized, { - candle_core::bail!("`QuantizedSerde::deserialize` is not supported.") + mcandle_core::bail!("`QuantizedSerde::deserialize` is not supported.") } } @@ -254,7 +254,7 @@ pub fn linear_no_bias( let layer = ::new(QuantMethodConfig::Dummy)?; Arc::new(layer) as Arc } else { - let layer = candle_nn::linear_no_bias(in_dim, out_dim, vb)?; + let layer = mcandle_nn::linear_no_bias(in_dim, out_dim, vb)?; let layer = ::new(QuantMethodConfig::Unquantized(layer))?; Arc::new(layer) as Arc @@ -279,7 +279,7 @@ pub fn linear( let layer = ::new(QuantMethodConfig::Dummy)?; Arc::new(layer) as Arc } else { - let layer = candle_nn::linear(in_dim, out_dim, vb)?; + let layer = mcandle_nn::linear(in_dim, out_dim, vb)?; let layer = ::new(QuantMethodConfig::Unquantized(layer))?; Arc::new(layer) as Arc diff --git a/mistralrs-quant/src/metal_kernels/mod.rs b/mistralrs-quant/src/metal_kernels/mod.rs index 7fd6f7425..3e6e63535 100644 --- a/mistralrs-quant/src/metal_kernels/mod.rs +++ b/mistralrs-quant/src/metal_kernels/mod.rs @@ -1,4 +1,4 @@ -use candle_core::DType; +use mcandle_core::DType; use metal::{ Buffer, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState, Device, Function, FunctionConstantValues, Library, diff --git a/mistralrs-quant/src/unquantized/mod.rs b/mistralrs-quant/src/unquantized/mod.rs index cb921073d..4c8d856c9 100644 --- a/mistralrs-quant/src/unquantized/mod.rs +++ b/mistralrs-quant/src/unquantized/mod.rs @@ -6,8 +6,8 @@ use std::{ }; use byteorder::{LittleEndian, ReadBytesExt}; -use candle_core::{quantized::GgmlDType, DType, Device, DeviceLocation, Result, Shape, Tensor, D}; -use candle_nn::Linear; +use mcandle_core::{quantized::GgmlDType, DType, Device, DeviceLocation, Result, Shape, Tensor, D}; +use mcandle_nn::Linear; use crate::{ cublaslt::{maybe_init_cublas_lt_wrapper, CUBLASLT_HANDLE}, @@ -25,7 +25,7 @@ pub struct UnquantLinear { } impl QuantMethod for UnquantLinear { - fn new(method: QuantMethodConfig) -> candle_core::Result + fn new(method: QuantMethodConfig) -> mcandle_core::Result where Self: Sized, { @@ -109,7 +109,7 @@ impl QuantMethod for UnquantLinear { })) } - fn dtype_and_device(&self) -> (DType, candle_core::Device) { + fn dtype_and_device(&self) -> (DType, mcandle_core::Device) { (self.w.dtype(), self.w.device().clone()) } @@ -282,12 +282,12 @@ impl QuantizedSerde for UnquantLinear { let version = buffer.read_u32::()?; if let Err(e) = version_is_compatible(version) { - return Err(candle_core::Error::wrap(e)); + return Err(mcandle_core::Error::wrap(e)); } let isq_type = buffer.read_u8()? as usize; if isq_type != QuantizedSerdeType::Unquant as usize { - candle_core::bail!( + mcandle_core::bail!( "ISQ type ({isq_type}) doesn't match expected type {}", QuantizedSerdeType::Unquant as usize ); diff --git a/mistralrs-quant/src/utils/isq.rs b/mistralrs-quant/src/utils/isq.rs index f072aa6f6..c02b23a70 100644 --- a/mistralrs-quant/src/utils/isq.rs +++ b/mistralrs-quant/src/utils/isq.rs @@ -1,4 +1,4 @@ -use candle_core::{quantized::GgmlDType, Tensor}; +use mcandle_core::{quantized::GgmlDType, Tensor}; pub enum QuantizationBehaviour { Quantize(GgmlDType), @@ -57,11 +57,11 @@ macro_rules! generate_isq { $crate::utils::isq::QuantizationBehaviour::Skip => { let shape = $tensor.shape(); tracing::warn!("Skipping quantization of tensor with shape {shape:?} as it is not quantizable."); - Arc::new(candle_core::quantized::QTensor::quantize_onto(&$tensor, GgmlDType::F32, &$device)?) + Arc::new(mcandle_core::quantized::QTensor::quantize_onto(&$tensor, GgmlDType::F32, &$device)?) }, $crate::utils::isq::QuantizationBehaviour::Quantize(dtype) => { $n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - Arc::new(candle_core::quantized::QTensor::quantize_onto(&$tensor, dtype, &$device)?) + Arc::new(mcandle_core::quantized::QTensor::quantize_onto(&$tensor, dtype, &$device)?) } } } diff --git a/mistralrs-quant/src/utils/mod.rs b/mistralrs-quant/src/utils/mod.rs index 51650c7f9..98b6c39dc 100644 --- a/mistralrs-quant/src/utils/mod.rs +++ b/mistralrs-quant/src/utils/mod.rs @@ -12,7 +12,7 @@ pub(crate) use uqff::{ }; #[cfg(feature = "cuda")] -use candle_core::{ +use mcandle_core::{ cuda::{cudarc::driver::DevicePtr, CudaDType}, CudaDevice, Device, Storage, Tensor, WithDType, }; @@ -20,20 +20,20 @@ use candle_core::{ #[cfg(feature = "cuda")] pub(crate) fn get_cuda_slice( x: &Tensor, -) -> candle_core::Result<*const T> { +) -> mcandle_core::Result<*const T> { let offset = x.layout().start_offset(); match &*x.storage_and_layout().0 { Storage::Cuda(a_storage) => { Ok(*a_storage.as_cuda_slice::()?.slice(offset..).device_ptr() as *const T) } - _ => candle_core::bail!("Expected CUDA storage."), + _ => mcandle_core::bail!("Expected CUDA storage."), } } #[cfg(feature = "cuda")] -pub(crate) fn get_cuda_device(x: &Tensor) -> candle_core::Result<&CudaDevice> { +pub(crate) fn get_cuda_device(x: &Tensor) -> mcandle_core::Result<&CudaDevice> { match x.device() { Device::Cuda(dev) => Ok(dev), - _ => candle_core::bail!("Expected CUDA device"), + _ => mcandle_core::bail!("Expected CUDA device"), } } diff --git a/mistralrs-quant/src/utils/ops.rs b/mistralrs-quant/src/utils/ops.rs index c7f80db29..becf36238 100644 --- a/mistralrs-quant/src/utils/ops.rs +++ b/mistralrs-quant/src/utils/ops.rs @@ -1,4 +1,4 @@ -use candle_core::{ +use mcandle_core::{ backend::BackendStorage, CpuStorage, CustomOp1, CustomOp2, DType, Error, Layout, Result, Shape, Tensor, WithDType, }; @@ -9,7 +9,7 @@ use std::ops::{BitOr, Shl}; #[cfg(feature = "cuda")] use crate::utils::ffi; #[cfg(feature = "cuda")] -use candle_core::cuda::{cudarc::driver::DevicePtr, CudaStorage, WrapErr}; +use mcandle_core::cuda::{cudarc::driver::DevicePtr, CudaStorage, WrapErr}; #[cfg(feature = "cuda")] use std::ffi::c_void; @@ -54,15 +54,15 @@ impl CustomOp2 for BitWiseOr { CpuStorage::U8(vs1) => { let vs1 = match l1.contiguous_offsets() { Some((start, end)) => &vs1[start..end], - None => candle_core::bail!("Input tensor s1 must be contiguous"), + None => mcandle_core::bail!("Input tensor s1 must be contiguous"), }; let vs2 = s2.as_slice::()?; let vs2 = match l2.contiguous_offsets() { Some((start, end)) => &vs2[start..end], - None => candle_core::bail!("Input tensor s2 must be contiguous"), + None => mcandle_core::bail!("Input tensor s2 must be contiguous"), }; if vs1.len() != vs2.len() { - candle_core::bail!("Input tensors must have the same number of elements"); + mcandle_core::bail!("Input tensors must have the same number of elements"); }; let result = self.bitwise(vs1, vs2); let result = CpuStorage::U8(result); @@ -191,11 +191,11 @@ impl CustomOp2 for BitWiseOr { #[cfg(feature = "metal")] fn metal_fwd( &self, - s1: &candle_core::MetalStorage, + s1: &mcandle_core::MetalStorage, l1: &Layout, - s2: &candle_core::MetalStorage, + s2: &mcandle_core::MetalStorage, l2: &Layout, - ) -> Result<(candle_core::MetalStorage, Shape)> { + ) -> Result<(mcandle_core::MetalStorage, Shape)> { if l1.shape() != l2.shape() || l1.stride() != l2.stride() { return Err(Error::ShapeMismatchBinaryOp { lhs: l1.shape().clone(), @@ -211,10 +211,10 @@ impl CustomOp2 for BitWiseOr { }); } if !l1.is_contiguous() { - candle_core::bail!("Input tensor s1 must be contiguous"); + mcandle_core::bail!("Input tensor s1 must be contiguous"); } if !l2.is_contiguous() { - candle_core::bail!("Input tensor s2 must be contiguous"); + mcandle_core::bail!("Input tensor s2 must be contiguous"); } let command_buffer = s1.device().command_buffer()?; @@ -236,9 +236,9 @@ impl CustomOp2 for BitWiseOr { out_shape.elem_count(), &output, ) - .map_err(candle_core::Error::wrap)?; + .map_err(mcandle_core::Error::wrap)?; - let newstorage = candle_core::MetalStorage::new( + let newstorage = mcandle_core::MetalStorage::new( output, device.clone(), out_shape.elem_count(), @@ -274,7 +274,7 @@ impl CustomOp1 for Leftshift { fn cpu_fwd(&self, s1: &CpuStorage, l1: &Layout) -> Result<(CpuStorage, Shape)> { if !l1.is_contiguous() { - candle_core::bail!("Input tensor s1 must be contiguous"); + mcandle_core::bail!("Input tensor s1 must be contiguous"); } match s1 { CpuStorage::U8(vs1) => { @@ -300,7 +300,7 @@ impl CustomOp1 for Leftshift { #[cfg(feature = "cuda")] fn cuda_fwd(&self, s1: &CudaStorage, l1: &Layout) -> Result<(CudaStorage, Shape)> { if !l1.is_contiguous() { - candle_core::bail!("Input tensor s1 must be contiguous"); + mcandle_core::bail!("Input tensor s1 must be contiguous"); } let dev = s1.device().clone(); let (d_in1_ptr, elem_count) = match s1.dtype() { @@ -379,11 +379,11 @@ impl CustomOp1 for Leftshift { #[cfg(feature = "metal")] fn metal_fwd( &self, - s1: &candle_core::MetalStorage, + s1: &mcandle_core::MetalStorage, l1: &Layout, - ) -> Result<(candle_core::MetalStorage, Shape)> { + ) -> Result<(mcandle_core::MetalStorage, Shape)> { if !l1.is_contiguous() { - candle_core::bail!("Input tensor s1 must be contiguous"); + mcandle_core::bail!("Input tensor s1 must be contiguous"); } let command_buffer = s1.device().command_buffer()?; @@ -405,9 +405,9 @@ impl CustomOp1 for Leftshift { out_shape.elem_count(), &output, ) - .map_err(candle_core::Error::wrap)?; + .map_err(mcandle_core::Error::wrap)?; - let newstorage = candle_core::MetalStorage::new( + let newstorage = mcandle_core::MetalStorage::new( output, device.clone(), out_shape.elem_count(), @@ -432,8 +432,8 @@ mod tests { #[test] fn test_bitwise_or_cpu() { use crate::utils::ops::BitWiseOp; - use candle_core::Tensor; - let device = candle_core::Device::Cpu; + use mcandle_core::Tensor; + let device = mcandle_core::Device::Cpu; let a = Tensor::from_vec(vec![1i32, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap(); let b = Tensor::from_vec(vec![-1i32, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap(); @@ -445,8 +445,8 @@ mod tests { #[test] fn test_bitwise_or_cuda() { use crate::utils::ops::BitWiseOp; - use candle_core::Tensor; - let device = candle_core::Device::new_cuda(0).unwrap(); + use mcandle_core::Tensor; + let device = mcandle_core::Device::new_cuda(0).unwrap(); let a = Tensor::from_vec(vec![1i32, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap(); let b = Tensor::from_vec(vec![-1i32, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap(); @@ -457,8 +457,8 @@ mod tests { #[test] fn test_leftshift_cpu() { use crate::utils::ops::LeftshiftOp; - use candle_core::Tensor; - let device = candle_core::Device::Cpu; + use mcandle_core::Tensor; + let device = mcandle_core::Device::Cpu; let a = Tensor::from_vec(vec![1i32, 2, 3, 4, 5, 6], (3, 2), &device).unwrap(); let c = a.leftshift(2).unwrap().to_vec2::().unwrap(); assert_eq!(c, [[4, 8], [12, 16], [20, 24]]); @@ -468,8 +468,8 @@ mod tests { #[test] fn test_leftshift_cuda() { use crate::utils::ops::LeftshiftOp; - use candle_core::Tensor; - let device = candle_core::Device::new_cuda(0).unwrap(); + use mcandle_core::Tensor; + let device = mcandle_core::Device::new_cuda(0).unwrap(); let a = Tensor::from_vec(vec![1i32, 2, 3, 4, 5, 6], (3, 2), &device).unwrap(); let c = a.leftshift(2).unwrap().to_vec2::().unwrap(); assert_eq!(c, [[4, 8], [12, 16], [20, 24]]); @@ -479,8 +479,8 @@ mod tests { #[test] fn test_bitwise_or_and_leftshift_cuda() { use crate::utils::{ops::BitWiseOp, LeftshiftOp}; - use candle_core::Tensor; - let device = candle_core::Device::new_cuda(0).unwrap(); + use mcandle_core::Tensor; + let device = mcandle_core::Device::new_cuda(0).unwrap(); let a = Tensor::from_vec(vec![0b00001111u8], (1,), &device).unwrap(); let b = Tensor::from_vec(vec![0b00001111u8], (1,), &device).unwrap(); let c = a @@ -501,7 +501,7 @@ mod tests { #[test] fn test_bitpack_8bit() { use crate::HqqBits; - use candle_core::{Device, Tensor}; + use mcandle_core::{Device, Tensor}; let bits = HqqBits::Eight; let device = Device::Cpu; let wq = Tensor::from_vec(vec![257_i32, 258, 259, 260, 511, 512], (3, 2), &device).unwrap(); @@ -516,8 +516,8 @@ mod tests { #[test] fn test_bitpack_8bit() { use crate::HqqBits; - use candle_core::DType; - use candle_core::{Device, Tensor}; + use mcandle_core::DType; + use mcandle_core::{Device, Tensor}; let bits = HqqBits::Eight; let device = Device::new_cuda(0).unwrap(); let wq = Tensor::from_vec(vec![257_i32, 258, 259, 260, 511, 512], (3, 2), &device).unwrap(); @@ -534,7 +534,7 @@ mod tests { #[test] fn test_bitpack_4bit() { use crate::HqqBits; - use candle_core::{Device, Tensor}; + use mcandle_core::{Device, Tensor}; let bits = HqqBits::Four; let device = Device::Cpu; let wq = Tensor::from_vec(vec![1_u8, 2, 3, 4, 5, 6], (3, 2), &device).unwrap(); @@ -549,7 +549,7 @@ mod tests { #[test] fn test_bitpack_4bit() { use crate::HqqBits; - use candle_core::{Device, Tensor}; + use mcandle_core::{Device, Tensor}; let bits = HqqBits::Four; let device = Device::new_cuda(0).unwrap(); let wq = Tensor::from_vec(vec![1_u8, 2, 3, 4, 5, 6], (3, 2), &device).unwrap(); diff --git a/mistralrs-quant/src/utils/uqff.rs b/mistralrs-quant/src/utils/uqff.rs index bf5b0e7ae..a9f6ece51 100644 --- a/mistralrs-quant/src/utils/uqff.rs +++ b/mistralrs-quant/src/utils/uqff.rs @@ -1,8 +1,8 @@ use byteorder::{LittleEndian, ReadBytesExt}; -use candle_core::{DType, Device, Result, Tensor, WithDType}; use float8::F8E4M3; use half::{bf16, f16}; +use mcandle_core::{DType, Device, Result, Tensor, WithDType}; // v0.1.0: initial release // v0.1.1: add i16 dtype @@ -23,7 +23,7 @@ pub(crate) fn version_is_compatible(version: u32) -> Result<()> { let _patch = version; if major != HQFF_VERSION_MAJOR { - candle_core::bail!("Major version of ISQ artifact file ({major}) does not match the implementation in this build ({HQFF_VERSION_MAJOR})"); + mcandle_core::bail!("Major version of ISQ artifact file ({major}) does not match the implementation in this build ({HQFF_VERSION_MAJOR})"); } Ok(()) @@ -61,7 +61,7 @@ pub(crate) fn read_dtype(buffer: &mut R) -> Result { 7 => DType::F64, 8 => DType::I16, 9 => DType::F8E4M3, - _ => candle_core::bail!("unknown dtype for quantized tensor {dtype}"), + _ => mcandle_core::bail!("unknown dtype for quantized tensor {dtype}"), }; Ok(dtype) } @@ -164,7 +164,7 @@ fn data_to_bytes(mut vs: Vec) -> Vec { fn bytes_to_data( data: &[u8], shape: &[usize], - device: &candle_core::Device, + device: &mcandle_core::Device, ) -> Result { let size_in_bytes = T::DTYPE.size_in_bytes(); let elem_count = data.len() / size_in_bytes; diff --git a/mistralrs-server/Cargo.toml b/mistralrs-server/Cargo.toml index cb134f7c9..d27f92610 100644 --- a/mistralrs-server/Cargo.toml +++ b/mistralrs-server/Cargo.toml @@ -15,7 +15,7 @@ default-run = "mistralrs-server" [dependencies] anyhow.workspace = true ctrlc = "3.4.4" -candle-core.workspace = true +mcandle-core.workspace = true serde.workspace = true serde_json.workspace = true axum = { version = "0.7.4", features = ["tokio"] } diff --git a/mistralrs-server/src/main.rs b/mistralrs-server/src/main.rs index 4db596e60..c429d8b82 100644 --- a/mistralrs-server/src/main.rs +++ b/mistralrs-server/src/main.rs @@ -5,8 +5,8 @@ use axum::{ routing::{get, post}, Router, }; -use candle_core::Device; use clap::Parser; +use mcandle_core::Device; use mistralrs_core::{ get_model_dtype, get_tgt_non_granular_index, initialize_logging, paged_attn_supported, parse_isq_value, DefaultSchedulerMethod, DeviceLayerMapMetadata, DeviceMapMetadata, IsqType, @@ -311,10 +311,10 @@ async fn main() -> Result<()> { info!( "avx: {}, neon: {}, simd128: {}, f16c: {}", - candle_core::utils::with_avx(), - candle_core::utils::with_neon(), - candle_core::utils::with_simd128(), - candle_core::utils::with_f16c() + mcandle_core::utils::with_avx(), + mcandle_core::utils::with_neon(), + mcandle_core::utils::with_simd128(), + mcandle_core::utils::with_f16c() ); info!("Sampling method: penalties -> temperature -> topk -> topp -> minp -> multinomial"); if use_flash_attn { diff --git a/mistralrs-vision/Cargo.toml b/mistralrs-vision/Cargo.toml index 408312578..68acfec89 100644 --- a/mistralrs-vision/Cargo.toml +++ b/mistralrs-vision/Cargo.toml @@ -12,5 +12,5 @@ license.workspace = true homepage.workspace = true [dependencies] -candle-core.workspace = true +mcandle-core.workspace = true image.workspace = true diff --git a/mistralrs-vision/src/lib.rs b/mistralrs-vision/src/lib.rs index d1a6512da..43092e155 100644 --- a/mistralrs-vision/src/lib.rs +++ b/mistralrs-vision/src/lib.rs @@ -4,7 +4,7 @@ //! //! ## Example //! ```rust -//! use candle_core::Device; +//! use mcandle_core::Device; //! use image::{ColorType, DynamicImage}; //! use mistralrs_vision::{ApplyTransforms, Normalize, ToTensor, Transforms}; //! @@ -20,8 +20,8 @@ //! assert_eq!(transformed.dims(), &[3, 4, 3]); //! ``` -use candle_core::{Device, Result, Tensor}; use image::DynamicImage; +use mcandle_core::{Device, Result, Tensor}; mod ops; mod pad; mod transforms; diff --git a/mistralrs-vision/src/ops.rs b/mistralrs-vision/src/ops.rs index 610251e90..2076a0d3b 100644 --- a/mistralrs-vision/src/ops.rs +++ b/mistralrs-vision/src/ops.rs @@ -1,4 +1,4 @@ -use candle_core::{Result, Tensor}; +use mcandle_core::{Result, Tensor}; /// Pad an image of shape (c, h, w) to (c, max_h, max_w) by padding with zeros on the right and bottom. pub fn pad(image: &Tensor, max_h: usize, max_w: usize) -> Result { @@ -15,7 +15,7 @@ pub fn make_pixel_mask(image: &Tensor, h: usize, w: usize) -> Result { let (_c, max_h, max_w) = image.dims3()?; let mask = Tensor::ones((h, w), image.dtype(), image.device())?; let zeros = Tensor::zeros((max_h, max_w), image.dtype(), image.device())?; - // TODO(EricLBuehler): https://github.com/huggingface/candle/pull/2223 will make this nicer + // TODO(EricLBuehler): https://github.com/huggingface/mcandle/pull/2223 will make this nicer zeros.slice_assign(&[&(..h), &(..w)], &mask) } diff --git a/mistralrs-vision/src/transforms.rs b/mistralrs-vision/src/transforms.rs index 71fecf70e..f3e6eb152 100644 --- a/mistralrs-vision/src/transforms.rs +++ b/mistralrs-vision/src/transforms.rs @@ -1,6 +1,6 @@ use crate::utils::{get_pixel_data, n_channels}; -use candle_core::{DType, Device, Result, Tensor}; use image::{DynamicImage, GenericImageView}; +use mcandle_core::{DType, Device, Result, Tensor}; use crate::ImageTransform; @@ -85,7 +85,7 @@ impl ImageTransform for Normalize { fn map(&self, x: &Self::Input, _: &Device) -> Result { let num_channels = x.dim(0)?; if self.mean.len() != num_channels || self.std.len() != num_channels { - candle_core::bail!("Num channels must match number of mean and std."); + mcandle_core::bail!("Num channels must match number of mean and std."); } let mut accum = Vec::new(); for (i, channel) in x.chunk(num_channels, 0)?.iter().enumerate() { @@ -148,8 +148,8 @@ impl ImageTransform for Rescale { mod tests { #[test] fn test_to_tensor() { - use candle_core::Device; use image::{ColorType, DynamicImage}; + use mcandle_core::Device; use crate::ImageTransform; @@ -163,7 +163,7 @@ mod tests { #[test] fn test_normalize() { use crate::{ImageTransform, Normalize}; - use candle_core::{DType, Device, Tensor}; + use mcandle_core::{DType, Device, Tensor}; let image = Tensor::zeros((3, 5, 4), DType::U8, &Device::Cpu).unwrap(); let res = Normalize { diff --git a/mistralrs-vision/tests/integration.rs b/mistralrs-vision/tests/integration.rs index 10a941b29..a52728605 100644 --- a/mistralrs-vision/tests/integration.rs +++ b/mistralrs-vision/tests/integration.rs @@ -1,5 +1,5 @@ -use candle_core::Device; use image::{ColorType, DynamicImage}; +use mcandle_core::Device; use mistralrs_vision::{ApplyTransforms, InterpolateResize, Normalize, ToTensor, Transforms}; #[test] diff --git a/mistralrs/Cargo.toml b/mistralrs/Cargo.toml index 49816a51f..eab9da43a 100644 --- a/mistralrs/Cargo.toml +++ b/mistralrs/Cargo.toml @@ -15,7 +15,7 @@ homepage.workspace = true mistralrs-core = { version = "0.3.2", path = "../mistralrs-core" } anyhow.workspace = true tokio.workspace = true -candle-core.workspace = true +mcandle-core.workspace = true serde_json.workspace = true serde.workspace = true image.workspace = true diff --git a/mistralrs/src/lib.rs b/mistralrs/src/lib.rs index c91240662..91feae838 100644 --- a/mistralrs/src/lib.rs +++ b/mistralrs/src/lib.rs @@ -81,5 +81,5 @@ pub mod v0_4_api { pub use v0_4_api::*; -pub use candle_core::{DType, Device, Result, Tensor}; +pub use mcandle_core::{DType, Device, Result, Tensor}; pub use mistralrs_core::*; diff --git a/mistralrs/src/model.rs b/mistralrs/src/model.rs index d35267401..27bb6f645 100644 --- a/mistralrs/src/model.rs +++ b/mistralrs/src/model.rs @@ -1,5 +1,5 @@ use anyhow::Context; -use candle_core::{Device, Result}; +use mcandle_core::{Device, Result}; use mistralrs_core::*; use std::sync::Arc; use tokio::sync::mpsc::channel;